Source code for cellmap_segmentation_challenge.utils.matched_crop

import logging
import os
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple

import numpy as np
import zarr
from cellmap_data import CellMapImage
from skimage.transform import rescale
from upath import UPath

logger = logging.getLogger(__name__)

# Memory estimation constants
BYTES_PER_FLOAT32 = 4  # 4 bytes per float32 voxel

# Maximum allowed size ratio between source and target arrays
# Can be set via environment variable for flexibility
# Default is 16x in each dimension (4096x total volume size ratio)
# With chunked loading, we can handle larger arrays without loading all into memory at once
MAX_VOLUME_SIZE_RATIO = float(os.environ.get("MAX_VOLUME_SIZE_RATIO", 16**3))


def _get_attr_any(attrs, keys):
    for k in keys:
        if k in attrs:
            return attrs[k]
    return None


def _parse_voxel_size(attrs) -> Optional[Tuple[float, ...]]:
    vs = _get_attr_any(attrs, ["voxel_size", "resolution", "scale"])
    if vs is None:
        return None
    return tuple(float(x) for x in vs)


def _parse_translation(attrs) -> Optional[Tuple[float, ...]]:
    tr = _get_attr_any(attrs, ["translation", "offset"])
    if tr is None:
        return None
    return tuple(float(x) for x in tr)


def _resize_pad_crop(
    image: np.ndarray, target_shape: Tuple[int, ...], pad_value=0
) -> np.ndarray:
    # center pad/crop like your resize_array
    arr_shape = image.shape
    resized = image

    pad_width = []
    for i in range(len(target_shape)):
        if arr_shape[i] < target_shape[i]:
            pad_before = (target_shape[i] - arr_shape[i]) // 2
            pad_after = target_shape[i] - arr_shape[i] - pad_before
            pad_width.append((pad_before, pad_after))
        else:
            pad_width.append((0, 0))

    if any(p > 0 for pads in pad_width for p in pads):
        resized = np.pad(resized, pad_width, mode="constant", constant_values=pad_value)

    slices = []
    for i in range(len(target_shape)):
        if resized.shape[i] > target_shape[i]:
            start = (resized.shape[i] - target_shape[i]) // 2
            end = start + target_shape[i]
            slices.append(slice(start, end))
        else:
            slices.append(slice(None))
    return resized[tuple(slices)]


[docs] @dataclass class MatchedCrop: path: str | UPath class_label: str target_voxel_size: Sequence[float] target_shape: Sequence[int] target_translation: Sequence[float] instance_classes: Sequence[str] semantic_threshold: float = 0.5 pad_value: float | int = 0 check_nans: bool = False def _is_instance(self) -> bool: return self.class_label in set(self.instance_classes) def _warn_nan(self, nan_count: int) -> None: logger.warning( f"{self.path}: {nan_count} NaN value(s) found in raw '{self.class_label}' array " "before resampling. NaNs will be treated as background (0) and may spread to " "neighboring voxels during interpolation, negatively affecting scores." ) def _select_non_ome_level( self, grp: zarr.Group ) -> Tuple[str, Optional[Tuple[float, ...]], Optional[Tuple[float, ...]]]: """ Heuristic: choose among arrays like s0/s1/... (or any array keys) by voxel_size attrs. Preference: pick the level whose voxel_size is <= target_voxel_size (finer or equal to target resolution) and closest to target; else closest overall. Returns (array_key, voxel_size, translation) """ keys = list(grp.array_keys()) if not keys: raise ValueError(f"No arrays found under {self.path}") tgt = np.asarray(self.target_voxel_size, dtype=float) best_key = None best_vs = None best_tr = None best_score = None for k in keys: arr = grp[k] vs = _parse_voxel_size(arr.attrs) tr = _parse_translation(arr.attrs) if vs is None: # treat as unknown; deprioritize score = 1e18 else: v = np.asarray(vs, dtype=float) if v.size != tgt.size: v = v[-tgt.size :] # prefer v <= tgt # Here: voxel_size larger => coarser. We want not coarser than target => v <= tgt # If your convention is reversed, adjust this rule. not_coarser = np.all(v <= tgt) dist = float(np.linalg.norm(v - tgt)) score = dist + (0.0 if not_coarser else 1e6) if best_score is None or score < best_score: best_score = score best_key = k best_vs = vs best_tr = tr assert best_key is not None return best_key, best_vs, best_tr def _check_size_ratio(self, source_shape: Tuple[int, ...]): tgt_size = np.prod(self.target_shape) if tgt_size == 0: raise ValueError( f"Invalid target shape {tuple(self.target_shape)}: product of dimensions is zero." ) src_size = np.prod(source_shape) ratio = src_size / tgt_size # Estimate memory usage (assuming float32) estimated_memory_mb = (src_size * BYTES_PER_FLOAT32) / (1024 * 1024) # Return whether we should use chunked loading return ratio, estimated_memory_mb def _should_use_chunked_loading( self, ratio: float, estimated_memory_mb: float ) -> bool: """ Determine if we should use chunked loading based on size ratio. Chunked loading is used when the array is large but within acceptable limits. """ # Use chunked loading if ratio is high but not exceeding the limit # This allows processing of larger arrays without exceeding memory if ratio > MAX_VOLUME_SIZE_RATIO: raise ValueError( f"Source array at {self.path} is too large compared to target shape: " f"ratio {ratio:.1f}x > {MAX_VOLUME_SIZE_RATIO}x limit. " f"Estimated memory: {estimated_memory_mb:.1f} MB. " f"Please downsample your predictions to a resolution closer to the target." ) # Use chunked loading for arrays that would use significant memory if estimated_memory_mb > 500: logger.debug( f"Array at {self.path} is large (estimated {estimated_memory_mb:.1f} MB). Using chunked loading to reduce memory usage." ) else: logger.debug( f"Array at {self.path} is estimated to use {estimated_memory_mb:.1f} MB, which is within memory limits. Loading normally." ) return estimated_memory_mb > 500 def _load_array_chunked( self, arr: zarr.Array, scale_factors: Tuple[float, ...] ) -> np.ndarray: """ Load and downsample a zarr array in chunks to reduce memory usage. Args: arr: The zarr array to load scale_factors: Scale factors for rescaling (in_vs / tgt_vs). When input has finer resolution than target (in_vs < tgt_vs), scale_factors < 1.0, causing rescale() to downsample. Example: in_vs=2nm, tgt_vs=8nm → scale_factors=0.25 → output is 0.25x input size Returns: The downsampled array as a numpy array """ logger.info( f"Loading and downsampling array in chunks with scale factors: {scale_factors}" ) # Calculate output shape after downsampling # Use round() to match skimage.transform.rescale's internal output-shape calculation, # which avoids off-by-one errors when input_size * scale_factor is a half-integer. output_shape = tuple( max(1, int(np.round(s * sf))) for s, sf in zip(arr.shape, scale_factors) ) output = np.zeros( output_shape, dtype=arr.dtype if self._is_instance() else np.float32 ) # Determine chunk size based on memory constraints # Target ~100MB per chunk in memory target_chunk_memory_mb = 100 chunk_voxels = int((target_chunk_memory_mb * 1024 * 1024) / BYTES_PER_FLOAT32) chunk_size_per_dim = max( 32, int(chunk_voxels ** (1 / 3)) ) # At least 32 voxels per dimension logger.info( f"Processing with chunk size: {chunk_size_per_dim} voxels per dimension" ) # Process array in chunks nan_count = 0 for z_start in range(0, arr.shape[0], chunk_size_per_dim): z_end = min(z_start + chunk_size_per_dim, arr.shape[0]) for y_start in range(0, arr.shape[1], chunk_size_per_dim): y_end = min(y_start + chunk_size_per_dim, arr.shape[1]) for x_start in range(0, arr.shape[2], chunk_size_per_dim): x_end = min(x_start + chunk_size_per_dim, arr.shape[2]) # Load chunk chunk = arr[z_start:z_end, y_start:y_end, x_start:x_end] if self.check_nans and np.issubdtype(chunk.dtype, np.floating): nan_count += int(np.isnan(chunk).sum()) # Downsample chunk if self._is_instance(): chunk_downsampled = rescale( chunk, scale_factors, order=0, mode="constant", preserve_range=True, ).astype(chunk.dtype) else: if chunk.dtype == bool: chunk = chunk.astype(np.float32) chunk_downsampled = rescale( chunk, scale_factors, order=1, mode="constant", preserve_range=True, ) # Don't threshold here, will be done at the end # Calculate output position # scale_factors represents the ratio of dimensions (output_size / input_size) # When downsampling (in_vs < tgt_vs), scale_factors < 1.0 out_z_start = int(z_start * scale_factors[0]) out_y_start = int(y_start * scale_factors[1]) out_x_start = int(x_start * scale_factors[2]) # Place downsampled chunk in output, clamping to actual chunk shape # to handle rounding differences between ceil (position math) and # round (skimage rescale internal) for half-integer scaled sizes. ds_z, ds_y, ds_x = chunk_downsampled.shape out_z_end_actual = min(out_z_start + ds_z, output_shape[0]) out_y_end_actual = min(out_y_start + ds_y, output_shape[1]) out_x_end_actual = min(out_x_start + ds_x, output_shape[2]) sl_z = out_z_end_actual - out_z_start sl_y = out_y_end_actual - out_y_start sl_x = out_x_end_actual - out_x_start output[ out_z_start:out_z_end_actual, out_y_start:out_y_end_actual, out_x_start:out_x_end_actual, ] = chunk_downsampled[:sl_z, :sl_y, :sl_x] if nan_count > 0: self._warn_nan(nan_count) # Convert back to bool if semantic (threshold once at the end) if not self._is_instance(): output = output > self.semantic_threshold return output def _load_source_array(self): """ Returns (image, input_voxel_size, input_translation, already_downsampled) where image is a numpy array and already_downsampled indicates if chunked downsampling was applied. """ try: ds = zarr.open(str(self.path), mode="r") except Exception as e: raise ValueError( f"Failed to open zarr at {self.path}. " f"Ensure the path points to a valid zarr array or group. " f"Error: {e}" ) logger.info(f"Loading from {self.path}, type: {type(ds).__name__}") # OME-NGFF multiscale if isinstance(ds, zarr.Group) and "multiscales" in ds.attrs: logger.info(f"Detected OME-NGFF multiscale format at {self.path}") try: img = CellMapImage( path=str(self.path), target_class=self.class_label, target_scale=self.target_voxel_size, target_voxel_shape=self.target_shape, pad=True, pad_value=self.pad_value, interpolation="nearest" if self._is_instance() else "linear", ) level = img._level_info[img.scale_level][0] level_path = UPath(self.path) / level # Extract input voxel size and translation from multiscales metadata input_voxel_size = None input_translation = None for d in ds.attrs["multiscales"][0]["datasets"]: if d["path"] == level: for t in d.get("coordinateTransformations", []): if t.get("type") == "scale": input_voxel_size = tuple(t["scale"]) elif t.get("type") == "translation": input_translation = tuple(t["translation"]) break arr = zarr.open(str(level_path), mode="r") ratio, estimated_memory_mb = self._check_size_ratio(arr.shape) use_chunked = self._should_use_chunked_loading( ratio, estimated_memory_mb ) if use_chunked: logger.warning( f"Large OME-NGFF array detected ({estimated_memory_mb:.1f} MB). " f"Loading entire array - CellMapImage should have already selected appropriate resolution level. " f"If memory issues occur, ensure predictions are saved at appropriate resolution." ) image = arr[:] return ( image, input_voxel_size, input_translation, False, ) # Not downsampled in chunks except Exception as e: raise ValueError( f"Failed to load OME-NGFF multiscale data from {self.path}. " f"Error: {e}" ) # Non-OME group multiscale OR single-scale array with attrs if isinstance(ds, zarr.Group): logger.info(f"Detected zarr Group (non-OME) at {self.path}") # If this group directly contains the label array (common): path points at an array node # zarr.open on an array path usually returns an Array, not Group. If we got Group, pick a level. try: key, vs, tr = self._select_non_ome_level(ds) arr = ds[key] ratio, estimated_memory_mb = self._check_size_ratio(arr.shape) use_chunked = self._should_use_chunked_loading( ratio, estimated_memory_mb ) if use_chunked and vs is not None: logger.info( f"Using chunked loading for large non-OME array ({estimated_memory_mb:.1f} MB)" ) # Calculate scale factors for downsampling in_vs = np.asarray(vs, dtype=float) tgt_vs = np.asarray(self.target_voxel_size, dtype=float) if in_vs.size != tgt_vs.size: in_vs = in_vs[-tgt_vs.size :] if not np.allclose(in_vs, tgt_vs): # Downsampling needed - use chunked approach # scale_factors = in_vs / tgt_vs # When in_vs < tgt_vs (fine→coarse), scale_factors < 1.0, causing rescale to downsample scale_factors = in_vs / tgt_vs image = self._load_array_chunked(arr, scale_factors) return ( image, self.target_voxel_size, tr, True, ) # Already downsampled else: # No downsampling needed, load normally image = arr[:] return image, vs, tr, False else: image = arr[:] return image, vs, tr, False except Exception as e: raise ValueError( f"Failed to load from non-OME zarr Group at {self.path}. " f"Expected to find resolution levels (e.g., 's0', 's1') with voxel_size metadata, " f"but encountered an error: {e}" ) # Single-scale zarr array if isinstance(ds, zarr.Array): logger.info(f"Detected single-scale zarr Array at {self.path}") ratio, estimated_memory_mb = self._check_size_ratio(ds.shape) vs = _parse_voxel_size(ds.attrs) tr = _parse_translation(ds.attrs) if vs is None: logger.warning( f"No voxel_size metadata found at {self.path}. " f"Will attempt to match by shape only, which may produce incorrect alignment." ) if tr is None: logger.warning( f"No translation metadata found at {self.path}. " f"Assuming zero offset, which may produce incorrect alignment." ) use_chunked = self._should_use_chunked_loading(ratio, estimated_memory_mb) if use_chunked and vs is not None: logger.info( f"Using chunked loading for large single-scale array ({estimated_memory_mb:.1f} MB)" ) # Calculate scale factors for downsampling in_vs = np.asarray(vs, dtype=float) tgt_vs = np.asarray(self.target_voxel_size, dtype=float) if in_vs.size != tgt_vs.size: in_vs = in_vs[-tgt_vs.size :] if not np.allclose(in_vs, tgt_vs): # Downsampling needed - use chunked approach # scale_factors = in_vs / tgt_vs # When in_vs < tgt_vs (fine→coarse), scale_factors < 1.0, causing rescale to downsample scale_factors = in_vs / tgt_vs image = self._load_array_chunked(ds, scale_factors) return ( image, self.target_voxel_size, tr, True, ) # Already downsampled else: # No downsampling needed, load normally image = ds[:] return image, vs, tr, False else: image = ds[:] return image, vs, tr, False raise ValueError(f"Unsupported zarr node type at {self.path}: {type(ds)}")
[docs] def load_aligned(self) -> np.ndarray: """ Return full aligned volume in target space (target_shape). """ tgt_vs = np.asarray(self.target_voxel_size, dtype=float) tgt_shape = tuple(int(x) for x in self.target_shape) tgt_tr = np.asarray(self.target_translation, dtype=float) image, input_voxel_size, input_translation, already_downsampled = ( self._load_source_array() ) if self.check_nans and np.issubdtype(image.dtype, np.floating): nan_count = int(np.isnan(image).sum()) if nan_count > 0: self._warn_nan(nan_count) # Resample if needed (skip if already downsampled in chunks) if not already_downsampled and input_voxel_size is not None: in_vs = np.asarray(input_voxel_size, dtype=float) if in_vs.size != tgt_vs.size: in_vs = in_vs[-tgt_vs.size :] if not np.allclose(in_vs, tgt_vs): scale_factors = in_vs / tgt_vs if self._is_instance(): image = rescale( image, scale_factors, order=0, mode="constant", preserve_range=True, ).astype(image.dtype) else: if image.dtype == bool: image = image.astype(np.float32) imgf = rescale( image, scale_factors, order=1, mode="constant", preserve_range=True, ) image = imgf > self.semantic_threshold elif not already_downsampled and image.shape != tgt_shape: # If no voxel size info, fall back to center crop/pad image = _resize_pad_crop(image, tgt_shape, pad_value=self.pad_value) # Compute relative offset in voxel units if input_translation is not None: in_tr = np.asarray(input_translation, dtype=float) if in_tr.size != tgt_tr.size: in_tr = in_tr[-tgt_tr.size :] # snap to voxel grid adjusted_in_tr = (in_tr // tgt_vs) * tgt_vs rel = (np.abs(adjusted_in_tr - tgt_tr) // tgt_vs) * np.sign( adjusted_in_tr - tgt_tr ) rel = rel.astype(int) else: rel = np.zeros(len(tgt_shape), dtype=int) # Translate + crop/pad into destination if any(rel != 0) or image.shape != tgt_shape: result = np.zeros(tgt_shape, dtype=image.dtype) input_slices = [] output_slices = [] for d in range(len(tgt_shape)): if rel[d] < 0: input_start = abs(rel[d]) output_start = 0 input_end = min(input_start + tgt_shape[d], image.shape[d]) length = input_end - input_start output_end = output_start + length else: input_start = 0 output_start = rel[d] output_end = min(tgt_shape[d], image.shape[d] + output_start) length = output_end - output_start input_end = input_start + length if length <= 0: return result input_slices.append(slice(int(input_start), int(input_end))) output_slices.append(slice(int(output_start), int(output_end))) result[tuple(output_slices)] = image[tuple(input_slices)] return result return image