Source code for cellmap_segmentation_challenge.models.vitnet

# Adapted from:
# https://github.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch
# By Emma Avetissian

# coding=utf-8
from __future__ import absolute_import, division, print_function

import copy
import logging
import math

import ml_collections
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torch.distributions.normal import Normal
from torch.nn import Conv3d, Dropout, LayerNorm, Linear, Softmax
from torch.nn.modules.utils import _pair, _triple


[docs] def get_3DReg_config(): config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({"size": (8, 8, 8)}) config.patches.grid = (8, 8, 8) config.hidden_size = 252 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 3072 config.transformer.num_heads = 12 config.transformer.num_layers = 12 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.patch_size = 8 config.conv_first_channel = 512 config.encoder_channels = (16, 32, 32) config.down_factor = 2 config.down_num = 2 config.decoder_channels = (96, 48, 32, 32, 16) config.skip_channels = (32, 32, 32, 32, 16) config.n_skip = 5 config.input_channels = 1 return config
logger = logging.getLogger(__name__) ATTENTION_Q = "MultiHeadDotProductAttention_1/query" ATTENTION_K = "MultiHeadDotProductAttention_1/key" ATTENTION_V = "MultiHeadDotProductAttention_1/value" ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" FC_0 = "MlpBlock_3/Dense_0" FC_1 = "MlpBlock_3/Dense_1" ATTENTION_NORM = "LayerNorm_0" MLP_NORM = "LayerNorm_2"
[docs] def np2th(weights, conv=False): """Possibly convert HWIO to OIHW.""" if conv: weights = weights.transpose([3, 2, 0, 1]) return torch.from_numpy(weights)
[docs] def swish(x): return x * torch.sigmoid(x)
ACT2FN = { "gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish, }
[docs] class Attention(nn.Module): def __init__(self, config, vis): super(Attention, self).__init__() self.vis = vis self.num_attention_heads = config.transformer["num_heads"] self.attention_head_size = int(config.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = Linear(config.hidden_size, self.all_head_size) self.key = Linear(config.hidden_size, self.all_head_size) self.value = Linear(config.hidden_size, self.all_head_size) self.out = Linear(config.hidden_size, config.hidden_size) self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) self.softmax = Softmax(dim=-1)
[docs] def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3)
[docs] def forward(self, hidden_states): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = self.softmax(attention_scores) weights = attention_probs if self.vis else None attention_probs = self.attn_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) attention_output = self.out(context_layer) attention_output = self.proj_dropout(attention_output) return attention_output, weights
[docs] class Mlp(nn.Module): def __init__(self, config): super(Mlp, self).__init__() self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) self.act_fn = ACT2FN["gelu"] self.dropout = Dropout(config.transformer["dropout_rate"]) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.normal_(self.fc1.bias, std=1e-6) nn.init.normal_(self.fc2.bias, std=1e-6)
[docs] def forward(self, x): x = self.fc1(x) x = self.act_fn(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x
[docs] class Embeddings(nn.Module): """Construct the embeddings from patch, position embeddings.""" def __init__(self, config, img_size): super(Embeddings, self).__init__() self.config = config down_factor = config.down_factor patch_size = _triple(config.patches["size"]) n_patches = int( (img_size[0] / 2**down_factor // patch_size[0]) * (img_size[1] / 2**down_factor // patch_size[1]) * (img_size[2] / 2**down_factor // patch_size[2]) ) self.hybrid_model = CNNEncoder(config, n_channels=config.input_channels) in_channels = config["encoder_channels"][-1] self.patch_embeddings = Conv3d( in_channels=in_channels, out_channels=config.hidden_size, kernel_size=patch_size, stride=patch_size, ) self.position_embeddings = nn.Parameter( torch.zeros(1, n_patches, config.hidden_size) ) self.dropout = Dropout(config.transformer["dropout_rate"])
[docs] def forward(self, x): x, features = self.hybrid_model(x) x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) x = x.flatten(2) x = x.transpose(-1, -2) # (B, n_patches, hidden) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings, features
[docs] class Block(nn.Module): def __init__(self, config, vis): super(Block, self).__init__() self.hidden_size = config.hidden_size self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) self.ffn = Mlp(config) self.attn = Attention(config, vis)
[docs] def forward(self, x): h = x x = self.attention_norm(x) x, weights = self.attn(x) x = x + h h = x x = self.ffn_norm(x) x = self.ffn(x) x = x + h return x, weights
[docs] class Encoder(nn.Module): def __init__(self, config, vis): super(Encoder, self).__init__() self.vis = vis self.layer = nn.ModuleList() self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) for _ in range(config.transformer["num_layers"]): layer = Block(config, vis) self.layer.append(copy.deepcopy(layer))
[docs] def forward(self, hidden_states): attn_weights = [] for layer_block in self.layer: hidden_states, weights = layer_block(hidden_states) if self.vis: attn_weights.append(weights) encoded = self.encoder_norm(hidden_states) return encoded, attn_weights
[docs] class Transformer(nn.Module): def __init__(self, config, img_size, vis): super(Transformer, self).__init__() self.embeddings = Embeddings(config, img_size=img_size) self.encoder = Encoder(config, vis)
[docs] def forward(self, input_ids): embedding_output, features = self.embeddings(input_ids) encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) return encoded, attn_weights, features
[docs] class Conv3dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): conv = nn.Conv3d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) bn = nn.BatchNorm3d(out_channels) super(Conv3dReLU, self).__init__(conv, bn, relu)
[docs] class DecoderBlock(nn.Module): def __init__( self, in_channels, out_channels, skip_channels=0, use_batchnorm=True, ): super().__init__() self.conv1 = Conv3dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = Conv3dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False)
[docs] def forward(self, x, skip=None): x = self.up(x) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x
[docs] class DecoderCup(nn.Module): def __init__(self, config, img_size): super().__init__() self.config = config self.down_factor = config.down_factor head_channels = config.conv_first_channel self.img_size = img_size self.conv_more = Conv3dReLU( config.hidden_size, head_channels, kernel_size=3, padding=1, use_batchnorm=True, ) decoder_channels = config.decoder_channels in_channels = [head_channels] + list(decoder_channels[:-1]) out_channels = decoder_channels self.patch_size = _triple(config.patches["size"]) skip_channels = self.config.skip_channels blocks = [ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) ] self.blocks = nn.ModuleList(blocks)
[docs] def forward(self, hidden_states, features=None): B, n_patch, hidden = ( hidden_states.size() ) # reshape from (B, n_patch, hidden) to (B, h, w, hidden) l, h, w = ( (self.img_size[0] // 2**self.down_factor // self.patch_size[0]), (self.img_size[1] // 2**self.down_factor // self.patch_size[1]), (self.img_size[2] // 2**self.down_factor // self.patch_size[2]), ) x = hidden_states.permute(0, 2, 1) x = x.contiguous().view(B, hidden, l, h, w) x = self.conv_more(x) for i, decoder_block in enumerate(self.blocks): if features is not None: skip = features[i] if (i < self.config.n_skip) else None # print(skip.shape) else: skip = None x = decoder_block(x, skip=skip) return x
[docs] class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), )
[docs] def forward(self, x): return self.double_conv(x)
[docs] class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool3d(2), DoubleConv(in_channels, out_channels) )
[docs] def forward(self, x): return self.maxpool_conv(x)
[docs] class CNNEncoder(nn.Module): def __init__(self, config, n_channels=2): super(CNNEncoder, self).__init__() self.n_channels = n_channels decoder_channels = config.decoder_channels encoder_channels = config.encoder_channels self.down_num = config.down_num self.inc = DoubleConv(n_channels, encoder_channels[0]) self.down1 = Down(encoder_channels[0], encoder_channels[1]) self.down2 = Down(encoder_channels[1], encoder_channels[2]) self.width = encoder_channels[-1]
[docs] def forward(self, x): features = [] x1 = self.inc(x) features.append(x1) x2 = self.down1(x1) features.append(x2) feats = self.down2(x2) features.append(feats) feats_down = feats for i in range(self.down_num): feats_down = nn.MaxPool3d(2)(feats_down) features.append(feats_down) return feats, features[::-1]
[docs] class RegistrationHead(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): conv3d = nn.Conv3d( in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 ) conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) super().__init__(conv3d)
[docs] class ViTVNet(nn.Module): """ ViT-V-Net model Parameters ---------- out_channels : int The number of output channels. config : str or ml_collections.ConfigDict The configuration of the model. img_size : tuple of int The size of the image. Should be the length of the dimensions of the image. vis : bool Whether to visualize the attention weights. """ def __init__( self, out_channels, config="ViT-V-Net", img_size=(128, 128, 128), vis=False ): super(ViTVNet, self).__init__() if isinstance(config, str): config = CONFIGS[config] else: assert isinstance( config, ml_collections.ConfigDict ), "Is not a config object or the name of one" self.transformer = Transformer(config, img_size, vis) self.decoder = DecoderCup(config, img_size) self.reg_head = RegistrationHead( in_channels=config.decoder_channels[-1], out_channels=out_channels, kernel_size=3, ) self.config = config
[docs] def forward(self, x): x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) x = self.decoder(x, features) out = self.reg_head(x) return out
CONFIGS = { "ViT-V-Net": get_3DReg_config(), }