# Adapted from:
# https://github.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch
# By Emma Avetissian
# coding=utf-8
from __future__ import absolute_import, division, print_function
import copy
import logging
import math
import ml_collections
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torch.distributions.normal import Normal
from torch.nn import Conv3d, Dropout, LayerNorm, Linear, Softmax
from torch.nn.modules.utils import _pair, _triple
[docs]
def get_3DReg_config():
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({"size": (8, 8, 8)})
config.patches.grid = (8, 8, 8)
config.hidden_size = 252
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.patch_size = 8
config.conv_first_channel = 512
config.encoder_channels = (16, 32, 32)
config.down_factor = 2
config.down_num = 2
config.decoder_channels = (96, 48, 32, 32, 16)
config.skip_channels = (32, 32, 32, 32, 16)
config.n_skip = 5
config.input_channels = 1
return config
logger = logging.getLogger(__name__)
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"
[docs]
def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)
[docs]
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {
"gelu": torch.nn.functional.gelu,
"relu": torch.nn.functional.relu,
"swish": swish,
}
[docs]
class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = Linear(config.hidden_size, self.all_head_size)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
[docs]
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
[docs]
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights
[docs]
class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
self.act_fn = ACT2FN["gelu"]
self.dropout = Dropout(config.transformer["dropout_rate"])
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
[docs]
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
[docs]
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings."""
def __init__(self, config, img_size):
super(Embeddings, self).__init__()
self.config = config
down_factor = config.down_factor
patch_size = _triple(config.patches["size"])
n_patches = int(
(img_size[0] / 2**down_factor // patch_size[0])
* (img_size[1] / 2**down_factor // patch_size[1])
* (img_size[2] / 2**down_factor // patch_size[2])
)
self.hybrid_model = CNNEncoder(config, n_channels=config.input_channels)
in_channels = config["encoder_channels"][-1]
self.patch_embeddings = Conv3d(
in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size,
)
self.position_embeddings = nn.Parameter(
torch.zeros(1, n_patches, config.hidden_size)
)
self.dropout = Dropout(config.transformer["dropout_rate"])
[docs]
def forward(self, x):
x, features = self.hybrid_model(x)
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
x = x.flatten(2)
x = x.transpose(-1, -2) # (B, n_patches, hidden)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, features
[docs]
class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config.hidden_size
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)
[docs]
def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x, weights
[docs]
class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
for _ in range(config.transformer["num_layers"]):
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))
[docs]
def forward(self, hidden_states):
attn_weights = []
for layer_block in self.layer:
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights
[docs]
class Conv3dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm3d(out_channels)
super(Conv3dReLU, self).__init__(conv, bn, relu)
[docs]
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv3dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv3dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False)
[docs]
def forward(self, x, skip=None):
x = self.up(x)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
[docs]
class DecoderCup(nn.Module):
def __init__(self, config, img_size):
super().__init__()
self.config = config
self.down_factor = config.down_factor
head_channels = config.conv_first_channel
self.img_size = img_size
self.conv_more = Conv3dReLU(
config.hidden_size,
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
decoder_channels = config.decoder_channels
in_channels = [head_channels] + list(decoder_channels[:-1])
out_channels = decoder_channels
self.patch_size = _triple(config.patches["size"])
skip_channels = self.config.skip_channels
blocks = [
DecoderBlock(in_ch, out_ch, sk_ch)
for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
]
self.blocks = nn.ModuleList(blocks)
[docs]
def forward(self, hidden_states, features=None):
B, n_patch, hidden = (
hidden_states.size()
) # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
l, h, w = (
(self.img_size[0] // 2**self.down_factor // self.patch_size[0]),
(self.img_size[1] // 2**self.down_factor // self.patch_size[1]),
(self.img_size[2] // 2**self.down_factor // self.patch_size[2]),
)
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, l, h, w)
x = self.conv_more(x)
for i, decoder_block in enumerate(self.blocks):
if features is not None:
skip = features[i] if (i < self.config.n_skip) else None
# print(skip.shape)
else:
skip = None
x = decoder_block(x, skip=skip)
return x
[docs]
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
[docs]
def forward(self, x):
return self.double_conv(x)
[docs]
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2), DoubleConv(in_channels, out_channels)
)
[docs]
def forward(self, x):
return self.maxpool_conv(x)
[docs]
class CNNEncoder(nn.Module):
def __init__(self, config, n_channels=2):
super(CNNEncoder, self).__init__()
self.n_channels = n_channels
decoder_channels = config.decoder_channels
encoder_channels = config.encoder_channels
self.down_num = config.down_num
self.inc = DoubleConv(n_channels, encoder_channels[0])
self.down1 = Down(encoder_channels[0], encoder_channels[1])
self.down2 = Down(encoder_channels[1], encoder_channels[2])
self.width = encoder_channels[-1]
[docs]
def forward(self, x):
features = []
x1 = self.inc(x)
features.append(x1)
x2 = self.down1(x1)
features.append(x2)
feats = self.down2(x2)
features.append(feats)
feats_down = feats
for i in range(self.down_num):
feats_down = nn.MaxPool3d(2)(feats_down)
features.append(feats_down)
return feats, features[::-1]
[docs]
class RegistrationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv3d = nn.Conv3d(
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
)
conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape))
conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
super().__init__(conv3d)
[docs]
class ViTVNet(nn.Module):
"""
ViT-V-Net model
Parameters
----------
out_channels : int
The number of output channels.
config : str or ml_collections.ConfigDict
The configuration of the model.
img_size : tuple of int
The size of the image. Should be the length of the dimensions of the image.
vis : bool
Whether to visualize the attention weights.
"""
def __init__(
self, out_channels, config="ViT-V-Net", img_size=(128, 128, 128), vis=False
):
super(ViTVNet, self).__init__()
if isinstance(config, str):
config = CONFIGS[config]
else:
assert isinstance(
config, ml_collections.ConfigDict
), "Is not a config object or the name of one"
self.transformer = Transformer(config, img_size, vis)
self.decoder = DecoderCup(config, img_size)
self.reg_head = RegistrationHead(
in_channels=config.decoder_channels[-1],
out_channels=out_channels,
kernel_size=3,
)
self.config = config
[docs]
def forward(self, x):
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
x = self.decoder(x, features)
out = self.reg_head(x)
return out
CONFIGS = {
"ViT-V-Net": get_3DReg_config(),
}