Source code for trainer

import os
import torch
import torchvision
from tqdm import tqdm


[docs] class Trainer: """Trainer class for training and validating a PyTorch model on a dataset. This class provides methods to train the model on a training dataset, validate it on a validation dataset, and save/load model checkpoints for resuming training or inference. During training, the model is saved at regular intervals and the best model based on the validation loss is saved separately. During validation, the model's predictions can be saved as images for visualization. Attributes ---------- model : torch.nn.Module Model to be trained and validated. optimizer : torch.optim.Optimizer Optimizer to be used for training. train_loader : torch.utils.data.DataLoader DataLoader for the training dataset. val_loader : torch.utils.data.DataLoader DataLoader for the validation dataset. model_is_binary : bool True if the model is a binary segmentation model, False if it is a multi-class segmentation model. train_loss : list List of training losses for each epoch. val_loss : list List of validation losses for each epoch. val_accuracy : list List of accuracies on the validation set for each epoch. val_dice : list List of Dice scores on the validation set for each epoch. loss_fn : torch.nn.Module Loss function to be used for training. device : str Device to run the model on ('cuda' or 'cpu'). scaler : torch.cuda.amp.GradScaler Gradient scaler to prevent underflow and overflow during training. Methods ------- train_step() Performs a single training step (forward and backward pass for one epoch) on the training dataset. val_step(save_img_dir=None) Performs validation on the validation dataset and optionally saves images of predictions. save_checkpoint(file_path='checkpoints/last.pth') Save the model and optimizer state to a checkpoint file. load_checkpoint(path) Load the model and optimizer state from a checkpoint file or directory. train(num_epochs, save_interval=5, early_stop_patience=None, save_img=True) Train the model for a specified number of epochs, saving the best model based on the validation loss. """ def __init__(self, model, train_loader, val_loader, loss_fn='default', optimizer='default', learning_rate=1e-4, device='default' ): """Initialize the Trainer class with the model, data loaders, loss function, optimizer, and device. Parameters ---------- model : torch.nn.Module Model to be trained and validated. train_loader : torch.utils.data.DataLoader DataLoader for the training dataset. val_loader : torch.utils.data.DataLoader DataLoader for the validation dataset. loss_fn : torch.nn.Module or str, optional Loss function to be used for training. Default is 'default', which uses BCEWithLogitsLoss for binary segmentation and CrossEntropyLoss for multi-class segmentation. Note that if it cannot be determined whether the model is binary or multi-class, it is assumed to be multi-class. optimizer : torch.optim.Optimizer or str, optional Optimizer to be used for training. Default is 'default', which uses Adam with a learning rate of 1e-4. learning_rate : float, optional Learning rate to be used by the optimizer. Default is 1e-4. device : str, optional Device to run the model on ('cuda' or 'cpu'). Default is 'default', which uses 'cuda' if available, otherwise 'cpu'. """ self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.model_is_binary = ( (hasattr(self.model, 'is_binary') and self.model.is_binary) or (hasattr(self.model, 'final_conv') and self.model.final_conv.out_channels == 1)) self.train_loss = [] self.val_loss = [] self.val_accuracy = [] self.val_dice = [] self.loss_fn = loss_fn if loss_fn == 'default': if self.model_is_binary: self.loss_fn = torch.nn.BCEWithLogitsLoss() else: self.loss_fn = torch.nn.CrossEntropyLoss() self.optimizer = optimizer if optimizer == 'default': self.optimizer = torch.optim.Adam(self.model.parameters(), learning_rate) self.device = device if device == 'default': self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.scaler = torch.amp.GradScaler(self.device) self.model.to(self.device) def __save_images(self, x, y, prediction, save_img_dir, index): """Save images of the input, predictions, and masks during training or validation. Parameters ---------- x : torch.Tensor Input image tensor. y : torch.Tensor Target mask tensor. prediction : torch.Tensor Predicted mask tensor. save_img_dir : str Directory where the images will be saved. index : int Index of the image in the dataset, used for naming the files. """ os.makedirs(save_img_dir, exist_ok=True) # Normalize the input image for visualization x = (x - x.min()) / (x.max() - x.min()) torchvision.utils.save_image(x, f"{save_img_dir}img_{index}.png") torchvision.utils.save_image(prediction, f"{save_img_dir}pred_{index}.png") torchvision.utils.save_image(y, f"{save_img_dir}mask_{index}.png")
[docs] def train_step(self, save_img_dir = None): """Performs a single training step (forward and backward pass for one epoch) on the training dataset. Optionally saves images of the input, predictions, and masks during training. Parameters ---------- save_img_dir : str, optional If provided, directory where the input images, predictions, and masks will be saved. Returns ------- loss : float Average loss over the training set. """ self.model.train() loop = tqdm(self.train_loader, desc='Training') total_loss = 0 for index, (x, y) in enumerate(loop): x = x.to(self.device).float() y = y.to(self.device).float() if self.model_is_binary: y = y.unsqueeze(1) # Add channel dimension for binary segmentation # Forward pass with torch.amp.autocast(self.device): prediction = self.model(x) loss = self.loss_fn(prediction, y) total_loss += loss.item() # Backward pass self.optimizer.zero_grad() self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() # Update tqdm loop loop.set_postfix(loss=loss.item()) # Save images if save_img_dir: self.__save_images(x, y, prediction, save_img_dir, index) return total_loss / len(self.train_loader)
[docs] def val_step(self, save_img_dir = None): """Performs validation on the validation dataset. Optionally saves images of the input, predictions, and masks during validation. Parameters ---------- save_img_dir : str, optional If provided, directory where the input images, predictions, and masks will be saved. Returns ------- loss : float Average loss over the validation set. accuracy : float Accuracy of the model on the validation set (ratio of correct predictions to total pixels). dice_score : float Average Dice score of the model on the validation set. """ self.model.eval() total_loss = 0 num_correct = 0 num_pixels = 0 total_dice = 0 with torch.no_grad(): loop = tqdm(self.val_loader, desc='Validation') for index, (x, y) in enumerate(loop): x = x.to(self.device).float() y = y.to(self.device) if self.model_is_binary: y = y.float().unsqueeze(1) # Add channel dimension for binary segmentation else: # If the model is multi-class, # remove the channel dimension from the target tensor if it exists if y.dim() == 4 and y.size(1) == 1: y = y.squeeze(1) prediction = self.model(x) # prediction = torch.sigmoid(self.model(x)) loss = self.loss_fn(prediction, y) total_loss += loss.item() if self.model_is_binary: prediction = torch.sigmoid(prediction) prediction = (prediction > 0.5).float() else: prediction = torch.softmax(prediction, dim=1) prediction = torch.argmax(prediction, dim=1) num_correct += (prediction == y).sum().item() num_pixels += y.numel() total_dice += ((2 * (prediction*y).sum()) / ((prediction + y).sum() + 1e-8)).item() # add epsilon to avoid division by zero # Update tqdm loop loop.set_postfix(loss=loss.item()) # Save images if save_img_dir: self.__save_images(x, y, prediction, save_img_dir, index) # Compute metrics loss = total_loss / len(self.val_loader) accuracy = num_correct / num_pixels dice_score = total_dice / len(self.val_loader) return loss, accuracy, dice_score
[docs] def save_checkpoint(self, file_path = 'checkpoints/last.pth'): """Save the model and optimizer state to a checkpoint file. Parameters ---------- file_path : str, optional Path to the checkpoint file. Default is 'checkpoints/last.pth'. """ os.makedirs(file_path, exist_ok=True) checkpoint = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict() } torch.save(checkpoint, file_path)
[docs] def load_checkpoint(self, path): """Load the model and optimizer state from a checkpoint file or directory. If a directory is provided, the most recent checkpoint file in the directory is loaded. Parameters ---------- path : str Path to the checkpoint file or directory containing the checkpoint files. Raises ------ FileNotFoundError If the file does not exist or the directory is empty. """ if os.path.isfile(path): if not os.path.exists(path): raise FileNotFoundError(f"File {path} does not exist") checkpoint_path = path elif os.path.isdir(path): files = os.listdir(path) if not files: raise FileNotFoundError(f"Directory {path} is empty") files.sort(key=os.path.getmtime) checkpoint_path = os.path.join(path, files[-1]) # Most recent file checkpoint = torch.load(checkpoint_path, weights_only=True) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer'])
[docs] def train(self, num_epochs, save_interval=5, early_stop_patience=None, save_val_img=True, save_train_img=False): """Train the model for a specified number of epochs, saving the last model and the best model based on the validation loss. Parameters ---------- num_epochs : int Number of epochs to train the model for. save_interval : int, optional Interval at which to save the model checkpoint. Default is 5. If set to 0, only the best and last models are saved. If set to 1, the model is saved after every epoch. early_stop_patience : int, optional If provided, training will stop if the validation loss does not improve after this number of epochs. save__val_img : bool, optional If True, save the input images, predictions, and masks during validation. Default is True. save_train_img : bool, optional If True, save the input images, predictions, and masks during training. Default is False. Returns ------- train_loss : list List of training losses for each epoch. val_loss : list List of validation losses for each epoch. val_accuracy : list List of accuracies on the validation set for each epoch. val_dice : list List of Dice scores on the validation set for each epoch. """ best_val_loss = 0 patience_counter = 0 save_img_val_dir = None save_img_train_dir = None for epoch in range(num_epochs): print(f"\nEpoch [{epoch}/{num_epochs-1}]") if save_train_img and epoch % save_interval == 0: save_img_train_dir = f"saved_images/epoch_{epoch}/train" else: save_img_train_dir = None if save_val_img and epoch % save_interval == 0: save_img_val_dir = f"saved_images/epoch_{epoch}/val" else: save_img_val_dir = None # Perform a training step train_loss = self.train_step(save_img_train_dir) print(f"Training Loss: {train_loss:.4f}") self.train_loss.append(train_loss) # Perform a validation step val_loss, val_accuracy, val_dice = self.val_step(save_img_val_dir) print(f"Validation Loss: {val_loss:.4f} - Accuracy: {val_accuracy:.4f} - Dice Score: {val_dice:.4f}") self.val_loss.append(val_loss) self.val_accuracy.append(val_accuracy) self.val_dice.append(val_dice) # Save the model checkpoint if save_interval > 0 and epoch % save_interval == 0: self.save_checkpoint(f'checkpoints/{epoch}.pth') if val_loss > best_val_loss: best_val_loss = val_loss self.save_checkpoint('checkpoints/best.pth') # Early Stopping check if early_stop_patience: if val_loss < best_val_loss: patience_counter = 0 else: patience_counter += 1 if patience_counter >= early_stop_patience: print(f"Early stopping at epoch {epoch}") break # Save the last model self.save_checkpoint('checkpoints/last.pth') return self.train_loss, self.val_loss, self.val_accuracy, self.val_dice