Source code for cellmap_segmentation_challenge.utils.matched_crop

import logging
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple

import numpy as np
import zarr
from skimage.transform import rescale
from upath import UPath

from cellmap_data import CellMapImage


logger = logging.getLogger(__name__)


def _get_attr_any(attrs, keys):
    for k in keys:
        if k in attrs:
            return attrs[k]
    return None


def _parse_voxel_size(attrs) -> Optional[Tuple[float, ...]]:
    vs = _get_attr_any(attrs, ["voxel_size", "resolution", "scale"])
    if vs is None:
        return None
    return tuple(float(x) for x in vs)


def _parse_translation(attrs) -> Optional[Tuple[float, ...]]:
    tr = _get_attr_any(attrs, ["translation", "offset"])
    if tr is None:
        return None
    return tuple(float(x) for x in tr)


def _resize_pad_crop(
    image: np.ndarray, target_shape: Tuple[int, ...], pad_value=0
) -> np.ndarray:
    # center pad/crop like your resize_array
    arr_shape = image.shape
    resized = image

    pad_width = []
    for i in range(len(target_shape)):
        if arr_shape[i] < target_shape[i]:
            pad_before = (target_shape[i] - arr_shape[i]) // 2
            pad_after = target_shape[i] - arr_shape[i] - pad_before
            pad_width.append((pad_before, pad_after))
        else:
            pad_width.append((0, 0))

    if any(p > 0 for pads in pad_width for p in pads):
        resized = np.pad(resized, pad_width, mode="constant", constant_values=pad_value)

    slices = []
    for i in range(len(target_shape)):
        if resized.shape[i] > target_shape[i]:
            start = (resized.shape[i] - target_shape[i]) // 2
            end = start + target_shape[i]
            slices.append(slice(start, end))
        else:
            slices.append(slice(None))
    return resized[tuple(slices)]


[docs] @dataclass class MatchedCrop: path: str | UPath class_label: str target_voxel_size: Sequence[float] target_shape: Sequence[int] target_translation: Sequence[float] instance_classes: Sequence[str] semantic_threshold: float = 0.5 pad_value: float | int = 0 def _is_instance(self) -> bool: return self.class_label in set(self.instance_classes) def _select_non_ome_level( self, grp: zarr.Group ) -> Tuple[str, Optional[Tuple[float, ...]], Optional[Tuple[float, ...]]]: """ Heuristic: choose among arrays like s0/s1/... (or any array keys) by voxel_size attrs. Preference: pick the level whose voxel_size is <= target_voxel_size (not finer than target) and closest to target; else closest overall. Returns (array_key, voxel_size, translation) """ keys = list(grp.array_keys()) if not keys: raise ValueError(f"No arrays found under {self.path}") tgt = np.asarray(self.target_voxel_size, dtype=float) best_key = None best_vs = None best_tr = None best_score = None for k in keys: arr = grp[k] vs = _parse_voxel_size(arr.attrs) tr = _parse_translation(arr.attrs) if vs is None: # treat as unknown; deprioritize score = 1e18 else: v = np.asarray(vs, dtype=float) if v.size != tgt.size: v = v[-tgt.size :] # prefer v >= tgt? (coarser) or v <= tgt? depends on definition. # Here: voxel_size larger => coarser. We want not finer than target => v >= tgt # If your convention is reversed, adjust this rule. not_finer = np.all(v >= tgt) dist = float(np.linalg.norm(v - tgt)) score = dist + (0.0 if not_finer else 1e6) if best_score is None or score < best_score: best_score = score best_key = k best_vs = vs best_tr = tr assert best_key is not None return best_key, best_vs, best_tr def _load_source_array(self): """ Returns (image, input_voxel_size, input_translation) where image is a numpy array. """ ds = zarr.open(str(self.path), mode="r") # OME-NGFF multiscale if isinstance(ds, zarr.Group) and "multiscales" in ds.attrs: img = CellMapImage( path=str(self.path), target_class=self.class_label, target_scale=self.target_voxel_size, target_voxel_shape=self.target_shape, pad=True, pad_value=self.pad_value, interpolation="nearest" if self._is_instance() else "linear", ) level = img.scale_level level_path = UPath(self.path) / level # Extract input voxel size and translation from multiscales metadata input_voxel_size = None input_translation = None for d in ds.attrs["multiscales"][0]["datasets"]: if d["path"] == level: for t in d.get("coordinateTransformations", []): if t.get("type") == "scale": input_voxel_size = tuple(t["scale"]) elif t.get("type") == "translation": input_translation = tuple(t["translation"]) break arr = zarr.open(level_path.path, mode="r") image = arr[:] return image, input_voxel_size, input_translation # Non-OME group multiscale OR single-scale array with attrs if isinstance(ds, zarr.Group): # If this group directly contains the label array (common): path points at an array node # zarr.open on an array path usually returns an Array, not Group. If we got Group, pick a level. key, vs, tr = self._select_non_ome_level(ds) arr = ds[key] image = arr[:] return image, vs, tr # Single-scale zarr array if isinstance(ds, zarr.Array): image = ds[:] return image, _parse_voxel_size(ds.attrs), _parse_translation(ds.attrs) raise ValueError(f"Unsupported zarr node type at {self.path}: {type(ds)}")
[docs] def load_aligned(self) -> np.ndarray: """ Return full aligned volume in target space (target_shape). """ image, input_voxel_size, input_translation = self._load_source_array() tgt_vs = np.asarray(self.target_voxel_size, dtype=float) tgt_shape = tuple(int(x) for x in self.target_shape) tgt_tr = np.asarray(self.target_translation, dtype=float) # Resample if needed if input_voxel_size is not None: in_vs = np.asarray(input_voxel_size, dtype=float) if in_vs.size != tgt_vs.size: in_vs = in_vs[-tgt_vs.size :] if not np.allclose(in_vs, tgt_vs): scale_factors = in_vs / tgt_vs if self._is_instance(): image = rescale( image, scale_factors, order=0, mode="constant", preserve_range=True, ).astype(image.dtype) else: imgf = rescale( image, scale_factors, order=1, mode="constant", preserve_range=True, ) image = imgf > self.semantic_threshold elif image.shape != tgt_shape: # If no voxel size info, fall back to center crop/pad image = _resize_pad_crop(image, tgt_shape, pad_value=self.pad_value) # Compute relative offset in voxel units if input_translation is not None: in_tr = np.asarray(input_translation, dtype=float) if in_tr.size != tgt_tr.size: in_tr = in_tr[-tgt_tr.size :] # snap to voxel grid adjusted_in_tr = (in_tr // tgt_vs) * tgt_vs rel = (np.abs(adjusted_in_tr - tgt_tr) // tgt_vs) * np.sign( adjusted_in_tr - tgt_tr ) rel = rel.astype(int) else: rel = np.zeros(len(tgt_shape), dtype=int) # Translate + crop/pad into destination if any(rel != 0) or image.shape != tgt_shape: result = np.zeros(tgt_shape, dtype=image.dtype) input_slices = [] output_slices = [] for d in range(len(tgt_shape)): if rel[d] < 0: input_start = abs(rel[d]) output_start = 0 input_end = min(input_start + tgt_shape[d], image.shape[d]) length = input_end - input_start output_end = output_start + length else: input_start = 0 output_start = rel[d] output_end = min(tgt_shape[d], image.shape[d] + output_start) length = output_end - output_start input_end = input_start + length if length <= 0: return result input_slices.append(slice(int(input_start), int(input_end))) output_slices.append(slice(int(output_start), int(output_end))) result[tuple(output_slices)] = image[tuple(input_slices)] return result return image