cam_encoder.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. """This code is refer from:
  2. https://github.com/MelosY/CAM
  3. """
  4. # Copyright (c) Meta Platforms, Inc. and affiliates.
  5. # All rights reserved.
  6. # This source code is licensed under the license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from torch.nn.init import trunc_normal_
  12. from .convnextv2 import ConvNeXtV2, Block, LayerNorm
  13. from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33
  14. class Swish(nn.Module):
  15. def __init__(self) -> None:
  16. super().__init__()
  17. def forward(self, x):
  18. return x * torch.sigmoid(x)
  19. class UNetBlock(nn.Module):
  20. def __init__(self, cin, cout, bn2d, stride, deformable=False):
  21. """
  22. a UNet block with 2x up sampling
  23. """
  24. super().__init__()
  25. stride_h, stride_w = stride
  26. if stride_h == 1:
  27. kernel_h = 1
  28. padding_h = 0
  29. elif stride_h == 2:
  30. kernel_h = 4
  31. padding_h = 1
  32. elif stride_h == 4:
  33. kernel_h = 4
  34. padding_h = 0
  35. if stride_w == 1:
  36. kernel_w = 1
  37. padding_w = 0
  38. elif stride_w == 2:
  39. kernel_w = 4
  40. padding_w = 1
  41. elif stride_w == 4:
  42. kernel_w = 4
  43. padding_w = 0
  44. conv = nn.Conv2d
  45. self.up_sample = nn.ConvTranspose2d(cin,
  46. cin,
  47. kernel_size=(kernel_h, kernel_w),
  48. stride=(stride_h, stride_w),
  49. padding=(padding_h, padding_w),
  50. bias=True)
  51. self.conv = nn.Sequential(
  52. conv(cin, cin, kernel_size=3, stride=1, padding=1, bias=False),
  53. bn2d(cin),
  54. nn.ReLU6(inplace=True),
  55. conv(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
  56. bn2d(cout),
  57. )
  58. def forward(self, x):
  59. x = self.up_sample(x)
  60. return self.conv(x)
  61. class DepthWiseUNetBlock(nn.Module):
  62. def __init__(self, cin, cout, bn2d, stride, deformable=False):
  63. """
  64. a UNet block with 2x up sampling
  65. """
  66. super().__init__()
  67. stride_h, stride_w = stride
  68. if stride_h == 1:
  69. kernel_h = 1
  70. padding_h = 0
  71. elif stride_h == 2:
  72. kernel_h = 4
  73. padding_h = 1
  74. elif stride_h == 4:
  75. kernel_h = 4
  76. padding_h = 0
  77. if stride_w == 1:
  78. kernel_w = 1
  79. padding_w = 0
  80. elif stride_w == 2:
  81. kernel_w = 4
  82. padding_w = 1
  83. elif stride_w == 4:
  84. kernel_w = 4
  85. padding_w = 0
  86. self.up_sample = nn.ConvTranspose2d(cin,
  87. cin,
  88. kernel_size=(kernel_h, kernel_w),
  89. stride=(stride_h, stride_w),
  90. padding=(padding_h, padding_w),
  91. bias=True)
  92. self.conv = nn.Sequential(
  93. nn.Conv2d(cin,
  94. cin,
  95. kernel_size=3,
  96. stride=1,
  97. padding=1,
  98. bias=False,
  99. groups=cin),
  100. nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
  101. bias=False),
  102. bn2d(cin),
  103. nn.ReLU6(inplace=True),
  104. nn.Conv2d(cin,
  105. cin,
  106. kernel_size=3,
  107. stride=1,
  108. padding=1,
  109. bias=False,
  110. groups=cin),
  111. nn.Conv2d(cin,
  112. cout,
  113. kernel_size=1,
  114. stride=1,
  115. padding=0,
  116. bias=False),
  117. bn2d(cout),
  118. )
  119. def forward(self, x):
  120. x = self.up_sample(x)
  121. return self.conv(x)
  122. class SFTLayer(nn.Module):
  123. def __init__(self, dim_in, dim_out):
  124. super(SFTLayer, self).__init__()
  125. self.SFT_scale_conv0 = nn.Linear(
  126. dim_in,
  127. dim_in,
  128. )
  129. self.SFT_scale_conv1 = nn.Linear(
  130. dim_in,
  131. dim_out,
  132. )
  133. self.SFT_shift_conv0 = nn.Linear(
  134. dim_in,
  135. dim_in,
  136. )
  137. self.SFT_shift_conv1 = nn.Linear(
  138. dim_in,
  139. dim_out,
  140. )
  141. def forward(self, x):
  142. # x[0]: fea; x[1]: cond
  143. scale = self.SFT_scale_conv1(
  144. F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
  145. shift = self.SFT_shift_conv1(
  146. F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
  147. return x[0] * (scale + 1) + shift
  148. class MoreUNetBlock(nn.Module):
  149. def __init__(self, cin, cout, bn2d, stride, deformable=False):
  150. """
  151. a UNet block with 2x up sampling
  152. """
  153. super().__init__()
  154. stride_h, stride_w = stride
  155. if stride_h == 1:
  156. kernel_h = 1
  157. padding_h = 0
  158. elif stride_h == 2:
  159. kernel_h = 4
  160. padding_h = 1
  161. elif stride_h == 4:
  162. kernel_h = 4
  163. padding_h = 0
  164. if stride_w == 1:
  165. kernel_w = 1
  166. padding_w = 0
  167. elif stride_w == 2:
  168. kernel_w = 4
  169. padding_w = 1
  170. elif stride_w == 4:
  171. kernel_w = 4
  172. padding_w = 0
  173. self.up_sample = nn.ConvTranspose2d(cin,
  174. cin,
  175. kernel_size=(kernel_h, kernel_w),
  176. stride=(stride_h, stride_w),
  177. padding=(padding_h, padding_w),
  178. bias=True)
  179. self.conv = nn.Sequential(
  180. nn.Conv2d(cin,
  181. cin,
  182. kernel_size=3,
  183. stride=1,
  184. padding=1,
  185. bias=False,
  186. groups=cin),
  187. nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
  188. bias=False), bn2d(cin), nn.ReLU6(inplace=True),
  189. nn.Conv2d(cin,
  190. cin,
  191. kernel_size=3,
  192. stride=1,
  193. padding=1,
  194. bias=False,
  195. groups=cin),
  196. nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
  197. bias=False), bn2d(cin), nn.ReLU6(inplace=True),
  198. nn.Conv2d(cin,
  199. cin,
  200. kernel_size=3,
  201. stride=1,
  202. padding=1,
  203. bias=False,
  204. groups=cin),
  205. nn.Conv2d(cin,
  206. cout,
  207. kernel_size=1,
  208. stride=1,
  209. padding=0,
  210. bias=False), bn2d(cout), nn.ReLU6(inplace=True),
  211. nn.Conv2d(cout,
  212. cout,
  213. kernel_size=3,
  214. stride=1,
  215. padding=1,
  216. bias=False,
  217. groups=cout),
  218. nn.Conv2d(cout,
  219. cout,
  220. kernel_size=1,
  221. stride=1,
  222. padding=0,
  223. bias=False), bn2d(cout))
  224. def forward(self, x):
  225. x = self.up_sample(x)
  226. return self.conv(x)
  227. class BinaryDecoder(nn.Module):
  228. def __init__(self,
  229. dim,
  230. num_classes,
  231. strides,
  232. use_depthwise_unet=False,
  233. use_more_unet=False,
  234. binary_loss_type='DiceLoss') -> None:
  235. super().__init__()
  236. channels = [dim // 2**i for i in range(4)]
  237. self.linear_enc2binary = nn.Sequential(
  238. nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),
  239. nn.SyncBatchNorm(dim),
  240. )
  241. self.strides = strides
  242. self.use_deformable = False
  243. self.binary_decoder = nn.ModuleList()
  244. unet = DepthWiseUNetBlock if use_depthwise_unet else UNetBlock
  245. unet = MoreUNetBlock if use_more_unet else unet
  246. for i in range(3):
  247. up_sample_stride = self.strides[::-1][i]
  248. cin, cout = channels[i], channels[i + 1]
  249. self.binary_decoder.append(
  250. unet(cin, cout, nn.SyncBatchNorm, up_sample_stride,
  251. self.use_deformable))
  252. last_stride = (self.strides[0][0] // 2, self.strides[0][1] // 2)
  253. self.binary_decoder.append(
  254. unet(cout, cout, nn.SyncBatchNorm, last_stride,
  255. self.use_deformable))
  256. if binary_loss_type == 'CrossEntropyDiceLoss' or binary_loss_type == 'BanlanceMultiClassCrossEntropyLoss':
  257. segm_num_cls = num_classes - 2
  258. else:
  259. segm_num_cls = num_classes - 3
  260. self.binary_pred = nn.Conv2d(channels[-1],
  261. segm_num_cls,
  262. kernel_size=1,
  263. stride=1,
  264. bias=True)
  265. def patchify(self, imgs):
  266. """
  267. imgs: (N, 3, H, W)
  268. x: (N, L, patch_size**2 *3)
  269. """
  270. p_h, p_w = self.strides[0]
  271. p_h = p_h // 2
  272. p_w = p_w // 2
  273. h = imgs.shape[2] // p_h
  274. w = imgs.shape[3] // p_w
  275. x = imgs.reshape(shape=(imgs.shape[0], 3, h, p_h, w, p_w))
  276. x = torch.einsum('nchpwq->nhwpqc', x)
  277. x = x.reshape(shape=(imgs.shape[0], h * w, p_h * p_w * 3))
  278. return x
  279. def unpatchify(self, x):
  280. """
  281. x: (N, patch_size**2, h, w)
  282. imgs: (N, 3, H, W)
  283. """
  284. p_h, p_w = self.strides[0]
  285. p_h = p_h // 2
  286. p_w = p_w // 2
  287. _, _, h, w = x.shape
  288. assert p_h * p_w == x.shape[1]
  289. x = x.permute(0, 2, 3, 1) # [N, h, w, 4*4]
  290. x = x.reshape(shape=(x.shape[0], h, w, p_h, p_w))
  291. x = torch.einsum('nhwpq->nhpwq', x)
  292. imgs = x.reshape(shape=(x.shape[0], h * p_h, w * p_w))
  293. return imgs
  294. def forward(self, x, time=None):
  295. """
  296. x: the encoder feat to init the query for binary prediction, usually this is equal to the `img`.
  297. img: the encoder feat.
  298. txt: the unnormmed text to get the length of predicted words.
  299. txt_feat: the text feat before character prediction.
  300. xs: the encoder feat from different stages
  301. """
  302. binary_feats = []
  303. x = self.linear_enc2binary(x)
  304. binary_feats.append(x.clone())
  305. for i, d in enumerate(self.binary_decoder):
  306. x = d(x)
  307. binary_feats.append(x.clone())
  308. #return None,binary_feats
  309. x = self.binary_pred(x)
  310. if self.training:
  311. return x, binary_feats
  312. else:
  313. # return torch.sigmoid(x), binary_feat
  314. return x.softmax(1), binary_feats
  315. class LayerNormProxy(nn.Module):
  316. def __init__(self, dim):
  317. super().__init__()
  318. self.norm = nn.LayerNorm(dim)
  319. def forward(self, x):
  320. x = x.permute(0, 2, 3, 1)
  321. x = self.norm(x)
  322. return x.permute(0, 3, 1, 2)
  323. class DAttentionFuse(nn.Module):
  324. def __init__(
  325. self,
  326. q_size=(4, 32),
  327. kv_size=(4, 32),
  328. n_heads=8,
  329. n_head_channels=80,
  330. n_groups=4,
  331. attn_drop=0.0,
  332. proj_drop=0.0,
  333. stride=2,
  334. offset_range_factor=2,
  335. use_pe=True,
  336. stage_idx=0,
  337. ):
  338. '''
  339. stage_idx from 2 to 3
  340. '''
  341. super().__init__()
  342. self.n_head_channels = n_head_channels
  343. self.scale = self.n_head_channels**-0.5
  344. self.n_heads = n_heads
  345. self.q_h, self.q_w = q_size
  346. self.kv_h, self.kv_w = kv_size
  347. self.nc = n_head_channels * n_heads
  348. self.n_groups = n_groups
  349. self.n_group_channels = self.nc // self.n_groups
  350. self.n_group_heads = self.n_heads // self.n_groups
  351. self.use_pe = use_pe
  352. self.offset_range_factor = offset_range_factor
  353. ksizes = [9, 7, 5, 3]
  354. kk = ksizes[stage_idx]
  355. self.conv_offset = nn.Sequential(
  356. nn.Conv2d(2 * self.n_group_channels,
  357. 2 * self.n_group_channels,
  358. kk,
  359. stride,
  360. kk // 2,
  361. groups=self.n_group_channels),
  362. LayerNormProxy(2 * self.n_group_channels), nn.GELU(),
  363. nn.Conv2d(2 * self.n_group_channels, 2, 1, 1, 0, bias=False))
  364. self.proj_q = nn.Conv2d(self.nc,
  365. self.nc,
  366. kernel_size=1,
  367. stride=1,
  368. padding=0)
  369. self.proj_k = nn.Conv2d(self.nc,
  370. self.nc,
  371. kernel_size=1,
  372. stride=1,
  373. padding=0)
  374. self.proj_v = nn.Conv2d(self.nc,
  375. self.nc,
  376. kernel_size=1,
  377. stride=1,
  378. padding=0)
  379. self.proj_out = nn.Conv2d(self.nc,
  380. self.nc,
  381. kernel_size=1,
  382. stride=1,
  383. padding=0)
  384. self.proj_drop = nn.Dropout(proj_drop, inplace=True)
  385. self.attn_drop = nn.Dropout(attn_drop, inplace=True)
  386. if self.use_pe:
  387. self.rpe_table = nn.Parameter(
  388. torch.zeros(self.n_heads, self.kv_h * 2 - 1,
  389. self.kv_w * 2 - 1))
  390. trunc_normal_(self.rpe_table, std=0.01)
  391. else:
  392. self.rpe_table = None
  393. @torch.no_grad()
  394. def _get_ref_points(self, H_key, W_key, B, dtype, device):
  395. ref_y, ref_x = torch.meshgrid(
  396. torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype,
  397. device=device),
  398. torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype,
  399. device=device))
  400. ref = torch.stack((ref_y, ref_x), -1)
  401. ref[..., 1].div_(W_key).mul_(2).sub_(1)
  402. ref[..., 0].div_(H_key).mul_(2).sub_(1)
  403. ref = ref[None, ...].expand(B * self.n_groups, -1, -1,
  404. -1) # B * g H W 2
  405. return ref
  406. def forward(self, x, y):
  407. B, C, H, W = x.size()
  408. dtype, device = x.dtype, x.device
  409. q_off = torch.cat(
  410. (x, y), dim=1
  411. ).reshape(B, self.n_groups, 2 * self.n_group_channels, H, W).flatten(
  412. 0, 1
  413. ) #einops.rearrange(torch.cat((x,y),dim=1), 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=2*self.n_group_channels)
  414. offset = self.conv_offset(q_off) # B * g 2 Hg Wg
  415. Hk, Wk = offset.size(2), offset.size(3)
  416. n_sample = Hk * Wk
  417. if self.offset_range_factor > 0:
  418. offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk],
  419. device=device).reshape(1, 2, 1, 1)
  420. offset = offset.tanh().mul(offset_range).mul(
  421. self.offset_range_factor)
  422. offset = offset.permute(
  423. 0, 2, 3, 1) #einops.rearrange(offset, 'b p h w -> b h w p')
  424. reference = self._get_ref_points(Hk, Wk, B, dtype, device)
  425. if self.offset_range_factor >= 0:
  426. pos = offset + reference
  427. else:
  428. pos = (offset + reference).tanh()
  429. q = self.proj_q(y)
  430. x_sampled = F.grid_sample(
  431. input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
  432. grid=pos[..., (1, 0)], # y, x -> x, y
  433. mode='bilinear',
  434. align_corners=False) # B * g, Cg, Hg, Wg
  435. x_sampled = x_sampled.reshape(B, C, 1, n_sample)
  436. q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
  437. k = self.proj_k(x_sampled).reshape(B * self.n_heads,
  438. self.n_head_channels, n_sample)
  439. v = self.proj_v(x_sampled).reshape(B * self.n_heads,
  440. self.n_head_channels, n_sample)
  441. attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
  442. attn = attn.mul(self.scale)
  443. if self.use_pe:
  444. rpe_table = self.rpe_table
  445. rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
  446. q_grid = self._get_ref_points(H, W, B, dtype, device)
  447. displacement = (
  448. q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) -
  449. pos.reshape(B * self.n_groups, n_sample,
  450. 2).unsqueeze(1)).mul(0.5)
  451. attn_bias = F.grid_sample(input=rpe_bias.reshape(
  452. B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
  453. grid=displacement[..., (1, 0)],
  454. mode='bilinear',
  455. align_corners=False)
  456. attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
  457. attn = attn + attn_bias
  458. attn = F.softmax(attn, dim=2)
  459. attn = self.attn_drop(attn)
  460. out = torch.einsum('b m n, b c n -> b c m', attn, v)
  461. out = out.reshape(B, C, H, W)
  462. out = self.proj_drop(self.proj_out(out))
  463. return out, pos.reshape(B, self.n_groups, Hk, Wk,
  464. 2), reference.reshape(B, self.n_groups, Hk, Wk,
  465. 2)
  466. class FuseModel(nn.Module):
  467. def __init__(self,
  468. dim,
  469. deform_stride=2,
  470. stage_idx=2,
  471. k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
  472. q_size=(2, 32)):
  473. super().__init__()
  474. channels = [dim // 2**i for i in range(4)]
  475. refine_conv = nn.Conv2d
  476. self.deform_stride = deform_stride
  477. in_out_ch = [(-1, -2), (-2, -3), (-3, -4), (-4, -4)]
  478. self.binary_condition_layer = DAttentionFuse(q_size=q_size,
  479. kv_size=q_size,
  480. stride=self.deform_stride,
  481. n_head_channels=dim // 8,
  482. stage_idx=stage_idx)
  483. self.binary2refine_linear_norm = nn.ModuleList()
  484. for i in range(len(k_size)):
  485. self.binary2refine_linear_norm.append(
  486. nn.Sequential(
  487. Block(dim=channels[in_out_ch[i][0]]),
  488. LayerNorm(channels[in_out_ch[i][0]],
  489. eps=1e-6,
  490. data_format='channels_first'),
  491. refine_conv(channels[in_out_ch[i][0]],
  492. channels[in_out_ch[i][1]],
  493. kernel_size=k_size[i],
  494. stride=k_size[i])), # [8, 32]
  495. )
  496. def forward(self, recog_feat, binary_feats, dec_in=None):
  497. multi_feat = []
  498. binary_feat = binary_feats[-1]
  499. for i in range(len(self.binary2refine_linear_norm)):
  500. binary_feat = self.binary2refine_linear_norm[i](binary_feat)
  501. multi_feat.append(binary_feat)
  502. binary_feat = binary_feat + binary_feats[0]
  503. multi_feat[3] += binary_feats[0]
  504. binary_refined_feat, pos, _ = self.binary_condition_layer(
  505. recog_feat, binary_feat)
  506. binary_refined_feat = binary_refined_feat + binary_feat
  507. return binary_refined_feat, binary_feat
  508. class CAMEncoder(nn.Module):
  509. """
  510. Args:
  511. in_chans (int): Number of input image channels. Default: 3
  512. depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
  513. dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
  514. drop_path_rate (float): Stochastic depth rate. Default: 0.
  515. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
  516. """
  517. def __init__(self,
  518. in_channels=3,
  519. encoder_config={'name': 'ConvNeXtV2'},
  520. nb_classes=71,
  521. strides=[(4, 4), (2, 1), (2, 1), (1, 1)],
  522. k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
  523. q_size=[2, 32],
  524. deform_stride=2,
  525. stage_idx=2,
  526. use_depthwise_unet=True,
  527. use_more_unet=False,
  528. binary_loss_type='BanlanceMultiClassCrossEntropyLoss',
  529. mid_size=True,
  530. d_embedding=384):
  531. super().__init__()
  532. encoder_name = encoder_config.pop('name')
  533. encoder_config['in_channels'] = in_channels
  534. self.backbone = eval(encoder_name)(**encoder_config)
  535. dim = self.backbone.out_channels
  536. self.mid_size = mid_size
  537. if self.mid_size:
  538. self.enc_downsample = nn.Sequential(
  539. nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
  540. nn.SyncBatchNorm(dim // 2),
  541. #nn.ReLU6(inplace=True),
  542. nn.Conv2d(dim // 2,
  543. dim // 2,
  544. kernel_size=3,
  545. stride=1,
  546. padding=1,
  547. bias=False,
  548. groups=dim // 2),
  549. nn.Conv2d(dim // 2,
  550. dim // 2,
  551. kernel_size=1,
  552. stride=1,
  553. padding=0,
  554. bias=False),
  555. nn.SyncBatchNorm(dim // 2),
  556. )
  557. dim = dim // 2
  558. # recognition decoder
  559. self.linear_enc2recog = nn.Sequential(
  560. nn.Conv2d(
  561. dim,
  562. dim,
  563. kernel_size=1,
  564. stride=1,
  565. ),
  566. nn.SyncBatchNorm(dim),
  567. #nn.ReLU6(inplace=True),
  568. nn.Conv2d(dim,
  569. dim,
  570. kernel_size=3,
  571. stride=1,
  572. padding=1,
  573. bias=False,
  574. groups=dim),
  575. nn.Conv2d(dim,
  576. dim,
  577. kernel_size=1,
  578. stride=1,
  579. padding=0,
  580. bias=False),
  581. nn.SyncBatchNorm(dim),
  582. )
  583. else:
  584. self.linear_enc2recog = nn.Sequential(
  585. nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
  586. nn.SyncBatchNorm(dim // 2),
  587. #nn.ReLU6(inplace=True),
  588. nn.Conv2d(dim // 2, dim, kernel_size=3, stride=1, padding=1),
  589. nn.SyncBatchNorm(dim),
  590. )
  591. self.linear_norm = nn.Sequential(
  592. nn.Linear(dim, d_embedding),
  593. nn.LayerNorm(d_embedding, eps=1e-6),
  594. )
  595. self.out_channels = d_embedding
  596. self.binary_decoder = BinaryDecoder(
  597. dim,
  598. nb_classes,
  599. strides,
  600. use_depthwise_unet=use_depthwise_unet,
  601. use_more_unet=use_more_unet,
  602. binary_loss_type=binary_loss_type)
  603. self.fuse_model = FuseModel(dim,
  604. deform_stride=deform_stride,
  605. stage_idx=stage_idx,
  606. k_size=k_size,
  607. q_size=q_size)
  608. self.apply(self._init_weights)
  609. def _init_weights(self, m):
  610. if isinstance(m, (nn.Conv2d, nn.Linear)):
  611. trunc_normal_(m.weight, std=.02)
  612. if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None:
  613. nn.init.constant_(m.bias, 0)
  614. if isinstance(m, nn.ConvTranspose2d):
  615. nn.init.kaiming_normal_(m.weight,
  616. mode='fan_out',
  617. nonlinearity='relu')
  618. if m.bias is not None:
  619. nn.init.constant_(m.bias, 0.)
  620. elif isinstance(m, nn.LayerNorm):
  621. if m.bias is not None:
  622. nn.init.constant_(m.bias, 0)
  623. if m.weight is not None:
  624. nn.init.constant_(m.weight, 1.0)
  625. elif isinstance(m, nn.SyncBatchNorm):
  626. if m.bias is not None:
  627. nn.init.constant_(m.bias, 0)
  628. if m.weight is not None:
  629. nn.init.constant_(m.weight, 1.0)
  630. elif isinstance(m, nn.BatchNorm2d):
  631. if m.bias is not None:
  632. nn.init.constant_(m.bias, 0)
  633. if m.weight is not None:
  634. nn.init.constant_(m.weight, 1.0)
  635. def no_weight_decay(self):
  636. return {}
  637. def forward(self, x):
  638. output = {}
  639. enc_feat = self.backbone(x)
  640. if self.mid_size:
  641. enc_feat = self.enc_downsample(enc_feat)
  642. output['enc_feat'] = enc_feat
  643. # binary mask
  644. pred_binary, binary_feats = self.binary_decoder(enc_feat)
  645. output['pred_binary'] = pred_binary
  646. reg_feat = self.linear_enc2recog(enc_feat)
  647. B, C, H, W = reg_feat.shape
  648. last_feat, binary_feat = self.fuse_model(reg_feat, binary_feats)
  649. dec_in = last_feat.reshape(B, C, H * W).permute(0, 2, 1)
  650. dec_in = self.linear_norm(dec_in)
  651. output['refined_feat'] = dec_in
  652. output['binary_feat'] = binary_feats[-1]
  653. return output