dacapo.experiments.trainers
Submodules
- dacapo.experiments.trainers.dummy_trainer
- dacapo.experiments.trainers.dummy_trainer_config
- dacapo.experiments.trainers.gp_augments
- dacapo.experiments.trainers.gunpowder_trainer
- dacapo.experiments.trainers.gunpowder_trainer_config
- dacapo.experiments.trainers.optimizers
- dacapo.experiments.trainers.trainer
- dacapo.experiments.trainers.trainer_config
Classes
Trainer Abstract Base Class |
|
A class to represent the Trainer Configurations. |
|
This is just a dummy trainer config used for testing. None of the |
|
This class is used to train a model using dummy data and is used for testing purposes. It contains attributes |
|
This class is used to configure a Gunpowder Trainer. It contains attributes related to trainer type, |
|
GunpowderTrainer class for training a model using gunpowder. This class is a subclass of the Trainer class. It |
|
Base class for gunpowder augment configurations. Each subclass of a Augment |
Package Contents
- class dacapo.experiments.trainers.Trainer
Trainer Abstract Base Class
This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model.
- iteration
The number of training iterations.
- Type:
int
- batch_size
The size of the training batch.
- Type:
int
- learning_rate
The learning rate for the optimizer.
- Type:
float
- create_optimizer(model
Model) -> torch.optim.Optimizer: Creates an optimizer for the model.
- iterate(num_iterations
int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations.
- can_train(datasets
List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets.
- build_batch_provider(datasets
List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components.
Note
The Trainer class is an abstract class that cannot be instantiated directly. It is meant to be subclassed.
- iteration: int
- batch_size: int
- learning_rate: float
- abstract create_optimizer(model: dacapo.experiments.model.Model) torch.optim.Optimizer
Creates an optimizer for the model.
- Parameters:
model (Model) – The model for which the optimizer will be created.
- Returns:
The optimizer created for the model.
- Return type:
torch.optim.Optimizer
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> optimizer = trainer.create_optimizer(model)
Note
This method must be implemented by the subclass.
- abstract iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer: torch.optim.Optimizer, device: torch.device) Iterator[dacapo.experiments.training_iteration_stats.TrainingIterationStats]
Performs a number of training iterations.
- Parameters:
num_iterations (int) – Number of training iterations.
model (Model) – The model to be trained.
optimizer (torch.optim.Optimizer) – The optimizer for the model.
device (torch.device) – The device (GPU/CPU) where the model will be trained.
- Returns:
An iterator of the training statistics.
- Return type:
Iterator[TrainingIterationStats]
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): >>> print(iteration_stats)
Note
This method must be implemented by the subclass.
- abstract can_train(datasets: List[dacapo.experiments.datasplits.datasets.Dataset]) bool
Checks if the trainer can train with a specific set of datasets.
Some trainers may have specific requirements for their training datasets.
- Parameters:
datasets (List[Dataset]) – The training datasets.
- Returns:
True if the trainer can train on the given datasets, False otherwise.
- Return type:
bool
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> can_train = trainer.can_train(datasets)
Note
This method must be implemented by the subclass.
- abstract build_batch_provider(datasets: List[dacapo.experiments.datasplits.datasets.Dataset], model: dacapo.experiments.model.Model, task: dacapo.experiments.tasks.task.Task, snapshot_container: dacapo.store.array_store.LocalContainerIdentifier) None
Initializes the training pipeline using various components.
This method uses the datasets, model, task, and snapshot_container to set up the training pipeline.
- Parameters:
datasets (List[Dataset]) – The datasets to pull data from.
model (Model) – The model to inform the pipeline of required input/output sizes.
task (Task) – The task to transform ground truth into target.
snapshot_container (LocalContainerIdentifier) – Defines where snapshots will be saved.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> trainer.build_batch_provider(datasets, model, task, snapshot_container)
Note
This method must be implemented by the subclass.
- class dacapo.experiments.trainers.TrainerConfig
A class to represent the Trainer Configurations.
It is the base class for trainer configurations. Each subclass of a Trainer should have a specific config class derived from TrainerConfig.
- name
A unique name for this trainer.
- Type:
str
- batch_size
The batch size to be used during training.
- Type:
int
- learning_rate
The learning rate of the optimizer.
- Type:
float
- verify() Tuple[bool, str]
Verify whether this TrainerConfig is valid or not.
Note
The TrainerConfig class is an abstract class that cannot be instantiated directly. It is meant to be subclassed.
- name: str
- batch_size: int
- learning_rate: float
- verify() Tuple[bool, str]
Verify whether this TrainerConfig is valid or not. A TrainerConfig is considered valid if it has a valid batch size and learning rate.
- Returns:
A tuple containing a boolean indicating whether the TrainerConfig is valid and a message explaining why.
- Return type:
tuple
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> valid, message = trainer_config.verify() >>> valid True >>> message "No validation for this Trainer"
Note
This method must be implemented by the subclass.
- class dacapo.experiments.trainers.DummyTrainerConfig
This is just a dummy trainer config used for testing. None of the attributes have any particular meaning. This is just to test the trainer and the trainer config.
- mirror_augment
A boolean value indicating whether to use mirror augmentation or not.
- Type:
bool
- verify(self) Tuple[bool, str]
This method verifies the DummyTrainerConfig object.
- trainer_type
- mirror_augment: bool
- verify() Tuple[bool, str]
Verify the DummyTrainerConfig object.
- Returns:
- A tuple containing a boolean value indicating whether the DummyTrainerConfig object is valid
and a string containing the reason why the object is invalid.
- Return type:
Tuple[bool, str]
Examples
>>> valid, reason = trainer_config.verify()
- class dacapo.experiments.trainers.DummyTrainer(trainer_config)
This class is used to train a model using dummy data and is used for testing purposes. It contains attributes related to learning rate, batch size, and mirror augment. It also contains methods to create an optimizer, iterate over the training data, build a batch provider, and check if the trainer can train on the given data split. This class contains methods to enter and exit the context manager. The iterate method yields training iteration statistics.
- learning_rate
The learning rate to use.
- Type:
float
- batch_size
The batch size to use.
- Type:
int
- mirror_augment
A boolean value indicating whether to use mirror augmentation or not.
- Type:
bool
- __init__(self, trainer_config)
This method initializes the DummyTrainer object.
- create_optimizer(self, model)
This method creates an optimizer for the given model.
- iterate(self, num_iterations
int, model, optimizer, device): This method iterates over the training data for the specified number of iterations.
- build_batch_provider(self, datasplit, architecture, task, snapshot_container)
This method builds a batch provider for the given data split, architecture, task, and snapshot container.
- can_train(self, datasplit)
This method checks if the trainer can train on the given data split.
- __enter__(self)
This method enters the context manager.
- __exit__(self, exc_type, exc_val, exc_tb)
This method exits the context manager.
Note
The iterate method yields TrainingIterationStats.
- iteration = 0
- learning_rate
- batch_size
- mirror_augment
- create_optimizer(model)
Create an optimizer for the given model.
- Parameters:
model (Model) – The model to optimize.
- Returns:
The optimizer object.
- Return type:
torch.optim.Optimizer
Examples
>>> optimizer = create_optimizer(model)
- iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer, device)
Iterate over the training data for the specified number of iterations.
- Parameters:
num_iterations (int) – The number of iterations to perform.
model (Model) – The model to train.
optimizer (torch.optim.Optimizer) – The optimizer to use.
device (torch.device) – The device to perform the computations on.
- Yields:
TrainingIterationStats – The training iteration statistics.
- Raises:
ValueError – If the number of iterations is less than or equal to zero.
Examples
>>> for stats in iterate(num_iterations, model, optimizer, device): >>> print(stats)
- build_batch_provider(datasplit, architecture, task, snapshot_container)
Build a batch provider for the given data split, architecture, task, and snapshot container.
- Parameters:
datasplit (DataSplit) – The data split to use.
architecture (Architecture) – The architecture to use.
task (Task) – The task to perform.
snapshot_container (SnapshotContainer) – The snapshot container to use.
- Returns:
The batch provider object.
- Return type:
BatchProvider
- Raises:
ValueError – If the task loss is not set.
Examples
>>> batch_provider = build_batch_provider(datasplit, architecture, task, snapshot_container)
- can_train(datasplit)
Check if the trainer can train on the given data split.
- Parameters:
datasplit (DataSplit) – The data split to check.
- Returns:
True if the trainer can train on the data split, False otherwise.
- Return type:
bool
- Raises:
NotImplementedError – If the method is not implemented.
Examples
>>> can_train(datasplit)
- class dacapo.experiments.trainers.GunpowderTrainerConfig
This class is used to configure a Gunpowder Trainer. It contains attributes related to trainer type, number of data fetchers, augmentations to apply, snapshot interval, minimum masked value, and a boolean value indicating whether to clip raw or not.
- trainer_type
This is the type of the trainer which is set to GunpowderTrainer by default.
- Type:
class
- num_data_fetchers
This is the number of CPU workers who will be dedicated to fetch and process the data.
- Type:
int
- augments
This is the list of augments to apply during the training.
- Type:
List[AugmentConfig]
- snapshot_interval
This is the number of iterations after which a new snapshot should be saved.
- Type:
Optional[int]
- min_masked
This is the minimum masked value.
- Type:
Optional[float]
- clip_raw
This is a boolean value indicating if the raw data should be clipped to the size of the GT data or not.
- Type:
bool
- trainer_type
- num_data_fetchers: int
- augments: List[dacapo.experiments.trainers.gp_augments.AugmentConfig]
- clip_raw: bool
- class dacapo.experiments.trainers.GunpowderTrainer(trainer_config)
GunpowderTrainer class for training a model using gunpowder. This class is a subclass of the Trainer class. It implements the abstract methods defined in the Trainer class. The GunpowderTrainer class is used to train a model using gunpowder, a data loading and augmentation library. It is used to train a model on a dataset using a specific task.
- learning_rate
The learning rate for the optimizer.
- Type:
float
- batch_size
The size of the training batch.
- Type:
int
- num_data_fetchers
The number of data fetchers.
- Type:
int
- print_profiling
The number of iterations after which to print profiling stats.
- Type:
int
- snapshot_iteration
The number of iterations after which to save a snapshot.
- Type:
int
- min_masked
The minimum value of the mask.
- Type:
float
- augments
The list of augmentations to apply to the data.
- Type:
List[Augment]
- mask_integral_downsample_factor
The downsample factor for the mask integral.
- Type:
int
- clip_raw
Whether to clip the raw data.
- Type:
bool
- scheduler
The learning rate scheduler.
- Type:
torch.optim.lr_scheduler.LinearLR
- create_optimizer(model
Model) -> torch.optim.Optimizer: Creates an optimizer for the model.
- build_batch_provider(datasets
List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components.
- iterate(num_iterations
int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations.
- next() Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray]
Fetches the next batch of data.
- __enter__() GunpowderTrainer
Enters the context manager.
- can_train(datasets
List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets.
Note
The GunpowderTrainer class is a subclass of the Trainer class. It is used to train a model using gunpowder.
- iteration = 0
- learning_rate
- batch_size
- num_data_fetchers
- print_profiling = 100
- snapshot_iteration
- min_masked
- augments
- mask_integral_downsample_factor = 4
- clip_raw
- gt_min_reject
- scheduler = None
- create_optimizer(model)
Creates an optimizer for the model.
- Parameters:
model (Model) – The model for which the optimizer will be created.
- Returns:
The optimizer created for the model.
- Return type:
torch.optim.Optimizer
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> optimizer = trainer.create_optimizer(model)
- build_batch_provider(datasets, model, task, snapshot_container=None)
Initializes the training pipeline using various components.
- Parameters:
datasets (List[Dataset]) – The list of datasets.
model (Model) – The model to be trained.
task (Task) – The task to be performed.
snapshot_container (LocalContainerIdentifier) – The snapshot container.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> trainer.build_batch_provider(datasets, model, task, snapshot_container)
- iterate(num_iterations, model, optimizer, device)
Performs a number of training iterations.
- Parameters:
num_iterations (int) – The number of training iterations.
model (Model) – The model to be trained.
optimizer (torch.optim.Optimizer) – The optimizer for the model.
device (torch.device) – The device (GPU/CPU) where the model will be trained.
- Returns:
An iterator of the training statistics.
- Return type:
Iterator[TrainingIterationStats]
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): >>> print(iteration_stats)
- next()
Fetches the next batch of data.
- Returns:
A tuple containing the raw data, ground truth data, target data, weight data, and mask data.
- Return type:
Tuple[Array, Array, Array, Array, Array]
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> raw, gt, target, weight, mask = trainer.next()
- can_train(datasets) bool
Checks if the trainer can train with a specific set of datasets.
- Parameters:
datasets (List[Dataset]) – The list of datasets.
- Returns:
True if the trainer can train with the datasets, False otherwise.
- Return type:
bool
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> can_train = trainer.can_train(datasets)
- visualize_pipeline(bind_address='0.0.0.0', bind_port=0)
Visualizes the pipeline for the run, including all produced arrays.
- Parameters:
bind_address – str Bind address for Neuroglancer webserver
bind_port – int Bind port for Neuroglancer webserver
- class dacapo.experiments.trainers.AugmentConfig
Base class for gunpowder augment configurations. Each subclass of a Augment should have a corresponding config class derived from AugmentConfig.
- _raw_key
Key for raw data. Not used in this implementation. Defaults to None.
- _gt_key
Key for ground truth data. Not used in this implementation. Defaults to None.
- _mask_key
Key for mask data. Not used in this implementation. Defaults to None.
- node(_raw_key=None, _gt_key=None, _mask_key=None)
Get a gp.Augment node.
- abstract node(raw_key: gunpowder.ArrayKey, gt_key: gunpowder.ArrayKey, mask_key: gunpowder.ArrayKey) gunpowder.BatchFilter
Get a gunpowder augment node.
- Parameters:
raw_key (gp.ArrayKey) – Key for raw data.
gt_key (gp.ArrayKey) – Key for ground truth data.
mask_key (gp.ArrayKey) – Key for mask data.
- Returns:
Augmentation node which can be incorporated in the pipeline.
- Return type:
gunpowder.BatchFilter
- Raises:
NotImplementedError – This method is not implemented.
Examples
>>> node = augment_config.node(raw_key, gt_key, mask_key)