Source code for cellmap_segmentation_challenge.models.model_load
# Imports
from glob import glob
import os
import numpy as np
import torch
from tensorboard.backend.event_processing import event_accumulator
from upath import UPath
from cellmap_segmentation_challenge.utils import get_formatted_fields, format_string
[docs]
def load_latest(search_path, model):
"""
Load the latest checkpoint from a directory into a model (in place).
Parameters
----------
search_path : str
The path to search for checkpoints.
model : torch.nn.Module
The model to load the checkpoint into.
"""
# Check if there are any files matching the checkpoint save path
checkpoint_files = glob(format_string(search_path, {"epoch": "*"}))
if checkpoint_files:
# If there are checkpoints, sort by modification time and get the latest
checkpoint_files.sort(key=os.path.getmtime, reverse=True)
# Get the latest checkpoint
newest_checkpoint = checkpoint_files[0]
# Extract the epoch from the filename
epoch = int(
get_formatted_fields(newest_checkpoint, search_path, ["{epoch}"])["epoch"]
)
# Loads the most recent checkpoint into the model and prints out the file path
try:
model.load_state_dict(
torch.load(newest_checkpoint, weights_only=True), strict=False
)
print(f"Loaded latest checkpoint: {newest_checkpoint}")
return epoch
except Exception as e:
print(f"Error loading checkpoint: {newest_checkpoint}")
print(e)
# If there are no checkpoints, or an error occurs, return None
return None
[docs]
def load_best_val(logs_save_path, model_save_path, model, low_is_best=True):
"""
Load the model weights with the best validation score from a directory into an existing model object in place.
Parameters
----------
logs_save_path : str
The path to the directory with the tensorboard logs.
model_save_path : str
The path to the model checkpoints.
model : torch.nn.Module
The model to load the checkpoint into.
low_is_best : bool
Whether a lower validation score is better.
"""
# Load the event file
try:
print("Loading events files, may take a minute")
event_acc = event_accumulator.EventAccumulator(logs_save_path)
event_acc.Reload()
except FileNotFoundError:
print("No events file found, skipping")
return None
# Get validation scores
tags = event_acc.Tags()["scalars"]
if "validation" in tags:
events = event_acc.Scalars("validation")
scores = [event.value for event in events]
# Find the best score
if low_is_best:
best_epoch = np.argmin(scores)
else:
best_epoch = np.argmax(scores)
# Load the model with the best validation score
checkpoint_path = UPath(model_save_path.format(epoch=best_epoch)).path
checkpoint = torch.load(checkpoint_path, weights_only=True)
try:
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded best validation checkpoint from epoch: {best_epoch}")
return best_epoch
except Exception as e:
print(f"Error loading checkpoint: {checkpoint_path}")
print(e)
return None
else:
print("No validation scores found, skipping")
return None
[docs]
def get_best_val_epoch(logs_save_path, low_is_best=True):
"""
Get the epoch with the best validation score from tensorboard logs.
Parameters
----------
logs_save_path : str
The path to the directory with the tensorboard logs.
low_is_best : bool
Whether a lower validation score is better.
Returns
-------
int or None
The epoch number with the best validation score, or None if not found.
"""
try:
event_acc = event_accumulator.EventAccumulator(logs_save_path)
event_acc.Reload()
except:
print("No events file found, skipping")
return None
tags = event_acc.Tags()["scalars"]
if "validation" in tags:
events = event_acc.Scalars("validation")
scores = [event.value for event in events]
if low_is_best:
best_epoch = np.argmin(scores)
else:
best_epoch = np.argmax(scores)
return best_epoch
else:
print("No validation scores found, skipping")
return None
[docs]
def get_latest_checkpoint_epoch(model_save_path):
"""
Get the latest checkpoint epoch from a directory.
Parameters
----------
model_save_path : str
The path to the directory with the model checkpoints.
Returns
-------
int or None
The epoch number of the latest checkpoint, or None if not found.
"""
# Check if there are any files matching the checkpoint save path
checkpoint_files = glob(format_string(model_save_path, {"epoch": "*"}))
if checkpoint_files:
# If there are checkpoints, sort by modification time and get the latest
checkpoint_files.sort(key=os.path.getmtime, reverse=True)
# Get the latest checkpoint
newest_checkpoint = checkpoint_files[0]
# Extract the epoch from the filename
epoch = int(
get_formatted_fields(newest_checkpoint, model_save_path, ["{epoch}"])[
"epoch"
]
)
return epoch
# If there are no checkpoints, return None
return None