"""Create, write, and read ZV per-level arrays.

Each "array" is a subdirectory within a resolution-level group.  Chunk
data is stored as raw binary files named by their spatial chunk
coordinates (e.g. ``vertices/0.0.0``).  Array metadata is in
``<array>/.zattrs``.

This module is the single point of contact for all array I/O — type
modules (``types/*.py``) call these functions rather than touching
the store or encoding modules directly.
"""

from __future__ import annotations

import warnings
from contextlib import contextmanager
from typing import Any

import numpy as np
import numpy.typing as npt
import zarr
from zarr.codecs import VLenBytesCodec
from zarr.errors import UnstableSpecificationWarning

from zarr_vectors.constants import (
    CROSS_CHUNK_LINK_ATTRIBUTES,
    CROSS_CHUNK_LINKS,
    GROUP_ATTRIBUTES,
    GROUPS,
    LINK_ATTRIBUTES,
    LINK_FRAGMENTS,
    LINKS,
    OBJECT_ATTRIBUTES,
    OBJECT_INDEX,
    VERTEX_ATTRIBUTES,
    VERTEX_FRAGMENTS,
    VERTICES,
)
from zarr_vectors.core.paths import (
    cross_chunk_link_attributes_path,
    cross_chunk_links_path,
    format_delta,
    link_attributes_path,
    links_path,
    parse_delta,
)
from zarr_vectors.core.store import FsGroup
from zarr_vectors.encoding.fragments import (
    FragmentIndex,
    decode_fragments,
    decode_object_manifest_blocks,
    encode_fragments,
    encode_object_manifest_blocks,
)
from zarr_vectors.encoding.ragged import (
    decode_ragged_blob,
    decode_ragged_ints,
    decode_vertex_groups,
    encode_ragged_blob,
    encode_ragged_ints,
    encode_vertex_groups,
)
from zarr_vectors.exceptions import ArrayError
from zarr_vectors.typing import (
    ChunkCoords,
    CrossChunkLink,
    ObjectManifest,
)


# Discriminator written into ``object_index`` group attrs to identify
# the ragged-array layout (one vlen-bytes zarr array named ``manifests``,
# one entry per object_id).  Absent on legacy stores, which used a pair
# of single-chunk ``data`` + ``offsets`` byte blobs.
OBJECT_INDEX_LAYOUT_V1 = "vlen_manifests_v1"

# Objects per zarr chunk of ``object_index/manifests``.  A single-object
# read fetches only the chunk containing the requested oid, so this sets
# the read amplification ceiling (~16K manifest blobs per fetch).
OBJECT_INDEX_MANIFEST_BUCKET = 16_384


# ===================================================================
# Helpers
# ===================================================================

def _chunk_key(coords: ChunkCoords) -> str:
    """Convert chunk coordinates to a dot-separated key string.

    ``(0, 1, 2)`` → ``"0.1.2"``
    """
    return ".".join(str(c) for c in coords)


def _parse_chunk_key(key: str) -> ChunkCoords:
    """Parse a dot-separated chunk key back to coordinates.

    ``"0.1.2"`` → ``(0, 1, 2)``
    """
    return tuple(int(x) for x in key.split("."))


@contextmanager
def _maybe_batched_reads(
    level_group: FsGroup,
    plan: list[tuple[str, list[str]]],
):
    """Open a single-chunk prefetch unless one is already active.

    Wraps :meth:`FsGroup.batched_reads` for the common "I'm about to
    read N>1 sibling arrays for one chunk" case: it fans the reads out
    in one ``asyncio.gather`` instead of paying N sequential
    round-trips.  When the caller has already entered an outer
    :meth:`batched_reads` context (i.e. a multi-chunk loop), this is a
    no-op so the outer plan stays in charge.
    """
    if level_group._prefetch_cache is not None:
        yield
        return
    with level_group.batched_reads(plan):
        yield


def _ensure_array_dir(level_group: FsGroup, array_name: str) -> None:
    """Ensure an array subdirectory exists within a level group.

    Inside :meth:`Group.batched_writes` we skip the sync ``require_group``
    round-trip — the metadata flush will PUT the parent ``zarr.json``
    directly, including the right attributes, in the same gather as the
    chunk PUTs.  Outside batched mode we still call ``require_group`` so
    the parent group exists before any subsequent ``write_array_meta``
    call (which only ``attrs.update``s — it does not create the group).
    """
    if level_group._pending_array_metas is not None:
        return
    level_group.require_group(array_name)


def read_zv_array_tag(meta: dict) -> str | None:
    """Return the discriminator string from an array's ``.zattrs`` dict.

    Looks under the new key ``zv_array`` first and falls back to the
    legacy ``zvf_array`` key for stores written before the rename.  Use
    this helper instead of indexing the dict directly so that conformance
    / validation code keeps working against existing-on-disk stores.
    """
    return meta.get("zv_array", meta.get("zvf_array"))


# ===================================================================
# Array creation (set up directory + metadata)
# ===================================================================

def create_vertices_array(
    level_group: FsGroup,
    dtype: str = "float32",
    encoding: str = "raw",
) -> None:
    """Create the ``vertices/`` array within a resolution level.

    Args:
        level_group: The resolution level FsGroup.
        dtype: Numpy dtype string for vertex positions.
        encoding: ``"raw"`` or ``"draco"``.
    """
    _ensure_array_dir(level_group, VERTICES)
    _ensure_array_dir(level_group, VERTEX_FRAGMENTS)
    level_group.write_array_meta(VERTICES, {
        "zv_array": "vertices",
        "dtype": dtype,
        "encoding": encoding,
    })
    level_group.write_array_meta(VERTEX_FRAGMENTS, {
        "zv_array": VERTEX_FRAGMENTS,
        "encoding": "fragment_index_v1",
    })


def create_links_array(
    level_group: FsGroup,
    link_width: int,
    dtype: str = "int64",
    *,
    delta: int = 0,
) -> None:
    """Create a ``links/<delta>/`` array.

    Under the 0.4 multiscale links layout each ``<delta>`` segment is a
    distinct array; ``delta=0`` is the intra-level array (the only one
    written pre-0.4) and non-zero deltas hold edges that point ``delta``
    pyramid levels away (positive = coarser, negative = finer).

    Args:
        level_group: The resolution level FsGroup.
        link_width: Number of vertex indices per link entry (L).
            1 for skeleton parents, 2 for edges, 3 for triangle faces.
        dtype: Integer dtype.
        delta: Level delta; see :mod:`zarr_vectors.core.paths`.
    """
    full_name = links_path(delta)
    _ensure_array_dir(level_group, full_name)
    level_group.write_array_meta(full_name, {
        "zv_array": "links",
        "dtype": dtype,
        "link_width": link_width,
        "level_delta": int(delta),
    })


def create_attribute_array(
    level_group: FsGroup,
    name: str,
    dtype: str = "float32",
    channel_names: list[str] | None = None,
    extra_meta: dict[str, Any] | None = None,
) -> None:
    """Create a vertex attribute array ``attributes/<name>/``.

    Args:
        level_group: The resolution level FsGroup.
        name: Attribute name (e.g. ``"radius"``, ``"gene_expression"``).
        dtype: Numpy dtype string.
        channel_names: Optional list of channel names.
        extra_meta: Additional JSON-serialisable fields merged into the
            array metadata.  Used for the dictionary-encoding
            convention (``encoding="dictionary"``, ``categories``,
            ``ordered``, ``_FillValue``) and other userspace
            extensions.  Keys collide-check against the core fields
            (``zv_array``, ``name``, ``dtype``, ``channel_names``).
    """
    full_name = f"{VERTEX_ATTRIBUTES}/{name}"
    _ensure_array_dir(level_group, full_name)
    meta: dict[str, Any] = {
        "zv_array": "attribute",
        "name": name,
        "dtype": dtype,
    }
    if channel_names is not None:
        meta["channel_names"] = channel_names
    if extra_meta:
        reserved = {"zv_array", "name", "dtype", "channel_names"}
        clobber = reserved & set(extra_meta)
        if clobber:
            raise ArrayError(
                f"extra_meta cannot override core attribute fields: {sorted(clobber)}"
            )
        meta.update(extra_meta)
    level_group.write_array_meta(full_name, meta)


def create_object_index_array(level_group: FsGroup) -> None:
    """Create the ``object_index/`` array."""
    _ensure_array_dir(level_group, OBJECT_INDEX)
    level_group.write_array_meta(OBJECT_INDEX, {
        "zv_array": "object_index",
    })


def create_object_attributes_array(
    level_group: FsGroup,
    name: str,
    dtype: str = "float32",
    num_channels: int = 1,
) -> None:
    """Create an object attribute array ``object_attributes/<name>/``.

    Args:
        level_group: The resolution level FsGroup.
        name: Attribute name.
        dtype: Numpy dtype string.
        num_channels: Number of channels (C dimension).
    """
    full_name = f"{OBJECT_ATTRIBUTES}/{name}"
    _ensure_array_dir(level_group, full_name)
    level_group.write_array_meta(full_name, {
        "zv_array": "object_attribute",
        "name": name,
        "dtype": dtype,
        "num_channels": num_channels,
    })


def create_groupings_array(level_group: FsGroup) -> None:
    """Create the ``groupings/`` array."""
    _ensure_array_dir(level_group, GROUPS)
    level_group.write_array_meta(GROUPS, {
        "zv_array": "groups",
    })


def create_groupings_attributes_array(
    level_group: FsGroup,
    name: str,
    dtype: str = "float32",
    num_channels: int = 1,
) -> None:
    """Create a groupings attribute array ``groupings_attributes/<name>/``."""
    full_name = f"{GROUP_ATTRIBUTES}/{name}"
    _ensure_array_dir(level_group, full_name)
    level_group.write_array_meta(full_name, {
        "zv_array": "groupings_attribute",
        "name": name,
        "dtype": dtype,
        "num_channels": num_channels,
    })


def create_cross_chunk_links_array(
    level_group: FsGroup,
    *,
    delta: int = 0,
    link_width: int = 2,
) -> None:
    """Create a ``cross_chunk_links/<delta>/`` array.

    Source-side endpoints live at the owning resolution level;
    target-side endpoints live at ``this_level + delta``.

    Args:
        level_group: Resolution level group.
        delta: Level delta (0 for intra-level, ±N for cross-level).
        link_width: Number of vertex refs per record.  2 for edges
            (the default — chunk pairs straddling a boundary), 3 for
            triangle faces, 1 for parent→child metanode references.
    """
    full_name = cross_chunk_links_path(delta)
    _ensure_array_dir(level_group, full_name)
    level_group.write_array_meta(full_name, {
        "zv_array": "cross_chunk_links",
        "level_delta": int(delta),
        "link_width": int(link_width),
    })


def create_link_attributes_array(
    level_group: FsGroup,
    name: str,
    dtype: str = "float32",
    *,
    delta: int = 0,
) -> None:
    """Create a ``link_attributes/<name>/<delta>/`` array (parallel to
    the matching ``links/<delta>/`` array)."""
    full_name = link_attributes_path(name, delta)
    _ensure_array_dir(level_group, full_name)
    level_group.write_array_meta(full_name, {
        "zv_array": "link_attribute",
        "name": name,
        "dtype": dtype,
        "level_delta": int(delta),
    })


def create_cross_chunk_link_attributes_array(
    level_group: FsGroup,
    name: str,
    dtype: str = "float32",
    *,
    delta: int = 0,
) -> None:
    """Create a ``cross_chunk_link_attributes/<name>/<delta>/`` array.

    Parallel attribute storage for the matching
    ``cross_chunk_links/<delta>/`` array; one value (or one ``C``-vector
    row) per cross-chunk link in path order.
    """
    full_name = cross_chunk_link_attributes_path(name, delta)
    _ensure_array_dir(level_group, full_name)
    level_group.write_array_meta(full_name, {
        "zv_array": "cross_chunk_link_attribute",
        "name": name,
        "dtype": dtype,
        "level_delta": int(delta),
    })


# ===================================================================
# Writing data
# ===================================================================

def write_chunk_vertices(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
    groups: list[npt.NDArray[np.floating]],
    dtype: np.dtype | str = np.float32,
) -> npt.NDArray[np.int64]:
    """Write vertex groups to a spatial chunk.

    Encodes the groups as a contiguous byte buffer in ``vertices/`` and
    writes a v0.6 fragment-index to ``vertex_fragments/`` describing each
    group as a contiguous range of vertex rows in source order.

    Args:
        level_group: Resolution level group.
        chunk_coords: Spatial chunk coordinates.
        groups: List of arrays, each ``(N_k, D)``.
        dtype: Numpy dtype for serialisation.

    Returns:
        ``(K,)`` int64 array of vertex byte offsets (kept for backwards-
        compatible signature; callers that need the v0.6 fragment-index
        should use :func:`read_vertex_fragment_index`).
    """
    dtype = np.dtype(dtype)
    key = _chunk_key(chunk_coords)

    raw_bytes, vertex_byte_offsets = encode_vertex_groups(groups, dtype)
    level_group.write_bytes(VERTICES, key, raw_bytes)

    # Express each group as a contiguous (start_row, count) fragment.
    if len(groups) == 0:
        fragments: list[tuple[int, int]] = []
    else:
        per_group_counts = [int(np.asarray(g).shape[0]) for g in groups]
        cumulative = 0
        fragments = []
        for n in per_group_counts:
            fragments.append((cumulative, n))
            cumulative += n
    level_group.write_bytes(
        VERTEX_FRAGMENTS, key, encode_fragments(fragments),
    )
    return vertex_byte_offsets


def write_chunk_links(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
    link_groups: list[npt.NDArray[np.integer]],
    dtype: np.dtype | str = np.int64,
    *,
    delta: int = 0,
) -> npt.NDArray[np.int64]:
    """Write link groups to a spatial chunk under ``links/<delta>/``.

    For ``delta=0`` link groups are 1:1 aligned with the chunk's
    vertex groups; readers derive per-group link byte offsets from the
    cumulative sizes of each group's link bytes (see
    :func:`read_chunk_links`).

    For ``delta != 0`` (cross-pyramid-level links) the source vertex
    groups and link groups live at different levels and there is
    typically one link group spanning the chunk.

    Args:
        level_group: Resolution level group.
        chunk_coords: Spatial chunk coordinates.
        link_groups: List of arrays, each ``(M_k, L)``.
        dtype: Integer dtype.
        delta: Level delta; see :mod:`zarr_vectors.core.paths`.

    Returns:
        ``(K,)`` int64 array of link byte offsets.
    """
    dtype = np.dtype(dtype)
    key = _chunk_key(chunk_coords)
    full_name = links_path(delta)

    if delta == 0 and level_group.chunk_exists(VERTEX_FRAGMENTS, key):
        existing_fi = decode_fragments(
            level_group.read_bytes(VERTEX_FRAGMENTS, key),
        )
        if existing_fi.num_fragments != len(link_groups):
            raise ArrayError(
                f"Link group count ({len(link_groups)}) != "
                f"vertex fragment count ({existing_fi.num_fragments}) in chunk {key}"
            )

    if delta == 0:
        # v0.6 intra-level: flat concatenated link data + sibling
        # link_fragments/ describing per-group row ranges.
        data_bytes, link_byte_offsets = encode_ragged_ints(link_groups, dtype)
        level_group.write_bytes(full_name, key, data_bytes)

        link_row_size = dtype.itemsize * (
            int(np.asarray(link_groups[0]).shape[1]) if (
                link_groups and np.asarray(link_groups[0]).ndim == 2
            ) else 1
        )
        # Fragment per group as a contiguous range of link rows.
        if len(link_groups) == 0:
            link_fragments: list[tuple[int, int]] = []
        else:
            cumulative = 0
            link_fragments = []
            for g in link_groups:
                n = int(np.asarray(g).shape[0]) if np.asarray(g).ndim >= 1 else 0
                link_fragments.append((cumulative, n))
                cumulative += n
        # Ensure the sibling array group exists.
        if not level_group.chunk_exists(LINK_FRAGMENTS, key):
            level_group.require_group(LINK_FRAGMENTS)
            try:
                level_group.read_array_meta(LINK_FRAGMENTS)
            except Exception:
                level_group.write_array_meta(LINK_FRAGMENTS, {
                    "zv_array": LINK_FRAGMENTS,
                    "encoding": "fragment_index_v1",
                })
        level_group.write_bytes(
            LINK_FRAGMENTS, key, encode_fragments(link_fragments),
        )
        del link_row_size  # silence unused-variable warning
        return link_byte_offsets

    # delta != 0: cross-level links keep the v0.5 inline self-describing
    # layout (out of scope for the v0.6 fragment-index refactor).
    blob = encode_ragged_blob(link_groups, dtype)
    level_group.write_bytes(full_name, key, blob)
    _, link_byte_offsets = encode_ragged_ints(link_groups, dtype)
    return link_byte_offsets


def write_chunk_attributes(
    level_group: FsGroup,
    attr_name: str,
    chunk_coords: ChunkCoords,
    attr_groups: list[npt.NDArray],
    dtype: np.dtype | str = np.float32,
) -> None:
    """Write vertex attribute data for groups in a spatial chunk.

    Attribute groups align 1:1 with vertex groups, so per-group byte
    offsets are derived at read time from ``vertex_group_offsets`` and
    the attribute dtype/ncols.  No sibling ``_offsets`` blob is written.

    Args:
        level_group: Resolution level group.
        attr_name: Attribute name (e.g. ``"radius"``).
        chunk_coords: Spatial chunk coordinates.
        attr_groups: List of arrays aligned with vertex groups.
            Each array is ``(N_k,)`` for scalar or ``(N_k, C)`` for
            multi-channel attributes.
        dtype: Numpy dtype.
    """
    dtype = np.dtype(dtype)
    key = _chunk_key(chunk_coords)
    full_name = f"{VERTEX_ATTRIBUTES}/{attr_name}"
    raw_bytes, _ = encode_vertex_groups(attr_groups, dtype)
    level_group.write_bytes(full_name, key, raw_bytes)


def write_chunk_link_attributes(
    level_group: FsGroup,
    attr_name: str,
    chunk_coords: ChunkCoords,
    attr_groups: list[npt.NDArray],
    dtype: np.dtype | str = np.float32,
    *,
    delta: int = 0,
) -> None:
    """Write per-edge attribute data parallel to ``links/<delta>/``.

    Args:
        level_group: Resolution level group.
        attr_name: Attribute name (e.g. ``"weight"``).
        chunk_coords: Spatial chunk coordinates.
        attr_groups: List of arrays, each ``(M_k,)`` or ``(M_k, C)``,
            aligned with link groups in the ``links/<delta>/`` array.
        dtype: Numpy dtype.
        delta: Level delta; see :mod:`zarr_vectors.core.paths`.
    """
    dtype = np.dtype(dtype)
    key = _chunk_key(chunk_coords)
    full_name = link_attributes_path(attr_name, delta)
    raw_bytes, _ = encode_vertex_groups(attr_groups, dtype)
    level_group.write_bytes(full_name, key, raw_bytes)


def write_object_index(
    level_group: FsGroup,
    manifests: dict[int, ObjectManifest],
    sid_ndim: int,
    *,
    total_objects: int | None = None,
) -> None:
    """Write object index: object_id → ordered vertex group references.

    Args:
        level_group: Resolution level group.
        manifests: ``{object_id: [(chunk_coords, vg_index), ...], ...}``.
            Sparse — OIDs absent from the dict get empty manifests.
        sid_ndim: Number of spatial index dimensions.
        total_objects: Number of OID slots to write.  When provided,
            the dense manifest list spans ``range(total_objects)`` even
            if the largest OID present is smaller — used by the
            ID-preserving pyramid regime, where surviving OIDs are a
            sparse subset of the parent's OID space.  When ``None``
            (default), the size is ``max(manifests.keys()) + 1``
            (legacy behaviour).
    """
    if not manifests and total_objects is None:
        return

    if total_objects is not None:
        size = int(total_objects)
    else:
        size = max(manifests.keys()) + 1
    # Build a dense list, filling gaps with empty manifests
    manifest_list: list[list[tuple[tuple[int, ...], int]]] = []
    for oid in range(size):
        manifest_list.append(manifests.get(oid, []))

    # v0.6 manifest-block encoding.  Each old (chunk, vg_index) tuple
    # becomes one mode-0 (single fragment) block.  Range / explicit
    # short-circuits are reserved for writers that know they produce
    # ranges or fragment-list shapes — to be plumbed through the
    # higher-level write APIs in a future change.
    manifest_blobs: list[bytes] = []
    for manifest in manifest_list:
        blocks = [
            (tuple(int(c) for c in chunk_coords), int(vg_index))
            for chunk_coords, vg_index in manifest
        ]
        manifest_blobs.append(
            encode_object_manifest_blocks(blocks, sid_ndim=sid_ndim)
        )

    _write_object_index_manifests(level_group, manifest_blobs)
    level_group.write_array_meta(OBJECT_INDEX, {
        "zv_array": "object_index",
        "num_objects": size,
        "sid_ndim": sid_ndim,
        "layout": OBJECT_INDEX_LAYOUT_V1,
    })


def _write_object_index_manifests(
    level_group: FsGroup,
    manifest_blobs: list[bytes],
) -> None:
    """Write ``object_index/manifests`` as a single ragged vlen-bytes array.

    One zarr chunk holds ``OBJECT_INDEX_MANIFEST_BUCKET`` consecutive
    objects, so a single-object read fetches at most one chunk regardless
    of total ``num_objects`` — fixing the legacy O(num_objects) read
    amplification.  Drops legacy ``object_index/{data,offsets}`` arrays
    if they exist from a prior write.
    """
    n = len(manifest_blobs)
    oi_group = level_group.zarr_group.require_group(OBJECT_INDEX)
    for legacy in ("manifests", "data", "offsets"):
        if legacy in oi_group:
            del oi_group[legacy]

    if n == 0:
        return

    chunk_size = min(OBJECT_INDEX_MANIFEST_BUCKET, n)
    # zarr 3.x's variable-length bytes dtype lacks a finalised V3 spec
    # (zarr-extensions tracks it); the warning is informational and ZVF
    # is alpha — accept it and silence at the call site so writes stay
    # quiet.  Revisit if the spec lands incompatibly.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UnstableSpecificationWarning)
        arr = oi_group.create_array(
            "manifests",
            shape=(n,),
            chunks=(chunk_size,),
            dtype="bytes",
            serializer=VLenBytesCodec(),
        )
        obj = np.empty(n, dtype=object)
        for i, blob in enumerate(manifest_blobs):
            obj[i] = blob
        arr[:] = obj


def write_object_attributes(
    level_group: FsGroup,
    attr_name: str,
    data: npt.NDArray,
    *,
    present_mask: npt.NDArray | None = None,
) -> None:
    """Write dense O×C object attribute data.

    Args:
        level_group: Resolution level group.
        attr_name: Attribute name.
        data: ``(O,)`` or ``(O, C)`` array.
        present_mask: Optional ``(O,)`` byte array (``0``/``1`` per
            object) marking which rows are real.  Required for levels
            that use ID-preserving sparsification — rows for dropped
            objects have ``mask[i] == 0`` and the corresponding
            ``data[i]`` row is dtype-zero padding.  When omitted, every
            row is assumed real (backwards compatible).
    """
    full_name = f"{OBJECT_ATTRIBUTES}/{attr_name}"
    _ensure_array_dir(level_group, full_name)
    level_group.write_bytes(full_name, "data", data.tobytes())
    if present_mask is not None:
        mask = np.asarray(present_mask, dtype=np.uint8)
        if mask.shape[0] != data.shape[0]:
            raise ArrayError(
                f"present_mask length {mask.shape[0]} != data row count "
                f"{data.shape[0]}"
            )
        level_group.write_bytes(full_name, "present_mask", mask.tobytes())
    level_group.write_array_meta(full_name, {
        "zv_array": "object_attribute",
        "name": attr_name,
        "dtype": str(data.dtype),
        "shape": list(data.shape),
        "has_present_mask": bool(present_mask is not None),
    })


def read_object_attribute_present_mask(
    level_group: FsGroup,
    attr_name: str,
) -> npt.NDArray[np.uint8] | None:
    """Read the optional ``present_mask`` byte sidecar for an attribute.

    Returns ``None`` when the level was written without a mask (every
    row real) or the array is missing.
    """
    full_name = f"{OBJECT_ATTRIBUTES}/{attr_name}"
    try:
        meta = level_group.read_array_meta(full_name)
    except Exception:
        return None
    if not meta.get("has_present_mask"):
        return None
    if not level_group.chunk_exists(full_name, "present_mask"):
        return None
    raw = level_group.read_bytes(full_name, "present_mask")
    return np.frombuffer(raw, dtype=np.uint8)


def write_groupings(
    level_group: FsGroup,
    groups: dict[int, list[int]],
) -> None:
    """Write group memberships: group_id → list of object_ids.

    Args:
        level_group: Resolution level group.
        groups: ``{group_id: [object_id, ...], ...}``.
            Group IDs must be contiguous starting from 0.
    """
    if not groups:
        return

    max_gid = max(groups.keys())
    group_list: list[npt.NDArray] = []
    for gid in range(max_gid + 1):
        members = groups.get(gid, [])
        group_list.append(np.array(members, dtype=np.int64))

    raw_bytes, offsets = encode_ragged_ints(group_list, dtype=np.dtype(np.int64))
    level_group.write_bytes(GROUPS, "data", raw_bytes)
    level_group.write_bytes(GROUPS, "offsets", offsets.tobytes())
    level_group.write_array_meta(GROUPS, {
        "zv_array": "groups",
        "num_groups": max_gid + 1,
    })


def write_groupings_attributes(
    level_group: FsGroup,
    attr_name: str,
    data: npt.NDArray,
) -> None:
    """Write dense G×C groupings attribute data.

    Args:
        level_group: Resolution level group.
        attr_name: Attribute name.
        data: ``(G,)`` or ``(G, C)`` array.
    """
    full_name = f"{GROUP_ATTRIBUTES}/{attr_name}"
    _ensure_array_dir(level_group, full_name)
    level_group.write_bytes(full_name, "data", data.tobytes())
    level_group.write_array_meta(full_name, {
        "zv_array": "groupings_attribute",
        "name": attr_name,
        "dtype": str(data.dtype),
        "shape": list(data.shape),
    })


def write_cross_chunk_links(
    level_group: FsGroup,
    links: list[list[tuple[ChunkCoords, int]]] | list[CrossChunkLink],
    sid_ndim: int,
    *,
    delta: int = 0,
    link_width: int | None = None,
) -> None:
    """Write cross-chunk link records under ``cross_chunk_links/<delta>/``.

    Each record is ``link_width`` ``(chunk_coords, vertex_idx)``
    endpoints.  ``link_width=2`` (the default) encodes the classic
    cross-chunk edge ``((chunk_A, vi_A), (chunk_B, vi_B))``;
    ``link_width=3`` encodes a triangle face spanning chunks;
    ``link_width=1`` encodes a single parent→child reference used by
    pyramid metanode drill-down.

    Records may be passed either as legacy 2-tuples (compatibility
    with the pre-0.6.0 edge-only API) or as a list of endpoint lists
    when ``link_width`` is supplied explicitly.

    Endpoint 0 is at the owning resolution level; endpoint k (k>0)
    is at ``this_level + delta``.  For ``link_width=1`` (metanode
    drill-down) the single endpoint is at ``this_level + delta`` and
    is paired with an implicit source defined by the writer (the
    record stores only the child reference).

    Args:
        level_group: Resolution level group.
        links: List of records; each record is a list of
            ``(chunk_coords, vertex_idx)`` tuples of length
            ``link_width``.  Legacy 2-tuple form is accepted when
            ``link_width`` is 2 (or omitted).
        sid_ndim: Number of spatial index dimensions.
        delta: Level delta; see :mod:`zarr_vectors.core.paths`.
        link_width: Endpoints per record.  Defaults to 2 (or to the
            arity of the first record if it's a list).
    """
    if not links:
        return

    # Normalise input to a list-of-lists shape; resolve link_width.
    normalised: list[list[tuple[ChunkCoords, int]]] = []
    for rec in links:
        if isinstance(rec, tuple) and len(rec) == 2 and isinstance(rec[0], tuple) and not isinstance(rec[0][0], tuple):
            # Legacy CrossChunkLink: ((chunk_a, vi_a), (chunk_b, vi_b))
            normalised.append([rec[0], rec[1]])
        else:
            normalised.append(list(rec))

    if link_width is None:
        link_width = len(normalised[0])
    for rec in normalised:
        if len(rec) != link_width:
            raise ArrayError(
                f"cross_chunk_links/{format_delta(delta)}: record arity "
                f"{len(rec)} != link_width {link_width}"
            )

    full_name = cross_chunk_links_path(delta)
    flat: list[int] = []
    for rec in normalised:
        for chunk, vi in rec:
            if len(chunk) != sid_ndim:
                raise ArrayError(
                    f"chunk coords arity mismatch in cross_chunk_links/"
                    f"{format_delta(delta)}: sid_ndim={sid_ndim}, "
                    f"got len(chunk)={len(chunk)}"
                )
            flat.extend(int(c) for c in chunk)
            flat.append(int(vi))

    arr = np.array(flat, dtype=np.int64)
    level_group.write_bytes(full_name, "data", arr.tobytes())
    level_group.write_array_meta(full_name, {
        "zv_array": "cross_chunk_links",
        "num_links": len(normalised),
        "sid_ndim": sid_ndim,
        "level_delta": int(delta),
        "link_width": int(link_width),
    })


def write_cross_chunk_link_attributes(
    level_group: FsGroup,
    attr_name: str,
    attr_data: npt.NDArray,
    *,
    num_links: int,
    delta: int = 0,
) -> None:
    """Write per-edge attribute data parallel to ``cross_chunk_links/<delta>/data``.

    The cross-chunk-link attribute array is a single flat blob whose
    rows are in the same order as the cross-chunk links written by
    :func:`write_cross_chunk_links` for the same ``delta``.  Length is
    runtime-checked against the parallel CCL array's ``num_links`` so a
    desynchronized write fails loudly instead of producing silent
    corruption.

    Args:
        level_group: Resolution level group.
        attr_name: Attribute name.
        attr_data: ``(num_links,)`` or ``(num_links, C)`` array.
        num_links: Expected length (the ``num_links`` value on the
            parallel ``cross_chunk_links/<delta>/`` array).
        delta: Level delta; see :mod:`zarr_vectors.core.paths`.
    """
    if attr_data.shape[0] != num_links:
        raise ArrayError(
            f"cross_chunk_link_attributes[{attr_name}] row count "
            f"{attr_data.shape[0]} != num_links {num_links} "
            f"(delta={format_delta(delta)})"
        )
    full_name = cross_chunk_link_attributes_path(attr_name, delta)
    _ensure_array_dir(level_group, full_name)
    level_group.write_bytes(full_name, "data", np.ascontiguousarray(attr_data).tobytes())
    level_group.write_array_meta(full_name, {
        "zv_array": "cross_chunk_link_attribute",
        "name": attr_name,
        "dtype": str(attr_data.dtype),
        "level_delta": int(delta),
        "num_links": int(num_links),
        "shape": list(attr_data.shape),
    })


# ===================================================================
# Reading data
# ===================================================================

def read_chunk_vertices(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
    dtype: np.dtype | str = np.float32,
    ndim: int = 3,
) -> list[npt.NDArray[np.floating]]:
    """Read all vertex groups from a spatial chunk.

    Args:
        level_group: Resolution level group.
        chunk_coords: Spatial chunk coordinates.
        dtype: Numpy dtype.
        ndim: Number of coordinate dimensions (D).

    Returns:
        List of arrays, each ``(N_k, D)``.

    Raises:
        ArrayError: If the chunk does not exist or data is malformed.
    """
    key = _chunk_key(chunk_coords)
    dtype = np.dtype(dtype)

    with _maybe_batched_reads(level_group, [
        (VERTICES, [key]),
        (VERTEX_FRAGMENTS, [key]),
    ]):
        try:
            raw = level_group.read_bytes(VERTICES, key)
        except Exception as e:
            raise ArrayError(f"Cannot read vertices chunk {key}: {e}") from e

        bytes_per_vertex = int(dtype.itemsize) * int(ndim)
        offsets = _read_vertex_offsets(
            level_group, chunk_coords, bytes_per_vertex=bytes_per_vertex,
        )
    return decode_vertex_groups(raw, offsets, dtype, ndim)


def read_vertex_group(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
    vg_index: int,
    dtype: np.dtype | str = np.float32,
    ndim: int = 3,
) -> npt.NDArray[np.floating]:
    """Read a single vertex group using byte offsets for efficient access.

    Args:
        level_group: Resolution level group.
        chunk_coords: Spatial chunk coordinates.
        vg_index: Index of the vertex group within the chunk.
        dtype: Numpy dtype.
        ndim: Number of coordinate dimensions.

    Returns:
        Array of shape ``(N, D)``.
    """
    key = _chunk_key(chunk_coords)
    dtype = np.dtype(dtype)

    with _maybe_batched_reads(level_group, [
        (VERTICES, [key]),
        (VERTEX_FRAGMENTS, [key]),
    ]):
        raw = level_group.read_bytes(VERTICES, key)
        bytes_per_vertex = int(dtype.itemsize) * int(ndim)
        offsets = _read_vertex_offsets(
            level_group, chunk_coords, bytes_per_vertex=bytes_per_vertex,
        )

    if vg_index < 0 or vg_index >= len(offsets):
        raise ArrayError(
            f"Vertex group index {vg_index} out of range "
            f"(chunk {key} has {len(offsets)} groups)"
        )

    start = int(offsets[vg_index])
    end = int(offsets[vg_index + 1]) if vg_index + 1 < len(offsets) else len(raw)
    segment = raw[start:end]

    arr = np.frombuffer(segment, dtype=dtype)
    if ndim > 1:
        arr = arr.reshape(-1, ndim)
    return arr


def read_chunk_links(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
    dtype: np.dtype | str = np.int64,
    link_width: int | None = None,
    *,
    delta: int = 0,
) -> list[npt.NDArray[np.integer]]:
    """Read all link groups from a spatial chunk's ``links/<delta>/`` array.

    Args:
        level_group: Resolution level group.
        chunk_coords: Spatial chunk coordinates.
        dtype: Integer dtype.
        link_width: Number of columns per link (L). If None, read from
            array metadata.
        delta: Level delta; ``0`` is the intra-level array.

    Returns:
        List of arrays, each ``(M_k, L)``.
    """
    key = _chunk_key(chunk_coords)
    dtype = np.dtype(dtype)
    full_name = links_path(delta)

    if link_width is None:
        meta = level_group.read_array_meta(full_name)
        link_width = meta.get("link_width", 2)

    # delta == 0 needs both the link bytes and the fragment-index sibling;
    # delta != 0 keeps the v0.5 inline self-describing layout (one blob).
    plan: list[tuple[str, list[str]]] = [(full_name, [key])]
    if delta == 0:
        plan.append((LINK_FRAGMENTS, [key]))

    with _maybe_batched_reads(level_group, plan):
        try:
            raw = level_group.read_bytes(full_name, key)
        except Exception as e:
            raise ArrayError(
                f"Cannot read links chunk {key} (delta={format_delta(delta)}): {e}"
            ) from e

        if delta == 0:
            # v0.6 intra-level layout: raw is the flat concatenated link
            # data; per-group row counts live in link_fragments/<chunk>.
            fi = read_link_fragment_index(level_group, chunk_coords)
            groups: list[npt.NDArray[np.integer]] = []
            row_size = int(dtype.itemsize) * int(link_width)
            for f in range(fi.num_fragments):
                if not fi.is_range(f):
                    raise ArrayError(
                        f"link_fragments/{key} fragment {f} is non-contiguous; "
                        "read_chunk_links requires every fragment to be a "
                        "contiguous range over rows of links/0/<chunk>.",
                    )
                start, count = fi.range(f)
                byte_lo = int(start) * row_size
                byte_hi = byte_lo + int(count) * row_size
                segment = raw[byte_lo:byte_hi]
                arr = np.frombuffer(segment, dtype=dtype)
                if link_width > 1:
                    arr = arr.reshape(-1, link_width)
                groups.append(arr)
            return groups

        # Cross-level delta != 0: v0.5 inline self-describing layout.
        return decode_ragged_blob(raw, dtype, ncols=link_width)


def read_chunk_attributes(
    level_group: FsGroup,
    attr_name: str,
    chunk_coords: ChunkCoords,
    dtype: np.dtype | str = np.float32,
    ncols: int = 1,
    *,
    vert_dtype: np.dtype | str | None = None,
    vert_ndim: int | None = None,
) -> list[npt.NDArray]:
    """Read vertex attribute data for a chunk.

    Per-group byte offsets are derived from ``vertex_group_offsets``:
    group ``k`` has ``n_k = (vert_offsets[k+1] - vert_offsets[k]) /
    (vert_dtype.itemsize * vert_ndim)`` vertices, so its attribute
    byte offset is ``cumsum(n_k) * dtype.itemsize * ncols``.

    Args:
        level_group: Resolution level group.
        attr_name: Attribute name.
        chunk_coords: Spatial chunk coordinates.
        dtype: Numpy dtype of the attribute.
        ncols: Number of columns (channels). Use 1 for scalars.
        vert_dtype: Vertex dtype (needed to derive per-group sizes).
            When ``None`` (default) it is read from the ``vertices/``
            array metadata.
        vert_ndim: Vertex coordinate dimensionality.  When ``None``
            (default) it is read from root metadata via NGFF axes; on
            failure falls back to 3.

    Returns:
        List of arrays aligned with vertex groups.
    """
    key = _chunk_key(chunk_coords)
    dtype = np.dtype(dtype)
    full_name = f"{VERTEX_ATTRIBUTES}/{attr_name}"

    if vert_dtype is None:
        try:
            vmeta = level_group.read_array_meta(VERTICES)
            vert_dtype = np.dtype(vmeta.get("dtype", "float32"))
        except Exception:
            vert_dtype = np.dtype(np.float32)
    else:
        vert_dtype = np.dtype(vert_dtype)
    if vert_ndim is None:
        vert_ndim = _infer_vert_ndim(level_group)

    with _maybe_batched_reads(level_group, [
        (full_name, [key]),
        (VERTEX_FRAGMENTS, [key]),
    ]):
        try:
            raw = level_group.read_bytes(full_name, key)
        except Exception as e:
            raise ArrayError(
                f"Cannot read attribute '{attr_name}' chunk {key}: {e}"
            ) from e

        attr_offsets = _derive_attribute_offsets(
            level_group, chunk_coords,
            vert_dtype=vert_dtype, vert_ndim=vert_ndim,
            attr_dtype=dtype, attr_ncols=ncols,
            total_attr_bytes=len(raw),
        )
    return decode_vertex_groups(raw, attr_offsets, dtype, ncols)


def _infer_vert_ndim(level_group: FsGroup) -> int:
    """Best-effort lookup of the spatial-index dimensionality.

    Reads NGFF ``multiscales[0].axes`` length from root attrs.  Falls
    back to 3 when unavailable.
    """
    try:
        # Level groups don't carry root attrs; walk up to root via the
        # backend.  Most levels have an ``_backend`` handle that owns
        # the root path.
        from zarr_vectors.core.group import Group
        root_handle = Group._from_backend(level_group._backend, "")
        ms = root_handle.attrs.to_dict().get("multiscales") or []
        if ms and isinstance(ms, list):
            axes = ms[0].get("axes") or []
            if axes:
                return len(axes)
    except Exception:
        pass
    return 3


def _derive_attribute_offsets(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
    *,
    vert_dtype: np.dtype,
    vert_ndim: int,
    attr_dtype: np.dtype,
    attr_ncols: int,
    total_attr_bytes: int,
) -> npt.NDArray[np.int64]:
    """Compute per-group attribute byte offsets from vertex offsets.

    Attribute groups align 1:1 with vertex groups.  The k-th vertex
    group spans ``vert_offsets[k+1] - vert_offsets[k]`` bytes of
    vertex data, which corresponds to ``n_k`` vertices (and therefore
    ``n_k`` attribute rows).
    """
    vert_row_size = vert_dtype.itemsize * vert_ndim
    if vert_row_size <= 0:
        return np.empty(0, dtype=np.int64)
    fi = read_vertex_fragment_index(level_group, chunk_coords)
    if fi.num_fragments == 0:
        return np.empty(0, dtype=np.int64)
    # Per-fragment vertex row count.  Today's writers always emit
    # range fragments; non-contiguous shapes will need a richer
    # attribute-alignment story (out of scope for this change).
    n_per_group = np.empty(fi.num_fragments, dtype=np.int64)
    for f in range(fi.num_fragments):
        if not fi.is_range(f):
            raise ArrayError(
                f"vertex_fragments fragment {f} is non-contiguous; "
                "attribute alignment requires every fragment to be a "
                "contiguous range of vertex rows.",
            )
        _start, count = fi.range(f)
        n_per_group[f] = int(count)
    attr_row_size = attr_dtype.itemsize * attr_ncols
    attr_byte_lengths = n_per_group * int(attr_row_size)
    attr_offsets = np.empty_like(attr_byte_lengths)
    attr_offsets[0] = 0
    np.cumsum(attr_byte_lengths[:-1], out=attr_offsets[1:])
    del total_attr_bytes  # signature retained for caller compat
    return attr_offsets


def read_object_manifest(
    level_group: FsGroup,
    object_id: int,
) -> ObjectManifest:
    """Read the ordered vertex group reference list for one object.

    Args:
        level_group: Resolution level group.
        object_id: Object ID.

    Returns:
        List of ``(chunk_coords, vg_index)`` tuples.
    """
    meta = level_group.read_array_meta(OBJECT_INDEX)
    sid_ndim = meta["sid_ndim"]
    num_objects = meta["num_objects"]

    if object_id < 0 or object_id >= num_objects:
        raise ArrayError(
            f"Object ID {object_id} out of range [0, {num_objects})"
        )

    if meta.get("layout") == OBJECT_INDEX_LAYOUT_V1:
        manifests_arr = level_group.zarr_group[OBJECT_INDEX]["manifests"]
        # Slice (then index) instead of scalar indexing: zarr 3.x vlen-bytes
        # returns a 0-d object ndarray under ``arr[i]``, whose ``bytes()``
        # is the array header — not the payload.  ``arr[i:i+1][0]`` is the
        # actual bytes object and still fetches only the chunk holding i.
        blob = manifests_arr[object_id:object_id + 1][0]
    else:
        blob = _legacy_read_object_blob(level_group, object_id, num_objects)

    blocks = decode_object_manifest_blocks(blob, sid_ndim=sid_ndim)
    return _expand_blocks(blocks)


def read_all_object_manifests(
    level_group: FsGroup,
) -> list[ObjectManifest]:
    """Read all object manifests at once.

    Returns:
        List indexed by object_id, each a list of ``(chunk_coords, vg_index)``.
    """
    meta = level_group.read_array_meta(OBJECT_INDEX)
    sid_ndim = meta["sid_ndim"]
    num_objects = int(meta.get("num_objects", 0))

    if meta.get("layout") == OBJECT_INDEX_LAYOUT_V1:
        if num_objects == 0:
            return []
        manifests_arr = level_group.zarr_group[OBJECT_INDEX]["manifests"]
        # Slicing yields a 1-D object ndarray whose elements are bytes
        # directly (unlike scalar indexing — see read_object_manifest).
        blobs = manifests_arr[:]
        return [
            _expand_blocks(decode_object_manifest_blocks(b, sid_ndim=sid_ndim))
            for b in blobs
        ]

    # Legacy layout: single-chunk data + offsets byte blobs.
    with _maybe_batched_reads(level_group, [
        (OBJECT_INDEX, ["data", "offsets"]),
    ]):
        raw = level_group.read_bytes(OBJECT_INDEX, "data")
        offsets = np.frombuffer(
            level_group.read_bytes(OBJECT_INDEX, "offsets"),
            dtype=np.int64,
        )
    return [
        _expand_blocks(
            decode_object_manifest_blocks(
                _slice_legacy_blob(raw, offsets, i, num_objects),
                sid_ndim=sid_ndim,
            ),
        )
        for i in range(num_objects)
    ]


def _legacy_read_object_blob(
    level_group: FsGroup,
    object_id: int,
    num_objects: int,
) -> bytes:
    """Load one object's encoded manifest blob from the legacy
    ``object_index/{data,offsets}`` byte-blob layout.

    Reads the full ``data`` and ``offsets`` arrays (each a single-chunk
    blob) and slices to the one object's byte range.  This is the cost
    the vlen-bytes ``manifests`` layout was introduced to eliminate;
    kept for backwards-compatible reads of pre-vlen stores.
    """
    with _maybe_batched_reads(level_group, [
        (OBJECT_INDEX, ["data", "offsets"]),
    ]):
        raw = level_group.read_bytes(OBJECT_INDEX, "data")
        offsets = np.frombuffer(
            level_group.read_bytes(OBJECT_INDEX, "offsets"),
            dtype=np.int64,
        )
    return _slice_legacy_blob(raw, offsets, object_id, num_objects)


def _slice_legacy_blob(
    data: bytes,
    offsets: npt.NDArray[np.int64],
    object_id: int,
    num_objects: int,
) -> bytes:
    start = int(offsets[object_id])
    end = (
        int(offsets[object_id + 1])
        if object_id + 1 < num_objects
        else len(data)
    )
    return data[start:end]


def _expand_blocks(
    blocks: list[tuple[ChunkCoords, Any]],
) -> ObjectManifest:
    """Expand v0.6 manifest blocks to the legacy
    ``[(chunk_coords, vg_index), ...]`` tuple list.

    Mode-1 (range) and mode-2 (explicit list) blocks expand to one
    tuple per fragment so existing call sites that iterate
    ``(chunk_coords, vg_index)`` keep working unchanged.  Callers that
    want the raw block representation can use
    :func:`zarr_vectors.encoding.fragments.decode_object_manifest_blocks`
    directly.
    """
    out: ObjectManifest = []
    for chunk_coords, frag_ref in blocks:
        if isinstance(frag_ref, int):
            out.append((chunk_coords, int(frag_ref)))
        elif isinstance(frag_ref, tuple):
            r_start, r_count = frag_ref
            for k in range(int(r_count)):
                out.append((chunk_coords, int(r_start) + k))
        else:
            # np.ndarray of explicit indices
            for idx in frag_ref:
                out.append((chunk_coords, int(idx)))
    return out


def read_object_vertices(
    level_group: FsGroup,
    object_id: int,
    dtype: np.dtype | str = np.float32,
    ndim: int = 3,
) -> list[npt.NDArray[np.floating]]:
    """Read all vertex data for an object by following its manifest.

    Args:
        level_group: Resolution level group.
        object_id: Object ID.
        dtype: Numpy dtype.
        ndim: Number of coordinate dimensions.

    Returns:
        List of vertex group arrays in reconstruction order.
    """
    manifest = read_object_manifest(level_group, object_id)
    if not manifest:
        return []

    # Prefetch every chunk this object touches in one async gather, so
    # the per-fragment read_vertex_group calls below hit the cache
    # instead of paying one round-trip per fragment.  Distinct chunks
    # appear once in the plan; multiple fragments inside the same chunk
    # share the same cache entry.
    chunk_keys = sorted({_chunk_key(cc) for cc, _ in manifest})
    with _maybe_batched_reads(level_group, [
        (VERTICES, chunk_keys),
        (VERTEX_FRAGMENTS, chunk_keys),
    ]):
        groups: list[npt.NDArray] = []
        for chunk_coords, vg_index in manifest:
            vg = read_vertex_group(
                level_group, chunk_coords, vg_index,
                dtype=dtype, ndim=ndim,
            )
            groups.append(vg)
    return groups


def read_object_attributes(
    level_group: FsGroup,
    attr_name: str,
    dtype: np.dtype | str | None = None,
) -> npt.NDArray:
    """Read dense O×C object attribute data.

    Args:
        level_group: Resolution level group.
        attr_name: Attribute name.
        dtype: Override dtype. If None, read from metadata.

    Returns:
        Array of shape ``(O,)`` or ``(O, C)``.
    """
    full_name = f"{OBJECT_ATTRIBUTES}/{attr_name}"
    meta = level_group.read_array_meta(full_name)
    if dtype is None:
        dtype = np.dtype(meta["dtype"])
    else:
        dtype = np.dtype(dtype)
    shape = tuple(meta["shape"])

    raw = level_group.read_bytes(full_name, "data")
    return np.frombuffer(raw, dtype=dtype).reshape(shape).copy()


def read_group_object_ids(
    level_group: FsGroup,
    group_id: int,
) -> list[int]:
    """Read the list of object IDs belonging to a group.

    Args:
        level_group: Resolution level group.
        group_id: Group ID.

    Returns:
        List of object ID integers.
    """
    meta = level_group.read_array_meta(GROUPS)
    num_groups = meta["num_groups"]

    if group_id < 0 or group_id >= num_groups:
        raise ArrayError(
            f"Group ID {group_id} out of range [0, {num_groups})"
        )

    raw = level_group.read_bytes(GROUPS, "data")
    offsets = np.frombuffer(
        level_group.read_bytes(GROUPS, "offsets"),
        dtype=np.int64,
    )

    all_groups = decode_ragged_ints(raw, offsets, dtype=np.dtype(np.int64), ncols=1)
    return all_groups[group_id].tolist()


def read_all_groupings(
    level_group: FsGroup,
) -> list[list[int]]:
    """Read all group memberships.

    Returns:
        List indexed by group_id, each a list of object_id ints.
    """
    meta = level_group.read_array_meta(GROUPS)

    raw = level_group.read_bytes(GROUPS, "data")
    offsets = np.frombuffer(
        level_group.read_bytes(GROUPS, "offsets"),
        dtype=np.int64,
    )

    all_groups = decode_ragged_ints(raw, offsets, dtype=np.dtype(np.int64), ncols=1)
    return [g.tolist() for g in all_groups]


def read_groupings_attributes(
    level_group: FsGroup,
    attr_name: str,
    dtype: np.dtype | str | None = None,
) -> npt.NDArray:
    """Read dense G×C groupings attribute data."""
    full_name = f"{GROUP_ATTRIBUTES}/{attr_name}"
    meta = level_group.read_array_meta(full_name)
    if dtype is None:
        dtype = np.dtype(meta["dtype"])
    else:
        dtype = np.dtype(dtype)
    shape = tuple(meta["shape"])

    raw = level_group.read_bytes(full_name, "data")
    return np.frombuffer(raw, dtype=dtype).reshape(shape).copy()


def read_cross_chunk_links(
    level_group: FsGroup,
    *,
    delta: int = 0,
) -> list[tuple[tuple[ChunkCoords, int], ...]]:
    """Read all cross-chunk link records from ``cross_chunk_links/<delta>/data``.

    Each record is a list of ``(chunk_coords, vertex_idx)`` endpoints.
    Endpoint 0 lives at the owning resolution level; endpoints k (k>0)
    live at ``this_level + delta``.

    Returns ``[]`` when the ``<delta>`` array does not exist or has no
    records.

    Returns:
        List of records; each record has length ``link_width``.  For
        the common ``link_width=2`` edge case callers can unpack each
        record as ``((chunk_A, vi_A), (chunk_B, vi_B))``.
    """
    full_name = cross_chunk_links_path(delta)
    if not level_group.array_exists(full_name):
        return []
    try:
        meta = level_group.read_array_meta(full_name)
    except Exception:
        return []
    if "num_links" not in meta or "sid_ndim" not in meta:
        return []
    num_links = meta["num_links"]
    sid_ndim = meta["sid_ndim"]
    link_width = int(meta.get("link_width", 2))
    if num_links == 0:
        return []
    if not level_group.chunk_exists(full_name, "data"):
        return []

    raw = level_group.read_bytes(full_name, "data")
    arr = np.frombuffer(raw, dtype=np.int64)

    endpoint_len = sid_ndim + 1
    record_len = link_width * endpoint_len
    records: list[tuple[tuple[ChunkCoords, int], ...]] = []

    for i in range(0, len(arr), record_len):
        endpoints: list[tuple[ChunkCoords, int]] = []
        for j in range(link_width):
            base = i + j * endpoint_len
            chunk = tuple(int(x) for x in arr[base : base + sid_ndim])
            vi = int(arr[base + sid_ndim])
            endpoints.append((chunk, vi))
        records.append(tuple(endpoints))

    return records


def read_cross_chunk_link_attributes(
    level_group: FsGroup,
    attr_name: str,
    dtype: np.dtype | str | None = None,
    *,
    delta: int = 0,
) -> npt.NDArray:
    """Read per-link attribute data parallel to ``cross_chunk_links/<delta>/data``.

    Returns:
        Array of shape ``(num_links,)`` or ``(num_links, C)``.
    """
    full_name = cross_chunk_link_attributes_path(attr_name, delta)
    meta = level_group.read_array_meta(full_name)
    if dtype is None:
        dtype = np.dtype(meta["dtype"])
    else:
        dtype = np.dtype(dtype)
    shape = tuple(meta.get("shape", [meta["num_links"]]))
    raw = level_group.read_bytes(full_name, "data")
    return np.frombuffer(raw, dtype=dtype).reshape(shape).copy()


# ===================================================================
# Listing / introspection
# ===================================================================

def list_chunk_keys(
    level_group: FsGroup,
    array_name: str = VERTICES,
) -> list[ChunkCoords]:
    """List all chunk coordinates that have data for an array.

    Args:
        level_group: Resolution level group.
        array_name: Array name (default: ``"vertices"``).

    Returns:
        Sorted list of chunk coordinate tuples.
    """
    keys = level_group.list_chunks(array_name)
    coords: list[ChunkCoords] = []
    for k in keys:
        try:
            coords.append(_parse_chunk_key(k))
        except ValueError:
            continue  # skip non-chunk files (e.g. .zattrs)
    return sorted(coords)


def _list_deltas_under(level_group: FsGroup, group_path: str) -> list[int]:
    """List signed level-delta segments present under a group path.

    Returns the sorted list of integers parsed from immediate child
    names that look like delta segments (``"0"``, ``"+N"``, ``"-N"``).
    Returns an empty list when the parent group is absent.  Used by the
    public ``list_link_deltas`` / ``list_cross_link_deltas`` helpers
    (and indirectly by readers and validators that walk the multiscale
    link layout).
    """
    if not level_group.array_exists(group_path):
        return []
    try:
        sub = level_group[group_path]
    except Exception:
        return []
    deltas: list[int] = []
    for name in sub:
        try:
            deltas.append(parse_delta(name))
        except ValueError:
            continue
    return sorted(deltas)


def list_link_deltas(level_group: FsGroup) -> list[int]:
    """Sorted list of ``<delta>`` values present under ``links/`` in a level."""
    return _list_deltas_under(level_group, LINKS)


def list_cross_link_deltas(level_group: FsGroup) -> list[int]:
    """Sorted list of ``<delta>`` values present under ``cross_chunk_links/``."""
    return _list_deltas_under(level_group, CROSS_CHUNK_LINKS)


def list_link_attribute_deltas(level_group: FsGroup, name: str) -> list[int]:
    """Sorted list of ``<delta>`` values present under ``link_attributes/<name>/``."""
    return _list_deltas_under(level_group, f"{LINK_ATTRIBUTES}/{name}")


def list_cross_chunk_link_attribute_deltas(
    level_group: FsGroup, name: str,
) -> list[int]:
    """Sorted list of ``<delta>`` values under ``cross_chunk_link_attributes/<name>/``."""
    return _list_deltas_under(level_group, f"{CROSS_CHUNK_LINK_ATTRIBUTES}/{name}")


def resolve_chunk_keys(
    level_group: FsGroup,
    chunk_shape: tuple[float, ...],
    *,
    bbox: tuple[npt.NDArray, npt.NDArray] | None = None,
    chunks: list[ChunkCoords] | None = None,
    array_name: str = VERTICES,
) -> list[ChunkCoords]:
    """Resolve the chunk_keys present in a level, intersected with the
    bbox-implied set and the explicit ``chunks`` whitelist.

    Combination of filters is AND: a chunk must be physically present
    *and* satisfy every supplied constraint.

    Args:
        level_group: Resolution level group.
        chunk_shape: Physical chunk size per spatial dimension.
        bbox: Optional ``(min_corner, max_corner)``. Intersected with the
            stored keys via :func:`chunks_intersecting_bbox`.
        chunks: Optional explicit whitelist of chunk coordinate tuples.
            Pass ``[]`` for "no chunks" (yields an empty result). Pass
            ``None`` (the default) for "no filter".
        array_name: Array whose chunk keys to enumerate.

    Returns:
        Sorted list of chunk coordinate tuples.

    Raises:
        ValueError: If a tuple in ``chunks`` has the wrong arity for
            this store.
    """
    from zarr_vectors.spatial.chunking import chunks_intersecting_bbox

    present = list_chunk_keys(level_group, array_name=array_name)
    keys: set[ChunkCoords] = set(present)

    if bbox is not None:
        target = set(chunks_intersecting_bbox(
            np.asarray(bbox[0]), np.asarray(bbox[1]), tuple(chunk_shape),
        ))
        keys &= target

    if chunks is not None:
        expected_arity = len(chunk_shape)
        normalised: set[ChunkCoords] = set()
        for c in chunks:
            t = tuple(int(x) for x in c)
            if len(t) != expected_arity:
                # Some stores (e.g. attribute-binned points / graphs) prefix
                # spatial chunk coords with an extra binning axis, giving
                # keys of length ``expected_arity + 1``. Accept those too.
                if len(t) != expected_arity + 1:
                    raise ValueError(
                        f"chunks tuple {c!r} has arity {len(t)}; "
                        f"expected {expected_arity} (or {expected_arity + 1} "
                        f"for attribute-binned stores)"
                    )
            normalised.add(t)
        keys &= normalised

    return sorted(keys)


def count_vertex_groups(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
) -> int:
    """Count vertex groups in a chunk by reading the fragment-index header."""
    return len(read_vertex_fragment_index(level_group, chunk_coords))


# ===================================================================
# Internal helpers
# ===================================================================

def read_vertex_fragment_index(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
) -> FragmentIndex:
    """Read and decode the ``vertex_fragments/<chunk>`` blob.

    Returns the v0.6 :class:`FragmentIndex` view describing how rows of
    ``vertices/<chunk>`` partition into fragments.
    """
    key = _chunk_key(chunk_coords)
    raw = level_group.read_bytes(VERTEX_FRAGMENTS, key)
    return decode_fragments(raw)


def read_link_fragment_index(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
) -> FragmentIndex:
    """Read and decode the ``link_fragments/<chunk>`` blob (delta=0 only)."""
    key = _chunk_key(chunk_coords)
    raw = level_group.read_bytes(LINK_FRAGMENTS, key)
    return decode_fragments(raw)


def _read_vertex_offsets(
    level_group: FsGroup,
    chunk_coords: ChunkCoords,
    *,
    bytes_per_vertex: int | None = None,
) -> npt.NDArray[np.int64]:
    """Read the ``(K,)`` int64 vertex byte offsets for a chunk.

    Computed from the v0.6 ``vertex_fragments/<chunk>`` index.  Every
    fragment must be a contiguous range over rows of ``vertices/<chunk>``
    — the only shape the existing writer produces.  Stores written by
    future writers that materialise non-contiguous / shared-row
    fragments must use the higher-level fragment-index API directly;
    this helper raises rather than silently lying about a byte offset
    that doesn't exist.

    Args:
        level_group: Resolution level group.
        chunk_coords: Spatial chunk coordinates.
        bytes_per_vertex: Bytes per vertex row.  When omitted it is
            inferred from the ``vertices/`` array's ``dtype`` metadata
            and the root NGFF axes count.
    """
    if bytes_per_vertex is None:
        vmeta = level_group.read_array_meta(VERTICES)
        vdtype = np.dtype(vmeta.get("dtype", "float32"))
        ndim = _infer_vert_ndim(level_group)
        bytes_per_vertex = int(vdtype.itemsize) * int(ndim)
    fi = read_vertex_fragment_index(level_group, chunk_coords)
    if fi.num_fragments == 0:
        return np.empty(0, dtype=np.int64)
    offsets = np.empty(fi.num_fragments, dtype=np.int64)
    for i in range(fi.num_fragments):
        if not fi.is_range(i):
            raise ArrayError(
                f"vertex_fragments/{_chunk_key(chunk_coords)} fragment {i} "
                "is non-contiguous; byte-offset access requires every "
                "fragment to be a contiguous range over rows of "
                "vertices/<chunk>.  Use read_vertex_fragment_index() "
                "directly for non-contiguous fragments.",
            )
        start, _count = fi.range(i)
        offsets[i] = int(start) * int(bytes_per_vertex)
    return offsets


