trainer module
- class trainer.Trainer(model, train_loader, val_loader, loss_fn='default', optimizer='default', learning_rate=0.0001, device='default')[source]
Bases:
objectTrainer class for training and validating a PyTorch model on a dataset.
This class provides methods to train the model on a training dataset, validate it on a validation dataset, and save/load model checkpoints for resuming training or inference. During training, the model is saved at regular intervals and the best model based on the validation loss is saved separately. During validation, the model’s predictions can be saved as images for visualization.
- model
Model to be trained and validated.
- Type:
torch.nn.Module
- optimizer
Optimizer to be used for training.
- Type:
torch.optim.Optimizer
- train_loader
DataLoader for the training dataset.
- Type:
torch.utils.data.DataLoader
- val_loader
DataLoader for the validation dataset.
- Type:
torch.utils.data.DataLoader
- model_is_binary
True if the model is a binary segmentation model, False if it is a multi-class segmentation model.
- Type:
bool
- train_loss
List of training losses for each epoch.
- Type:
list
- val_loss
List of validation losses for each epoch.
- Type:
list
- val_accuracy
List of accuracies on the validation set for each epoch.
- Type:
list
- val_dice
List of Dice scores on the validation set for each epoch.
- Type:
list
- loss_fn
Loss function to be used for training.
- Type:
torch.nn.Module
- device
Device to run the model on (‘cuda’ or ‘cpu’).
- Type:
str
- scaler
Gradient scaler to prevent underflow and overflow during training.
- Type:
torch.cuda.amp.GradScaler
- train_step()[source]
Performs a single training step (forward and backward pass for one epoch) on the training dataset.
- val_step(save_img_dir=None)[source]
Performs validation on the validation dataset and optionally saves images of predictions.
- save_checkpoint(file_path='checkpoints/last.pth')[source]
Save the model and optimizer state to a checkpoint file.
- load_checkpoint(path)[source]
Load the model and optimizer state from a checkpoint file or directory.
- train(num_epochs, save_interval=5, early_stop_patience=None, save_img=True)[source]
Train the model for a specified number of epochs, saving the best model based on the validation loss.
- load_checkpoint(path)[source]
Load the model and optimizer state from a checkpoint file or directory. If a directory is provided, the most recent checkpoint file in the directory is loaded.
- Parameters:
path (str) – Path to the checkpoint file or directory containing the checkpoint files.
- Raises:
FileNotFoundError – If the file does not exist or the directory is empty.
- save_checkpoint(file_path='checkpoints/last.pth')[source]
Save the model and optimizer state to a checkpoint file.
- Parameters:
file_path (str, optional) – Path to the checkpoint file. Default is ‘checkpoints/last.pth’.
- train(num_epochs, save_interval=5, early_stop_patience=None, save_val_img=True, save_train_img=False)[source]
Train the model for a specified number of epochs, saving the last model and the best model based on the validation loss.
- Parameters:
num_epochs (int) – Number of epochs to train the model for.
save_interval (int, optional) – Interval at which to save the model checkpoint. Default is 5. If set to 0, only the best and last models are saved. If set to 1, the model is saved after every epoch.
early_stop_patience (int, optional) – If provided, training will stop if the validation loss does not improve after this number of epochs.
save__val_img (bool, optional) – If True, save the input images, predictions, and masks during validation. Default is True.
save_train_img (bool, optional) – If True, save the input images, predictions, and masks during training. Default is False.
- Returns:
train_loss (list) – List of training losses for each epoch.
val_loss (list) – List of validation losses for each epoch.
val_accuracy (list) – List of accuracies on the validation set for each epoch.
val_dice (list) – List of Dice scores on the validation set for each epoch.
- train_step(save_img_dir=None)[source]
Performs a single training step (forward and backward pass for one epoch) on the training dataset. Optionally saves images of the input, predictions, and masks during training.
- Parameters:
save_img_dir (str, optional) – If provided, directory where the input images, predictions, and masks will be saved.
- Returns:
loss – Average loss over the training set.
- Return type:
float
- val_step(save_img_dir=None)[source]
Performs validation on the validation dataset. Optionally saves images of the input, predictions, and masks during validation.
- Parameters:
save_img_dir (str, optional) – If provided, directory where the input images, predictions, and masks will be saved.
- Returns:
loss (float) – Average loss over the validation set.
accuracy (float) – Accuracy of the model on the validation set (ratio of correct predictions to total pixels).
dice_score (float) – Average Dice score of the model on the validation set.