cellmap_segmentation_challenge.train

Contents

cellmap_segmentation_challenge.train#

cellmap_segmentation_challenge.train(config_path: str)[source]#

Train a model using the configuration file at the specified path. The model checkpoints and training logs, as well as the datasets used for training, will be saved to the paths specified in the configuration file.

Parameters:

config_path (str) – Path to the configuration file to use for training the model. This file should be a Python file that defines the hyperparameters and other configurations for training the model. This may include: - model_save_path: Path to save the model checkpoints. Default is ‘checkpoints/{model_name}_{epoch}.pth’. - logs_save_path: Path to save the logs for tensorboard. Default is ‘tensorboard/{model_name}’. Training progress may be monitored by running tensorboard –logdir <logs_save_path> in the terminal. - datasplit_path: Path to the datasplit file that defines the train/val split the dataloader should use. Default is ‘datasplit.csv’. - validation_prob: Proportion of the datasets to use for validation. This is used if the datasplit CSV specified by datasplit_path does not already exist. Default is 0.3. - learning_rate: Learning rate for the optimizer. Default is 0.0001. - batch_size: Batch size for the dataloader. Default is 8. - input_array_info: Dictionary containing the shape and scale of the input data. Default is {‘shape’: (1, 128, 128), ‘scale’: (8, 8, 8)}. - target_array_info: Dictionary containing the shape and scale of the target data. Default is to use input_array_info. - epochs: Number of epochs to train the model for. Default is 1000. - iterations_per_epoch: Number of iterations per epoch. Each iteration includes an independently generated random batch from the training set. Default is 1000. - random_seed: Random seed for reproducibility. Default is 42. - classes: List of classes to train the model to predict. This will be reflected in the data included in the datasplit, if generated de novo after calling this script. Default is [‘nuc’, ‘er’]. - model_name: Name of the model to use. If the config file constructs the PyTorch model, this name can be anything. If the config file does not construct the PyTorch model, the model_name will need to specify which included architecture to use. This includes ‘2d_unet’, ‘2d_resnet’, ‘3d_unet’, ‘3d_resnet’, and ‘vitnet’. Default is ‘2d_unet’. See the models module README.md for more information. - model_to_load: Name of the pre-trained model to load. Default is the same as model_name. - model_kwargs: Dictionary of keyword arguments to pass to the model constructor. Default is {}. If the PyTorch model is passed, this will be ignored. See the models module README.md for more information. - model: PyTorch model to use for training. If this is provided, the model_name and model_to_load can be any string. Default is None. - load_model: Which model checkpoint to load if it exists. Options are ‘latest’ or ‘best’. If no checkpoints exist, will silently use the already initialized model. Default is ‘latest’. - spatial_transforms: Dictionary of spatial transformations to apply to the training data. Default is {‘mirror’: {‘axes’: {‘x’: 0.5, ‘y’: 0.5}}, ‘transpose’: {‘axes’: [‘x’, ‘y’]}, ‘rotate’: {‘axes’: {‘x’: [-180, 180], ‘y’: [-180, 180]}}}. See the dataloader module documentation for more information.

Return type:

None