123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- class IdentityLayer(nn.Module):
- def __init__(self):
- super(IdentityLayer, self).__init__()
- def forward(self, x):
- return x
- @staticmethod
- def is_zero_layer():
- return False
- class ZeroLayer(nn.Module):
- def __init__(self, stride):
- super(ZeroLayer, self).__init__()
- self.stride = stride
- def forward(self, x):
- n, c, h, w = x.shape
- h //= self.stride[0]
- w //= self.stride[1]
- device = x.device
- padding = torch.zeros(n, c, h, w, device=device, requires_grad=False)
- return padding
- @staticmethod
- def is_zero_layer():
- return True
- def get_flops(self, x):
- return 0, self.forward(x)
- def get_same_padding(kernel_size):
- if isinstance(kernel_size, tuple):
- assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
- p1 = get_same_padding(kernel_size[0])
- p2 = get_same_padding(kernel_size[1])
- return p1, p2
- assert isinstance(kernel_size,
- int), 'kernel size should be either `int` or `tuple`'
- assert kernel_size % 2 > 0, 'kernel size should be odd number'
- return kernel_size // 2
- class MBInvertedConvLayer(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=3,
- stride=(1, 1),
- expand_ratio=6,
- mid_channels=None):
- super(MBInvertedConvLayer, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.stride = stride
- self.expand_ratio = expand_ratio
- self.mid_channels = mid_channels
- feature_dim = round(
- self.in_channels *
- self.expand_ratio) if mid_channels is None else mid_channels
- if self.expand_ratio == 1:
- self.inverted_bottleneck = None
- else:
- self.inverted_bottleneck = nn.Sequential(
- OrderedDict([
- ('conv',
- nn.Conv2d(self.in_channels,
- feature_dim,
- 1,
- 1,
- 0,
- bias=False)),
- ('bn', nn.BatchNorm2d(feature_dim)),
- ('act', nn.ReLU6(inplace=True)),
- ]))
- pad = get_same_padding(self.kernel_size)
- self.depth_conv = nn.Sequential(
- OrderedDict([
- ('conv',
- nn.Conv2d(feature_dim,
- feature_dim,
- kernel_size,
- stride,
- pad,
- groups=feature_dim,
- bias=False)),
- ('bn', nn.BatchNorm2d(feature_dim)),
- ('act', nn.ReLU6(inplace=True)),
- ]))
- self.point_conv = nn.Sequential(
- OrderedDict([
- ('conv',
- nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
- ('bn', nn.BatchNorm2d(out_channels)),
- ]))
- def forward(self, x):
- if self.inverted_bottleneck:
- x = self.inverted_bottleneck(x)
- x = self.depth_conv(x)
- x = self.point_conv(x)
- return x
- @staticmethod
- def is_zero_layer():
- return False
- def conv_func_by_name(name):
- name2ops = {
- 'Identity': lambda in_C, out_C, S: IdentityLayer(),
- 'Zero': lambda in_C, out_C, S: ZeroLayer(stride=S),
- }
- name2ops.update({
- '3x3_MBConv1':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1),
- '3x3_MBConv2':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2),
- '3x3_MBConv3':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3),
- '3x3_MBConv4':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4),
- '3x3_MBConv5':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5),
- '3x3_MBConv6':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6),
- #######################################################################################
- '5x5_MBConv1':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1),
- '5x5_MBConv2':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2),
- '5x5_MBConv3':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3),
- '5x5_MBConv4':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4),
- '5x5_MBConv5':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5),
- '5x5_MBConv6':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6),
- #######################################################################################
- '7x7_MBConv1':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1),
- '7x7_MBConv2':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2),
- '7x7_MBConv3':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3),
- '7x7_MBConv4':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4),
- '7x7_MBConv5':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5),
- '7x7_MBConv6':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6),
- })
- return name2ops[name]
- def build_candidate_ops(candidate_ops, in_channels, out_channels, stride,
- ops_order):
- if candidate_ops is None:
- raise ValueError('please specify a candidate set')
- name2ops = {
- 'Identity':
- lambda in_C, out_C, S: IdentityLayer(in_C, out_C, ops_order=ops_order),
- 'Zero':
- lambda in_C, out_C, S: ZeroLayer(stride=S),
- }
- # add MBConv layers
- name2ops.update({
- '3x3_MBConv1':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1),
- '3x3_MBConv2':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2),
- '3x3_MBConv3':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3),
- '3x3_MBConv4':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4),
- '3x3_MBConv5':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5),
- '3x3_MBConv6':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6),
- #######################################################################################
- '5x5_MBConv1':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1),
- '5x5_MBConv2':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2),
- '5x5_MBConv3':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3),
- '5x5_MBConv4':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4),
- '5x5_MBConv5':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5),
- '5x5_MBConv6':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6),
- #######################################################################################
- '7x7_MBConv1':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1),
- '7x7_MBConv2':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2),
- '7x7_MBConv3':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3),
- '7x7_MBConv4':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4),
- '7x7_MBConv5':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5),
- '7x7_MBConv6':
- lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6),
- })
- return [
- name2ops[name](in_channels, out_channels, stride)
- for name in candidate_ops
- ]
- class MobileInvertedResidualBlock(nn.Module):
- def __init__(self, mobile_inverted_conv, shortcut):
- super(MobileInvertedResidualBlock, self).__init__()
- self.mobile_inverted_conv = mobile_inverted_conv
- self.shortcut = shortcut
- def forward(self, x):
- if self.mobile_inverted_conv.is_zero_layer():
- res = x
- elif self.shortcut is None or self.shortcut.is_zero_layer():
- res = self.mobile_inverted_conv(x)
- else:
- conv_x = self.mobile_inverted_conv(x)
- skip_x = self.shortcut(x)
- res = skip_x + conv_x
- return res
- class AutoSTREncoder(nn.Module):
- def __init__(self,
- in_channels,
- out_dim=256,
- with_lstm=True,
- stride_stages='[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]',
- n_cell_stages=[3, 3, 3, 3, 3],
- conv_op_ids=[5, 5, 5, 5, 5, 5, 5, 6, 6, 5, 4, 3, 4, 6, 6],
- **kwargs):
- super().__init__()
- self.first_conv = nn.Sequential(
- nn.Conv2d(in_channels,
- 32,
- kernel_size=(3, 3),
- stride=1,
- padding=1,
- bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True))
- stride_stages = eval(stride_stages)
- width_stages = [32, 64, 128, 256, 512]
- conv_candidates = [
- '5x5_MBConv1', '5x5_MBConv3', '5x5_MBConv6', '3x3_MBConv1',
- '3x3_MBConv3', '3x3_MBConv6', 'Zero'
- ]
- assert len(conv_op_ids) == sum(n_cell_stages)
- blocks = []
- input_channel = 32
- for width, n_cell, s in zip(width_stages, n_cell_stages,
- stride_stages):
- for i in range(n_cell):
- if i == 0:
- stride = s
- else:
- stride = (1, 1)
- block_i = len(blocks)
- conv_op = conv_func_by_name(
- conv_candidates[conv_op_ids[block_i]])(input_channel,
- width, stride)
- if stride == (1, 1) and input_channel == width:
- shortcut = IdentityLayer()
- else:
- shortcut = None
- inverted_residual_block = MobileInvertedResidualBlock(
- conv_op, shortcut)
- blocks.append(inverted_residual_block)
- input_channel = width
- self.out_channels = input_channel
- self.blocks = nn.ModuleList(blocks)
- # with_lstm = False
- self.with_lstm = with_lstm
- if with_lstm:
- self.rnn = nn.LSTM(input_channel,
- out_dim // 2,
- bidirectional=True,
- num_layers=2,
- batch_first=True)
- self.out_channels = out_dim
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight,
- mode='fan_out',
- nonlinearity='relu')
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- x = self.first_conv(x)
- for block in self.blocks:
- x = block(x)
- cnn_feat = x.squeeze(dim=2)
- cnn_feat = cnn_feat.transpose(2, 1)
- if self.with_lstm:
- rnn_feat, _ = self.rnn(cnn_feat)
- return rnn_feat
- else:
- return cnn_feat
|