Custom Model

This guide explains how to run your own custom model using the CellMap Flow CLI via the script interface.

Overview

To run a custom model:

  1. Create a Python script that: - Loads and prepares your model. - Optionally defines how to process each chunk (advanced).

  2. Use the CLI to launch the pipeline.

The CLI command looks like:

cellmap_flow script -s /path/to/your_script.py -d /path/to/input_data.zarr -q gpu_h100 -P cellmap

Minimum Requirements for Custom Scripts

When using cellmap_flow script, your Python script must define a configuration object (typically the global scope) that satisfies the following:

Required Attributes

Your script must define the following global-level variables or attributes:

  • model (optional): PyTorch model instance. Required if predict is not defined.

  • predict (optional): Custom callable to run predictions. Required if model is not defined.

  • read_shape (required): The shape of the input block (in world units or voxels).

  • write_shape (required): The shape of the output block (i.e., prediction size).

  • input_voxel_size (required): Size of a voxel in the input data.

  • output_voxel_size (required): Size of a voxel in the model output.

  • output_channels (required): Number of channels in the output prediction.

  • block_shape (required): Shape of each output block (must match write_shape + channels).

Testing Your Script

You can test your script locally using:

cellmap_flow script-server-check -s /path/to/your_script.py -d /path/to/input.zarr

This will simulate a small 2x2x2 chunk to ensure your setup works correctly.

Basic Script Template (PyTorch)

If you’re using PyTorch and your model is compatible with direct inference via .forward() or .eval(), you do not need to define a process_chunk function.

# pip install fly-organelles
from fly_organelles.model import StandardUnet
from funlib.geometry import Coordinate
import torch
import numpy as np

# Voxel size and chunk shape
input_voxel_size = (8, 8, 8)
output_voxel_size = Coordinate((8, 8, 8))
read_shape = Coordinate((178, 178, 178)) * Coordinate(input_voxel_size)
write_shape = Coordinate((56, 56, 56)) * Coordinate(input_voxel_size)

def load_eval_model(num_labels, checkpoint_path):
    model_backbone = StandardUnet(num_labels)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
    model_backbone.load_state_dict(checkpoint["model_state_dict"])
    model = torch.nn.Sequential(model_backbone, torch.nn.Sigmoid())
    model.to(device)
    model.eval()
    return model

CHECKPOINT_PATH = "/path/to/your/model_checkpoint"
output_channels = 8  # Set according to your model
model = load_eval_model(output_channels, CHECKPOINT_PATH)
block_shape = np.array((56, 56, 56, output_channels))

Note

You must define the model and block_shape variables in the global scope. CellMap Flow will use these automatically.

Advanced Usage: Custom Process Function (TensorFlow)

For advanced use cases or non-standard frameworks like TensorFlow v1, define a process_chunk function that handles:

  • Rescaling input

  • Feeding to model

  • Retrieving and postprocessing output

import tensorflow.compat.v1 as tf
import os, json
import numpy as np
from funlib.geometry import Coordinate, Roi
from cellmap_flow.image_data_interface import ImageDataInterface

# Define voxel sizes and context
voxel_size = Coordinate((8, 8, 8))
output_voxel_size = Coordinate((8, 8, 8))
read_shape = Coordinate((268, 268, 268)) * voxel_size
write_shape = Coordinate((164, 164, 164)) * output_voxel_size
context = (read_shape - write_shape) / 2

output_channels = 10
block_shape = np.array((164, 164, 164, output_channels))

# Load TensorFlow model
def load_eval_model(setup_dir, checkpoint):
    graph = tf.Graph()
    session = tf.Session(graph=graph)
    with graph.as_default():
        meta_graph_file = os.path.join(setup_dir, "config.meta")
        saver = tf.train.import_meta_graph(meta_graph_file)
        saver.restore(session, os.path.join(setup_dir, checkpoint))
    return session

setup_dir = "/path/to/tf_model_dir"
checkpoint = "train_net_checkpoint_400000"
session = load_eval_model(setup_dir, checkpoint)

def get_tensor_names(setup_dir, inputs, outputs):
    with open(os.path.join(setup_dir, "config.json"), "r") as f:
        net_config = json.load(f)
    return [net_config[it] for it in inputs], [net_config[ot] for ot in outputs]

def rescale_data(input_array, min_val, max_val):
    return (2.0 * (input_array - min_val) / (max_val - min_val)) - 1.0

def process_lsd(chunk, session, input_tensorname, output_tensorname):
    input_data = rescale_data(chunk, 158, 233)
    result = session.run(
        {ot: ot for ot in output_tensorname},
        feed_dict={input_tensorname[0]: input_data}
    )
    return (result[output_tensorname[0]].clip(0, 1) * 255).astype(np.uint8)

def process_chunk(idi: ImageDataInterface, input_roi: Roi):
    input_roi = input_roi.grow(context, context)
    chunk = idi.to_ndarray_ts(input_roi)
    input_tensor, output_tensor = get_tensor_names(setup_dir, ["raw"], ["embedding"])
    return process_lsd(chunk, session, input_tensor, output_tensor)