import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
[docs]
class DoubleConvolution(nn.Module):
"""Applies double convolution to the input
Double convolution is a sequence of two 3x3 convolutional layers with batch normalization and ReLU activation function.
Attributes
----------
double_conv : nn.Sequential
Sequence of two convolutional layers with batch normalization and ReLU activation function
Methods
-------
forward(x)
Perform the forward propagation
"""
def __init__(self, in_channels, out_channels):
"""Initializes the DoubleConvolution module
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
Raises
------
ValueError
If the number of input channels is less than 1
If the number of output channels is less than 1
"""
if in_channels < 1:
raise ValueError("The number of input channels should be at least 1")
if out_channels < 1:
raise ValueError("The number of output channels should be at least 1")
super(DoubleConvolution, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels), #original code has no batch normalization (since it came from a paper in 2016)
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
[docs]
def forward(self, x):
"""Perform the forward propagation
Parameters
----------
x : torch.Tensor
Input tensor to be processed by the double convolution
Returns
-------
result : torch.Tensor
Output tensor after applying the double convolution
"""
return self.double_conv(x)
[docs]
class UNet(nn.Module):
"""UNet model
UNet is a convolutional neural network used for image segmentation. It consists of an encoder and a decoder. The encoder
downsamples the input image and the decoder upsamples the encoder's output to the original size. The encoder and decoder
are connected by so-called skip connections, which concatenate the output of the encoder to the input of the decoder at the
same resolution. This helps the decoder to recover the spatial information lost during downsampling. The skip connections
are concatenated channel-wise, which means that the number of channels is doubled after each concatenation. The UNet model
has a contracting path (encoder) and an expansive path (decoder): the contracting path follows the typical architecture of
a convolutional neural network, with a series of convolutional layers followed by a max-pooling layer. The expansive path
consists of a series of up-convolutions, which increase the spatial resolution of the input, followed by a series of
convolutional layers. The final layer of the UNet model is a 1x1 convolutional layer that maps each pixel to the desired
number of classes.
Attributes
----------
num_features : int
Number of features in the encoder
ups : nn.ModuleList
List of up-convolutions in the decoder
downs : nn.ModuleList
List of double convolutions in the encoder
pool : nn.MaxPool2d
Max pooling layer
bottleneck : DoubleConvolution
Bottleneck layer
final_conv : nn.Conv2d
Final convolutional layer
Methods
-------
is_binary()
Returns whether the model output is binary or multi-class
forward(x)
Perform the forward propagation
References
----------
Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. arXiv preprint arXiv:1505.04597.
"""
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
"""Initialize the UNet model
Parameters
----------
in_channels: int
Number of input channels: 1 for grayscale, 3 for RGB
out_channels: int
Number of output channels: 1 for binary segmentation, >1 for multi-class segmentation
features: list
List of features in the encoder and decoder. The length of the list determines the depth of the UNet model.
Raises
------
ValueError
If the number of features is less than 2 or if any feature is less than 0
If the number of input channels is less than 1
If the number of output channels is less than 1
"""
if len(features) < 2:
raise ValueError("The number of features should be at least 2")
if any(f < 0 for f in features):
raise ValueError("The number of features should be positive")
if in_channels < 1:
raise ValueError("The number of input channels should be at least 1")
if out_channels < 1:
raise ValueError("The number of output channels should be at least 1")
super(UNet, self).__init__()
self.num_features = len(features)
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part of UNet
for feature in features:
self.downs.append(DoubleConvolution(in_channels, feature))
in_channels = feature
# Up part of UNet
for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
self.ups.append(DoubleConvolution(feature*2, feature))
self.bottleneck = DoubleConvolution(features[-1], features[-1]*2)
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
[docs]
def is_binary(self):
"""Returns whether the model output is binary or multi-class
Returns
-------
binary : bool
True if the model output is binary, False if the model output is multi-class
"""
return self.final_conv.out_channels == 1
[docs]
def forward(self, x):
"""Perform the forward propagation
Parameters
----------
x : torch.Tensor
Input tensor to be processed by the UNet model of shape (batch_size, in_channels, height, width)
Raises
------
ValueError
If the input tensor sizes is less then 2^(len(features)+1), since the tensor sizes are halved in each downsampling step
Returns
-------
result : torch.Tensor
Output tensor after applying the UNet model
"""
if x.shape[-1] < 2**(self.num_features+1) or x.shape[-2] < 2**(self.num_features+1):
raise ValueError(f"The input tensor sizes must be at least 2^(len(features)+1) = {2**(self.num_features+1)}")
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
# In order to make the implementation more general, if the input size is not divisible by 16, we need to resize the skip connection
# For instance, if the input size is 161x161, the output size will be 160x160, so we need to resize the skip connection to 160x160
if x.shape != skip_connection.shape:
# we can cut one or add padding to the image, but we will resize it for simplicity
# since it's just one pixel difference, it won't affect the performance
x = TF.resize(x, size=skip_connection.shape[2:])
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx+1](concat_skip)
return self.final_conv(x)