123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760 |
- """This code is refer from:
- https://github.com/MelosY/CAM
- """
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.init import trunc_normal_
- from .convnextv2 import ConvNeXtV2, Block, LayerNorm
- from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33
- class Swish(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- def forward(self, x):
- return x * torch.sigmoid(x)
- class UNetBlock(nn.Module):
- def __init__(self, cin, cout, bn2d, stride, deformable=False):
- """
- a UNet block with 2x up sampling
- """
- super().__init__()
- stride_h, stride_w = stride
- if stride_h == 1:
- kernel_h = 1
- padding_h = 0
- elif stride_h == 2:
- kernel_h = 4
- padding_h = 1
- elif stride_h == 4:
- kernel_h = 4
- padding_h = 0
- if stride_w == 1:
- kernel_w = 1
- padding_w = 0
- elif stride_w == 2:
- kernel_w = 4
- padding_w = 1
- elif stride_w == 4:
- kernel_w = 4
- padding_w = 0
- conv = nn.Conv2d
- self.up_sample = nn.ConvTranspose2d(cin,
- cin,
- kernel_size=(kernel_h, kernel_w),
- stride=(stride_h, stride_w),
- padding=(padding_h, padding_w),
- bias=True)
- self.conv = nn.Sequential(
- conv(cin, cin, kernel_size=3, stride=1, padding=1, bias=False),
- bn2d(cin),
- nn.ReLU6(inplace=True),
- conv(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
- bn2d(cout),
- )
- def forward(self, x):
- x = self.up_sample(x)
- return self.conv(x)
- class DepthWiseUNetBlock(nn.Module):
- def __init__(self, cin, cout, bn2d, stride, deformable=False):
- """
- a UNet block with 2x up sampling
- """
- super().__init__()
- stride_h, stride_w = stride
- if stride_h == 1:
- kernel_h = 1
- padding_h = 0
- elif stride_h == 2:
- kernel_h = 4
- padding_h = 1
- elif stride_h == 4:
- kernel_h = 4
- padding_h = 0
- if stride_w == 1:
- kernel_w = 1
- padding_w = 0
- elif stride_w == 2:
- kernel_w = 4
- padding_w = 1
- elif stride_w == 4:
- kernel_w = 4
- padding_w = 0
- self.up_sample = nn.ConvTranspose2d(cin,
- cin,
- kernel_size=(kernel_h, kernel_w),
- stride=(stride_h, stride_w),
- padding=(padding_h, padding_w),
- bias=True)
- self.conv = nn.Sequential(
- nn.Conv2d(cin,
- cin,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=cin),
- nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
- bias=False),
- bn2d(cin),
- nn.ReLU6(inplace=True),
- nn.Conv2d(cin,
- cin,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=cin),
- nn.Conv2d(cin,
- cout,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=False),
- bn2d(cout),
- )
- def forward(self, x):
- x = self.up_sample(x)
- return self.conv(x)
- class SFTLayer(nn.Module):
- def __init__(self, dim_in, dim_out):
- super(SFTLayer, self).__init__()
- self.SFT_scale_conv0 = nn.Linear(
- dim_in,
- dim_in,
- )
- self.SFT_scale_conv1 = nn.Linear(
- dim_in,
- dim_out,
- )
- self.SFT_shift_conv0 = nn.Linear(
- dim_in,
- dim_in,
- )
- self.SFT_shift_conv1 = nn.Linear(
- dim_in,
- dim_out,
- )
- def forward(self, x):
- # x[0]: fea; x[1]: cond
- scale = self.SFT_scale_conv1(
- F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
- shift = self.SFT_shift_conv1(
- F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
- return x[0] * (scale + 1) + shift
- class MoreUNetBlock(nn.Module):
- def __init__(self, cin, cout, bn2d, stride, deformable=False):
- """
- a UNet block with 2x up sampling
- """
- super().__init__()
- stride_h, stride_w = stride
- if stride_h == 1:
- kernel_h = 1
- padding_h = 0
- elif stride_h == 2:
- kernel_h = 4
- padding_h = 1
- elif stride_h == 4:
- kernel_h = 4
- padding_h = 0
- if stride_w == 1:
- kernel_w = 1
- padding_w = 0
- elif stride_w == 2:
- kernel_w = 4
- padding_w = 1
- elif stride_w == 4:
- kernel_w = 4
- padding_w = 0
- self.up_sample = nn.ConvTranspose2d(cin,
- cin,
- kernel_size=(kernel_h, kernel_w),
- stride=(stride_h, stride_w),
- padding=(padding_h, padding_w),
- bias=True)
- self.conv = nn.Sequential(
- nn.Conv2d(cin,
- cin,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=cin),
- nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
- bias=False), bn2d(cin), nn.ReLU6(inplace=True),
- nn.Conv2d(cin,
- cin,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=cin),
- nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
- bias=False), bn2d(cin), nn.ReLU6(inplace=True),
- nn.Conv2d(cin,
- cin,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=cin),
- nn.Conv2d(cin,
- cout,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=False), bn2d(cout), nn.ReLU6(inplace=True),
- nn.Conv2d(cout,
- cout,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=cout),
- nn.Conv2d(cout,
- cout,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=False), bn2d(cout))
- def forward(self, x):
- x = self.up_sample(x)
- return self.conv(x)
- class BinaryDecoder(nn.Module):
- def __init__(self,
- dim,
- num_classes,
- strides,
- use_depthwise_unet=False,
- use_more_unet=False,
- binary_loss_type='DiceLoss') -> None:
- super().__init__()
- channels = [dim // 2**i for i in range(4)]
- self.linear_enc2binary = nn.Sequential(
- nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),
- nn.SyncBatchNorm(dim),
- )
- self.strides = strides
- self.use_deformable = False
- self.binary_decoder = nn.ModuleList()
- unet = DepthWiseUNetBlock if use_depthwise_unet else UNetBlock
- unet = MoreUNetBlock if use_more_unet else unet
- for i in range(3):
- up_sample_stride = self.strides[::-1][i]
- cin, cout = channels[i], channels[i + 1]
- self.binary_decoder.append(
- unet(cin, cout, nn.SyncBatchNorm, up_sample_stride,
- self.use_deformable))
- last_stride = (self.strides[0][0] // 2, self.strides[0][1] // 2)
- self.binary_decoder.append(
- unet(cout, cout, nn.SyncBatchNorm, last_stride,
- self.use_deformable))
- if binary_loss_type == 'CrossEntropyDiceLoss' or binary_loss_type == 'BanlanceMultiClassCrossEntropyLoss':
- segm_num_cls = num_classes - 2
- else:
- segm_num_cls = num_classes - 3
- self.binary_pred = nn.Conv2d(channels[-1],
- segm_num_cls,
- kernel_size=1,
- stride=1,
- bias=True)
- def patchify(self, imgs):
- """
- imgs: (N, 3, H, W)
- x: (N, L, patch_size**2 *3)
- """
- p_h, p_w = self.strides[0]
- p_h = p_h // 2
- p_w = p_w // 2
- h = imgs.shape[2] // p_h
- w = imgs.shape[3] // p_w
- x = imgs.reshape(shape=(imgs.shape[0], 3, h, p_h, w, p_w))
- x = torch.einsum('nchpwq->nhwpqc', x)
- x = x.reshape(shape=(imgs.shape[0], h * w, p_h * p_w * 3))
- return x
- def unpatchify(self, x):
- """
- x: (N, patch_size**2, h, w)
- imgs: (N, 3, H, W)
- """
- p_h, p_w = self.strides[0]
- p_h = p_h // 2
- p_w = p_w // 2
- _, _, h, w = x.shape
- assert p_h * p_w == x.shape[1]
- x = x.permute(0, 2, 3, 1) # [N, h, w, 4*4]
- x = x.reshape(shape=(x.shape[0], h, w, p_h, p_w))
- x = torch.einsum('nhwpq->nhpwq', x)
- imgs = x.reshape(shape=(x.shape[0], h * p_h, w * p_w))
- return imgs
- def forward(self, x, time=None):
- """
- x: the encoder feat to init the query for binary prediction, usually this is equal to the `img`.
- img: the encoder feat.
- txt: the unnormmed text to get the length of predicted words.
- txt_feat: the text feat before character prediction.
- xs: the encoder feat from different stages
- """
- binary_feats = []
- x = self.linear_enc2binary(x)
- binary_feats.append(x.clone())
- for i, d in enumerate(self.binary_decoder):
- x = d(x)
- binary_feats.append(x.clone())
- #return None,binary_feats
- x = self.binary_pred(x)
- if self.training:
- return x, binary_feats
- else:
- # return torch.sigmoid(x), binary_feat
- return x.softmax(1), binary_feats
- class LayerNormProxy(nn.Module):
- def __init__(self, dim):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- def forward(self, x):
- x = x.permute(0, 2, 3, 1)
- x = self.norm(x)
- return x.permute(0, 3, 1, 2)
- class DAttentionFuse(nn.Module):
- def __init__(
- self,
- q_size=(4, 32),
- kv_size=(4, 32),
- n_heads=8,
- n_head_channels=80,
- n_groups=4,
- attn_drop=0.0,
- proj_drop=0.0,
- stride=2,
- offset_range_factor=2,
- use_pe=True,
- stage_idx=0,
- ):
- '''
- stage_idx from 2 to 3
- '''
- super().__init__()
- self.n_head_channels = n_head_channels
- self.scale = self.n_head_channels**-0.5
- self.n_heads = n_heads
- self.q_h, self.q_w = q_size
- self.kv_h, self.kv_w = kv_size
- self.nc = n_head_channels * n_heads
- self.n_groups = n_groups
- self.n_group_channels = self.nc // self.n_groups
- self.n_group_heads = self.n_heads // self.n_groups
- self.use_pe = use_pe
- self.offset_range_factor = offset_range_factor
- ksizes = [9, 7, 5, 3]
- kk = ksizes[stage_idx]
- self.conv_offset = nn.Sequential(
- nn.Conv2d(2 * self.n_group_channels,
- 2 * self.n_group_channels,
- kk,
- stride,
- kk // 2,
- groups=self.n_group_channels),
- LayerNormProxy(2 * self.n_group_channels), nn.GELU(),
- nn.Conv2d(2 * self.n_group_channels, 2, 1, 1, 0, bias=False))
- self.proj_q = nn.Conv2d(self.nc,
- self.nc,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_k = nn.Conv2d(self.nc,
- self.nc,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_v = nn.Conv2d(self.nc,
- self.nc,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_out = nn.Conv2d(self.nc,
- self.nc,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_drop = nn.Dropout(proj_drop, inplace=True)
- self.attn_drop = nn.Dropout(attn_drop, inplace=True)
- if self.use_pe:
- self.rpe_table = nn.Parameter(
- torch.zeros(self.n_heads, self.kv_h * 2 - 1,
- self.kv_w * 2 - 1))
- trunc_normal_(self.rpe_table, std=0.01)
- else:
- self.rpe_table = None
- @torch.no_grad()
- def _get_ref_points(self, H_key, W_key, B, dtype, device):
- ref_y, ref_x = torch.meshgrid(
- torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype,
- device=device),
- torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype,
- device=device))
- ref = torch.stack((ref_y, ref_x), -1)
- ref[..., 1].div_(W_key).mul_(2).sub_(1)
- ref[..., 0].div_(H_key).mul_(2).sub_(1)
- ref = ref[None, ...].expand(B * self.n_groups, -1, -1,
- -1) # B * g H W 2
- return ref
- def forward(self, x, y):
- B, C, H, W = x.size()
- dtype, device = x.dtype, x.device
- q_off = torch.cat(
- (x, y), dim=1
- ).reshape(B, self.n_groups, 2 * self.n_group_channels, H, W).flatten(
- 0, 1
- ) #einops.rearrange(torch.cat((x,y),dim=1), 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=2*self.n_group_channels)
- offset = self.conv_offset(q_off) # B * g 2 Hg Wg
- Hk, Wk = offset.size(2), offset.size(3)
- n_sample = Hk * Wk
- if self.offset_range_factor > 0:
- offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk],
- device=device).reshape(1, 2, 1, 1)
- offset = offset.tanh().mul(offset_range).mul(
- self.offset_range_factor)
- offset = offset.permute(
- 0, 2, 3, 1) #einops.rearrange(offset, 'b p h w -> b h w p')
- reference = self._get_ref_points(Hk, Wk, B, dtype, device)
- if self.offset_range_factor >= 0:
- pos = offset + reference
- else:
- pos = (offset + reference).tanh()
- q = self.proj_q(y)
- x_sampled = F.grid_sample(
- input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
- grid=pos[..., (1, 0)], # y, x -> x, y
- mode='bilinear',
- align_corners=False) # B * g, Cg, Hg, Wg
- x_sampled = x_sampled.reshape(B, C, 1, n_sample)
- q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
- k = self.proj_k(x_sampled).reshape(B * self.n_heads,
- self.n_head_channels, n_sample)
- v = self.proj_v(x_sampled).reshape(B * self.n_heads,
- self.n_head_channels, n_sample)
- attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
- attn = attn.mul(self.scale)
- if self.use_pe:
- rpe_table = self.rpe_table
- rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
- q_grid = self._get_ref_points(H, W, B, dtype, device)
- displacement = (
- q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) -
- pos.reshape(B * self.n_groups, n_sample,
- 2).unsqueeze(1)).mul(0.5)
- attn_bias = F.grid_sample(input=rpe_bias.reshape(
- B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
- grid=displacement[..., (1, 0)],
- mode='bilinear',
- align_corners=False)
- attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
- attn = attn + attn_bias
- attn = F.softmax(attn, dim=2)
- attn = self.attn_drop(attn)
- out = torch.einsum('b m n, b c n -> b c m', attn, v)
- out = out.reshape(B, C, H, W)
- out = self.proj_drop(self.proj_out(out))
- return out, pos.reshape(B, self.n_groups, Hk, Wk,
- 2), reference.reshape(B, self.n_groups, Hk, Wk,
- 2)
- class FuseModel(nn.Module):
- def __init__(self,
- dim,
- deform_stride=2,
- stage_idx=2,
- k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
- q_size=(2, 32)):
- super().__init__()
- channels = [dim // 2**i for i in range(4)]
- refine_conv = nn.Conv2d
- self.deform_stride = deform_stride
- in_out_ch = [(-1, -2), (-2, -3), (-3, -4), (-4, -4)]
- self.binary_condition_layer = DAttentionFuse(q_size=q_size,
- kv_size=q_size,
- stride=self.deform_stride,
- n_head_channels=dim // 8,
- stage_idx=stage_idx)
- self.binary2refine_linear_norm = nn.ModuleList()
- for i in range(len(k_size)):
- self.binary2refine_linear_norm.append(
- nn.Sequential(
- Block(dim=channels[in_out_ch[i][0]]),
- LayerNorm(channels[in_out_ch[i][0]],
- eps=1e-6,
- data_format='channels_first'),
- refine_conv(channels[in_out_ch[i][0]],
- channels[in_out_ch[i][1]],
- kernel_size=k_size[i],
- stride=k_size[i])), # [8, 32]
- )
- def forward(self, recog_feat, binary_feats, dec_in=None):
- multi_feat = []
- binary_feat = binary_feats[-1]
- for i in range(len(self.binary2refine_linear_norm)):
- binary_feat = self.binary2refine_linear_norm[i](binary_feat)
- multi_feat.append(binary_feat)
- binary_feat = binary_feat + binary_feats[0]
- multi_feat[3] += binary_feats[0]
- binary_refined_feat, pos, _ = self.binary_condition_layer(
- recog_feat, binary_feat)
- binary_refined_feat = binary_refined_feat + binary_feat
- return binary_refined_feat, binary_feat
- class CAMEncoder(nn.Module):
- """
- Args:
- in_chans (int): Number of input image channels. Default: 3
- depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
- dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
- drop_path_rate (float): Stochastic depth rate. Default: 0.
- head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
- """
- def __init__(self,
- in_channels=3,
- encoder_config={'name': 'ConvNeXtV2'},
- nb_classes=71,
- strides=[(4, 4), (2, 1), (2, 1), (1, 1)],
- k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
- q_size=[2, 32],
- deform_stride=2,
- stage_idx=2,
- use_depthwise_unet=True,
- use_more_unet=False,
- binary_loss_type='BanlanceMultiClassCrossEntropyLoss',
- mid_size=True,
- d_embedding=384):
- super().__init__()
- encoder_name = encoder_config.pop('name')
- encoder_config['in_channels'] = in_channels
- self.backbone = eval(encoder_name)(**encoder_config)
- dim = self.backbone.out_channels
- self.mid_size = mid_size
- if self.mid_size:
- self.enc_downsample = nn.Sequential(
- nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
- nn.SyncBatchNorm(dim // 2),
- #nn.ReLU6(inplace=True),
- nn.Conv2d(dim // 2,
- dim // 2,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=dim // 2),
- nn.Conv2d(dim // 2,
- dim // 2,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=False),
- nn.SyncBatchNorm(dim // 2),
- )
- dim = dim // 2
- # recognition decoder
- self.linear_enc2recog = nn.Sequential(
- nn.Conv2d(
- dim,
- dim,
- kernel_size=1,
- stride=1,
- ),
- nn.SyncBatchNorm(dim),
- #nn.ReLU6(inplace=True),
- nn.Conv2d(dim,
- dim,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=dim),
- nn.Conv2d(dim,
- dim,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=False),
- nn.SyncBatchNorm(dim),
- )
- else:
- self.linear_enc2recog = nn.Sequential(
- nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
- nn.SyncBatchNorm(dim // 2),
- #nn.ReLU6(inplace=True),
- nn.Conv2d(dim // 2, dim, kernel_size=3, stride=1, padding=1),
- nn.SyncBatchNorm(dim),
- )
- self.linear_norm = nn.Sequential(
- nn.Linear(dim, d_embedding),
- nn.LayerNorm(d_embedding, eps=1e-6),
- )
- self.out_channels = d_embedding
- self.binary_decoder = BinaryDecoder(
- dim,
- nb_classes,
- strides,
- use_depthwise_unet=use_depthwise_unet,
- use_more_unet=use_more_unet,
- binary_loss_type=binary_loss_type)
- self.fuse_model = FuseModel(dim,
- deform_stride=deform_stride,
- stage_idx=stage_idx,
- k_size=k_size,
- q_size=q_size)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv2d, nn.Linear)):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if isinstance(m, nn.ConvTranspose2d):
- nn.init.kaiming_normal_(m.weight,
- mode='fan_out',
- nonlinearity='relu')
- if m.bias is not None:
- nn.init.constant_(m.bias, 0.)
- elif isinstance(m, nn.LayerNorm):
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if m.weight is not None:
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.SyncBatchNorm):
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if m.weight is not None:
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.BatchNorm2d):
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if m.weight is not None:
- nn.init.constant_(m.weight, 1.0)
- def no_weight_decay(self):
- return {}
- def forward(self, x):
- output = {}
- enc_feat = self.backbone(x)
- if self.mid_size:
- enc_feat = self.enc_downsample(enc_feat)
- output['enc_feat'] = enc_feat
- # binary mask
- pred_binary, binary_feats = self.binary_decoder(enc_feat)
- output['pred_binary'] = pred_binary
- reg_feat = self.linear_enc2recog(enc_feat)
- B, C, H, W = reg_feat.shape
- last_feat, binary_feat = self.fuse_model(reg_feat, binary_feats)
- dec_in = last_feat.reshape(B, C, H * W).permute(0, 2, 1)
- dec_in = self.linear_norm(dec_in)
- output['refined_feat'] = dec_in
- output['binary_feat'] = binary_feats[-1]
- return output
|