"""Lazy filtered views and geometry-specific collections.
``ZVView`` is a filtered projection of a ``ZVLevel`` that narrows
which chunks, bins, objects, or groups will be read. Filters chain:
each ``.filter()`` returns a new view with the intersection of all
constraints. Data is materialised only on ``.compute()``.
``ZVPolylineCollection`` provides per-object lazy access to
polylines/streamlines.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
import numpy as np
import numpy.typing as npt
from zarr_vectors.core.arrays import (
list_chunk_keys,
read_all_object_manifests,
read_chunk_vertices,
read_object_vertices,
read_vertex_group,
)
from zarr_vectors.core.store import FsGroup
from zarr_vectors.core.metadata import RootMetadata, LevelMetadata
from zarr_vectors.typing import BinCoords, ChunkCoords
try:
import dask
from dask import delayed as dask_delayed
HAS_DASK = True
except ImportError:
HAS_DASK = False
dask = None # type: ignore
def dask_delayed(func): # type: ignore
class _FakeDelayed:
def __init__(self, *a, **kw):
self._f, self._a, self._kw = func, a, kw
def compute(self):
return self._f(*self._a, **self._kw)
def __repr__(self):
return f"Delayed({self._f.__name__})"
class _W:
def __call__(self, *a, **kw):
return _FakeDelayed(*a, **kw)
return _W()
# ===================================================================
# Filter specification
# ===================================================================
@dataclass
class FilterSpec:
"""Accumulated filter constraints for a lazy view."""
target_chunks: set[ChunkCoords] | None = None
target_bins: set[BinCoords] | None = None
target_object_ids: set[int] | None = None
bbox: tuple[npt.NDArray, npt.NDArray] | None = None
def intersect(self, other: FilterSpec) -> FilterSpec:
"""Return a new FilterSpec that is the intersection of self and other."""
def _intersect_sets(a, b):
if a is None:
return b
if b is None:
return a
return a & b
bbox = other.bbox if other.bbox is not None else self.bbox
return FilterSpec(
target_chunks=_intersect_sets(self.target_chunks, other.target_chunks),
target_bins=_intersect_sets(self.target_bins, other.target_bins),
target_object_ids=_intersect_sets(self.target_object_ids, other.target_object_ids),
bbox=bbox,
)
# ===================================================================
# ZVView — filtered lazy view
# ===================================================================
[docs]
class ZVView:
"""A filtered lazy view of a resolution level.
Created by calling ``.filter()`` on a ``ZVLevel`` or another
``ZVView``. Each filter narrows the read plan; data is only
loaded on ``.compute()``.
Args:
group: Resolution level FsGroup.
root_meta: Root metadata.
level_meta: Level metadata (or None).
all_chunk_keys: Full list of chunk keys at this level.
spec: The accumulated filter constraints.
"""
[docs]
def __init__(
self,
group: FsGroup,
root_meta: RootMetadata,
level_meta: LevelMetadata | None,
all_chunk_keys: list[ChunkCoords],
spec: FilterSpec,
) -> None:
self._group = group
self._root_meta = root_meta
self._level_meta = level_meta
self._all_chunk_keys = all_chunk_keys
self._spec = spec
[docs]
def filter(
self,
*,
bbox: tuple[npt.NDArray, npt.NDArray] | None = None,
object_ids: list[int] | None = None,
group_ids: list[int] | None = None,
) -> ZVView:
"""Apply additional filter constraints, returning a new view.
Args:
bbox: Bounding box ``(min_corner, max_corner)``.
object_ids: Keep only these object IDs.
group_ids: Keep only objects in these groups (resolved
to object IDs via groupings).
Returns:
A new ``ZVView`` with the intersection of all constraints.
"""
new_spec = FilterSpec()
if bbox is not None:
new_spec.bbox = (np.asarray(bbox[0]), np.asarray(bbox[1]))
# Compute target chunks from bbox
from zarr_vectors.spatial.chunking import chunks_intersecting_bbox
target = set(chunks_intersecting_bbox(
new_spec.bbox[0], new_spec.bbox[1],
self._root_meta.chunk_shape,
))
new_spec.target_chunks = target
# If bins are available, compute bin-level targets
bins_per_chunk = self._root_meta.bins_per_chunk
if any(b > 1 for b in bins_per_chunk):
from zarr_vectors.spatial.chunking import (
bins_intersecting_bbox, bin_to_chunk, bin_to_vg_index,
)
effective_bin = self._root_meta.effective_bin_shape
target_bins = set(bins_intersecting_bbox(
new_spec.bbox[0], new_spec.bbox[1], effective_bin,
))
new_spec.target_bins = target_bins
if object_ids is not None:
new_spec.target_object_ids = set(object_ids)
if group_ids is not None:
from zarr_vectors.core.arrays import read_group_object_ids
resolved: set[int] = set()
for gid in group_ids:
try:
members = read_group_object_ids(self._group, gid)
resolved.update(members)
except Exception:
pass
new_spec.target_object_ids = resolved
merged = self._spec.intersect(new_spec)
return ZVView(
self._group, self._root_meta, self._level_meta,
self._all_chunk_keys, merged,
)
@property
def vertices(self) -> _FilteredVertices:
"""Lazy filtered vertex accessor."""
return _FilteredVertices(self)
[docs]
def compute(self) -> dict[str, Any]:
"""Materialise the filtered data.
Returns:
Dict with ``positions``, ``vertex_count``, and optionally
``object_ids``.
"""
ndim = self._root_meta.sid_ndim
dtype = np.float32
# Path 1: object-ID based read
if self._spec.target_object_ids is not None:
return self._compute_by_objects(ndim, dtype)
# Path 2: bin/chunk spatial read
return self._compute_spatial(ndim, dtype)
def _compute_by_objects(self, ndim: int, dtype: np.dtype) -> dict[str, Any]:
"""Read by object ID using manifests."""
all_positions: list[npt.NDArray] = []
all_oids: list[npt.NDArray] = []
for oid in sorted(self._spec.target_object_ids):
try:
verts_list = read_object_vertices(
self._group, oid, dtype=dtype, ndim=ndim,
)
except Exception:
continue
for vg in verts_list:
if len(vg) > 0:
all_positions.append(vg)
all_oids.append(np.full(len(vg), oid, dtype=np.int64))
if not all_positions:
return {"positions": np.zeros((0, ndim), dtype=dtype),
"vertex_count": 0, "object_ids": np.array([], dtype=np.int64)}
positions = np.concatenate(all_positions, axis=0)
object_ids = np.concatenate(all_oids)
# Apply bbox post-filter
if self._spec.bbox is not None:
mask = np.all(
(positions >= self._spec.bbox[0]) & (positions <= self._spec.bbox[1]),
axis=1,
)
positions = positions[mask]
object_ids = object_ids[mask]
return {
"positions": positions,
"vertex_count": len(positions),
"object_ids": object_ids,
}
def _compute_spatial(self, ndim: int, dtype: np.dtype) -> dict[str, Any]:
"""Read by spatial targeting (bins or chunks)."""
bins_per_chunk = self._root_meta.bins_per_chunk
has_bins = any(b > 1 for b in bins_per_chunk)
# Determine which chunks to read
chunk_keys_set = set(self._all_chunk_keys)
if self._spec.target_chunks is not None:
active_chunks = [ck for ck in self._all_chunk_keys
if ck in self._spec.target_chunks]
else:
active_chunks = self._all_chunk_keys
all_positions: list[npt.NDArray] = []
if has_bins and self._spec.target_bins is not None:
# Bin-level read
from zarr_vectors.spatial.chunking import bin_to_chunk, bin_to_vg_index
chunk_vg_targets: dict[ChunkCoords, list[int]] = {}
for bc in self._spec.target_bins:
cc = bin_to_chunk(bc, bins_per_chunk)
vgi = bin_to_vg_index(bc, cc, bins_per_chunk)
if cc not in chunk_vg_targets:
chunk_vg_targets[cc] = []
chunk_vg_targets[cc].append(vgi)
for cc, vg_indices in chunk_vg_targets.items():
if cc not in chunk_keys_set:
continue
for vgi in vg_indices:
try:
vg = read_vertex_group(self._group, cc, vgi, dtype=dtype, ndim=ndim)
if len(vg) > 0:
all_positions.append(vg)
except Exception:
continue
else:
# Chunk-level read
for ck in active_chunks:
try:
groups = read_chunk_vertices(self._group, ck, dtype=dtype, ndim=ndim)
for vg in groups:
if len(vg) > 0:
all_positions.append(vg)
except Exception:
continue
if not all_positions:
return {"positions": np.zeros((0, ndim), dtype=dtype), "vertex_count": 0}
positions = np.concatenate(all_positions, axis=0)
# Final bbox mask
if self._spec.bbox is not None:
mask = np.all(
(positions >= self._spec.bbox[0]) & (positions <= self._spec.bbox[1]),
axis=1,
)
positions = positions[mask]
return {"positions": positions, "vertex_count": len(positions)}
def __repr__(self) -> str:
parts = []
if self._spec.target_object_ids is not None:
parts.append(f"objects={len(self._spec.target_object_ids)}")
if self._spec.target_chunks is not None:
parts.append(f"chunks={len(self._spec.target_chunks)}")
if self._spec.target_bins is not None:
parts.append(f"bins={len(self._spec.target_bins)}")
if self._spec.bbox is not None:
parts.append("bbox=set")
desc = ", ".join(parts) if parts else "unfiltered"
return f"ZVView({desc})"
class _FilteredVertices:
"""Vertex accessor on a filtered view."""
def __init__(self, view: ZVView) -> None:
self._view = view
def compute(self) -> npt.NDArray[np.floating]:
result = self._view.compute()
return result["positions"]
def __repr__(self) -> str:
return f"FilteredVertices({self._view!r})"
# ===================================================================
# ZVPolylineCollection — per-object lazy polyline access
# ===================================================================
[docs]
class ZVPolylineCollection:
"""Lazy collection of polylines accessible by object ID.
Each polyline is reconstructed by following its object_index
manifest and concatenating vertex groups from the relevant chunks.
Args:
group: Resolution level FsGroup.
ndim: Number of spatial dimensions.
"""
[docs]
def __init__(self, group: FsGroup, ndim: int = 3) -> None:
self._group = group
self._ndim = ndim
self._manifests: list | None = None
def _ensure_manifests(self) -> list:
if self._manifests is None:
try:
self._manifests = read_all_object_manifests(self._group)
except Exception:
self._manifests = []
return self._manifests
@property
def count(self) -> int:
"""Number of polylines."""
return len(self._ensure_manifests())
@property
def object_ids(self) -> list[int]:
"""List of polyline object IDs."""
return list(range(self.count))
def __len__(self) -> int:
return self.count
def __getitem__(self, object_id: int) -> Any:
"""Get a delayed polyline by object ID.
Returns a delayed object that, when computed, returns the full
reconstructed polyline as an ``(N, D)`` numpy array.
"""
manifests = self._ensure_manifests()
if object_id < 0 or object_id >= len(manifests):
raise IndexError(
f"Polyline {object_id} out of range [0, {len(manifests)})"
)
return _delayed_read_polyline(
self._group, object_id, self._ndim,
)
[docs]
def items(self):
"""Iterate over ``(object_id, delayed_polyline)`` pairs."""
for oid in range(self.count):
yield oid, self[oid]
[docs]
def compute(self) -> list[npt.NDArray[np.floating]]:
"""Materialise all polylines.
Returns:
List of ``(N_k, D)`` arrays, one per polyline.
"""
delayed_list = [self[oid] for oid in range(self.count)]
if HAS_DASK and len(delayed_list) > 1:
return list(dask.compute(*delayed_list))
return [d.compute() for d in delayed_list]
[docs]
def filter(
self,
*,
object_ids: list[int] | None = None,
length_range: tuple[float, float] | None = None,
) -> FilteredPolylineCollection:
"""Filter polylines by ID or length range.
Returns a new filtered collection (lazy).
Args:
object_ids: Keep only these IDs.
length_range: ``(min_length, max_length)`` — keep polylines
whose Euclidean path length falls in this range.
Requires computing lengths on first access.
"""
return FilteredPolylineCollection(
self, object_ids=object_ids, length_range=length_range,
)
def __repr__(self) -> str:
return f"ZVPolylineCollection(count={self.count})"
class FilteredPolylineCollection:
"""A filtered subset of a polyline collection."""
def __init__(
self,
parent: ZVPolylineCollection,
*,
object_ids: list[int] | None = None,
length_range: tuple[float, float] | None = None,
) -> None:
self._parent = parent
self._explicit_ids = set(object_ids) if object_ids is not None else None
self._length_range = length_range
self._resolved_ids: list[int] | None = None
def _resolve(self) -> list[int]:
"""Resolve which object IDs pass all filters."""
if self._resolved_ids is not None:
return self._resolved_ids
candidates = (
sorted(self._explicit_ids)
if self._explicit_ids is not None
else list(range(self._parent.count))
)
if self._length_range is not None:
lo, hi = self._length_range
kept: list[int] = []
for oid in candidates:
poly = self._parent[oid].compute()
if len(poly) < 2:
length = 0.0
else:
diffs = np.diff(poly, axis=0)
length = float(np.sum(np.sqrt(np.sum(diffs ** 2, axis=1))))
if lo <= length <= hi:
kept.append(oid)
candidates = kept
self._resolved_ids = candidates
return candidates
@property
def count(self) -> int:
return len(self._resolve())
def __len__(self) -> int:
return self.count
def compute(self) -> list[npt.NDArray[np.floating]]:
"""Materialise the filtered polylines."""
ids = self._resolve()
delayed_list = [self._parent[oid] for oid in ids]
if HAS_DASK and len(delayed_list) > 1:
return list(dask.compute(*delayed_list))
return [d.compute() for d in delayed_list]
def items(self):
for oid in self._resolve():
yield oid, self._parent[oid]
def __repr__(self) -> str:
parts = []
if self._explicit_ids is not None:
parts.append(f"ids={len(self._explicit_ids)}")
if self._length_range is not None:
parts.append(f"length={self._length_range}")
return f"FilteredPolylineCollection({', '.join(parts)}, count={self.count})"
# ===================================================================
# Delayed helpers
# ===================================================================
@dask_delayed
def _read_polyline(
group: FsGroup,
object_id: int,
ndim: int,
) -> npt.NDArray[np.floating]:
"""Read and reconstruct a single polyline from its manifest."""
try:
verts_list = read_object_vertices(
group, object_id, dtype=np.float32, ndim=ndim,
)
non_empty = [v for v in verts_list if len(v) > 0]
if not non_empty:
return np.zeros((0, ndim), dtype=np.float32)
return np.concatenate(non_empty, axis=0)
except Exception:
return np.zeros((0, ndim), dtype=np.float32)
def _delayed_read_polyline(group: FsGroup, object_id: int, ndim: int) -> Any:
return _read_polyline(group, object_id, ndim)