Source code for zarr_vectors.rechunk.engine

"""Rechunk engine — reads a source store and writes a rechunked copy.

The rechunked store has an extra prefix dimension on its chunk keys:
``(prefix_bin, z, y, x)`` instead of ``(z, y, x)``.  All objects in
the same prefix bin are physically contiguous, enabling O(1) group
or attribute-based filtering.
"""

from __future__ import annotations

import shutil
from pathlib import Path
from typing import Any

import numpy as np
import numpy.typing as npt

from zarr_vectors.constants import VERTICES
from zarr_vectors.core.arrays import (
    create_object_index_array,
    create_vertices_array,
    list_chunk_keys,
    read_all_object_manifests,
    read_chunk_vertices,
    read_object_vertices,
    write_chunk_vertices,
    write_object_index,
)
from zarr_vectors.core.metadata import LevelMetadata
from zarr_vectors.core.store import (
    FsGroup,
    create_resolution_level,
    create_store,
    get_resolution_level,
    open_store,
    read_root_metadata,
)
from zarr_vectors.rechunk.spec import DimensionMapper, RechunkSpec
from zarr_vectors.spatial.chunking import assign_chunks
from zarr_vectors.typing import ChunkCoords, ObjectManifest


[docs] def rechunk( store_path: str | Path, spec: RechunkSpec, output: str | Path | None = None, ) -> dict[str, Any]: """Rechunk a store along a non-spatial dimension. Reads object data from the source store, assigns each object to a rechunk bin via ``DimensionMapper``, and writes the result to an output store where chunk keys have a prefix dimension ``(bin, z, y, x)``. Args: store_path: Source store path. spec: Rechunk specification. output: Output store path. If None, rechunks in-place by writing to a temporary store then replacing the source. Returns: Summary dict with ``objects_rechunked``, ``bins_created``, ``output_path``. """ store_path = Path(store_path) # Determine output path in_place = output is None if in_place: output_path = store_path.parent / (store_path.name + ".rechunked") else: output_path = Path(output) if output_path.exists(): shutil.rmtree(output_path) # Read source src_root = open_store(str(store_path)) src_meta = read_root_metadata(src_root) ndim = src_meta.sid_ndim chunk_shape = spec.spatial_chunk_shape or src_meta.chunk_shape # Read level 0 data src_level = get_resolution_level(src_root, 0) # Read object manifests try: manifests = read_all_object_manifests(src_level) n_objects = len(manifests) except Exception: manifests = [] n_objects = 0 # Read groupings groupings: list[list[int]] | None = None try: from zarr_vectors.core.arrays import read_all_groupings groupings = read_all_groupings(src_level) except Exception: groupings = None # Read object attributes (for attribute-based rechunking) object_attributes: dict[str, npt.NDArray] | None = None if spec.by.startswith("attribute:"): # Try to read the attribute as per-object data attr_name = spec.by.split(":", 1)[1] object_attributes = {} try: from zarr_vectors.core.arrays import read_object_attributes obj_attr_data = read_object_attributes(src_level, attr_name) object_attributes[attr_name] = obj_attr_data except Exception: # Attribute might need to be computed (e.g. length for polylines) if attr_name == "length" and n_objects > 0: lengths = _compute_object_lengths(src_level, n_objects, ndim) object_attributes[attr_name] = lengths else: raise ValueError( f"Cannot read or compute attribute '{attr_name}'" ) # Map objects to rechunk bins mapper = DimensionMapper(spec) if n_objects > 0: obj_to_bin = mapper.map_objects( n_objects=n_objects, groupings=groupings, object_attributes=object_attributes, ) else: # No objects — rechunk spatially only obj_to_bin = {} # Determine unique bins if obj_to_bin: unique_bins = sorted(set(obj_to_bin.values())) else: unique_bins = [0] # Create output store spatial_dim_names = [ a.get("name", f"dim{i}") for i, a in enumerate(src_meta.spatial_index_dims) ] rechunk_dims = [spec.dimension_name, *spatial_dim_names] out_root = create_store( str(output_path), axes=src_meta.spatial_index_dims, chunk_shape=chunk_shape, bounds=src_meta.bounds, geometry_types=src_meta.geometry_types, links_convention=src_meta.links_convention, object_index_convention=src_meta.object_index_convention, cross_chunk_strategy=src_meta.cross_chunk_strategy, base_bin_shape=src_meta.base_bin_shape, ) # Compute the bin → original-value list for attribute-based rechunking. # Only meaningful when ``by="attribute:..."`` and we have the source # values; non-attribute rechunks leave chunk_attribute_* unset. chunk_attribute_name: str | None = None chunk_attribute_values: list[Any] | None = None if spec.by.startswith("attribute:") and object_attributes is not None: attr_name = spec.by.split(":", 1)[1] src_vals = object_attributes.get(attr_name) if src_vals is not None and obj_to_bin: # Build {bin_idx: value} from (object_id → bin_idx) and the # source value array; an object's value picks the bin. bin_to_value: dict[int, Any] = {} for oid, b in obj_to_bin.items(): if b not in bin_to_value: v = src_vals[oid] if hasattr(v, "item"): v = v.item() if isinstance(v, bytes): v = v.decode("utf-8") bin_to_value[b] = v chunk_attribute_name = attr_name chunk_attribute_values = [ bin_to_value[b] for b in sorted(bin_to_value) ] # Create level 0 level_meta = LevelMetadata( level=0, vertex_count=0, # updated below arrays_present=[VERTICES, "object_index"], chunk_dims=rechunk_dims, chunk_attribute_name=chunk_attribute_name, chunk_attribute_values=chunk_attribute_values, ) out_level = create_resolution_level(out_root, 0, level_meta) create_vertices_array(out_level, dtype="float32") create_object_index_array(out_level) # Rechunk: for each bin, gather all objects, assign to spatial chunks # with prefixed keys total_vertices = 0 total_objects = 0 object_manifests_out: dict[int, ObjectManifest] = {} global_obj_counter = 0 for bin_idx in unique_bins: # Collect objects in this bin if obj_to_bin: bin_objects = sorted( oid for oid, b in obj_to_bin.items() if b == bin_idx ) else: bin_objects = [] # Read vertex data for these objects bin_positions: list[npt.NDArray] = [] bin_obj_boundaries: list[int] = [] # cumulative vertex counts per object for oid in bin_objects: try: verts_list = read_object_vertices( src_level, oid, dtype=np.float32, ndim=ndim, ) obj_verts = np.concatenate( [v for v in verts_list if len(v) > 0], axis=0, ) if any(len(v) > 0 for v in verts_list) else np.zeros((0, ndim), dtype=np.float32) except Exception: obj_verts = np.zeros((0, ndim), dtype=np.float32) bin_positions.append(obj_verts) bin_obj_boundaries.append(len(obj_verts)) if not bin_positions or all(len(p) == 0 for p in bin_positions): continue # Assign to spatial chunks within this bin prefix all_pos = np.concatenate(bin_positions, axis=0) spatial_assignments = assign_chunks(all_pos, chunk_shape) # Build object-to-vertex mapping for this bin obj_starts = np.cumsum([0] + bin_obj_boundaries[:-1]) for spatial_cc, global_indices in sorted(spatial_assignments.items()): # Prefixed chunk key: (bin_idx, z, y, x) prefixed_cc: ChunkCoords = (bin_idx,) + spatial_cc chunk_verts = all_pos[global_indices] write_chunk_vertices( out_level, prefixed_cc, [chunk_verts], dtype=np.float32, ) total_vertices += len(chunk_verts) # Build object manifests for this bin for local_idx, oid in enumerate(bin_objects): start = int(obj_starts[local_idx]) n_verts = bin_obj_boundaries[local_idx] if n_verts == 0: continue obj_positions = bin_positions[local_idx] obj_spatial = assign_chunks(obj_positions, chunk_shape) manifest: ObjectManifest = [] for scc, _ in sorted(obj_spatial.items()): prefixed = (bin_idx,) + scc vg_idx = 0 # single VG per chunk in rechunked stores manifest.append((prefixed, vg_idx)) object_manifests_out[global_obj_counter] = manifest global_obj_counter += 1 total_objects += 1 # Write object index (with extended ndim for prefix dimension) if object_manifests_out: # The manifests use (prefix, z, y, x) coords — ndim+1 dimensions extended_ndim = ndim + 1 if obj_to_bin else ndim write_object_index(out_level, object_manifests_out, sid_ndim=extended_ndim) # Write groupings if rechunked by group (preserve group structure) if spec.by == "group" and groupings is not None: try: from zarr_vectors.core.arrays import ( create_groupings_array, write_groupings, ) # Remap group memberships to new object IDs old_to_new: dict[int, int] = {} new_idx = 0 for bin_idx in unique_bins: bin_objects = sorted( oid for oid, b in obj_to_bin.items() if b == bin_idx ) for old_oid in bin_objects: old_to_new[old_oid] = new_idx new_idx += 1 new_groupings: dict[int, list[int]] = {} for gid, members in enumerate(groupings): new_members = [ old_to_new[m] for m in members if m in old_to_new ] if new_members: new_groupings[gid] = new_members if new_groupings: create_groupings_array(out_level) write_groupings(out_level, new_groupings) except Exception: pass # In-place: replace source with output if in_place: backup = store_path.parent / (store_path.name + ".backup") store_path.rename(backup) output_path.rename(store_path) shutil.rmtree(backup) output_path = store_path return { "objects_rechunked": total_objects, "bins_created": len(unique_bins), "total_vertices": total_vertices, "rechunk_dims": rechunk_dims, "output_path": str(output_path), }
def _compute_object_lengths( level_group: FsGroup, n_objects: int, ndim: int, ) -> npt.NDArray[np.float64]: """Compute path length for each object (for polyline-like data).""" lengths = np.zeros(n_objects, dtype=np.float64) for oid in range(n_objects): try: verts_list = read_object_vertices( level_group, oid, dtype=np.float32, ndim=ndim, ) all_verts = np.concatenate( [v for v in verts_list if len(v) > 0], axis=0, ) if len(all_verts) >= 2: diffs = np.diff(all_verts, axis=0) lengths[oid] = float( np.sum(np.sqrt(np.sum(diffs ** 2, axis=1))) ) except Exception: pass return lengths