db_fpn.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. class SEModule(nn.Module):
  5. def __init__(self, in_channels, reduction=4):
  6. super(SEModule, self).__init__()
  7. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  8. self.conv1 = nn.Conv2d(
  9. in_channels=in_channels,
  10. out_channels=in_channels // reduction,
  11. kernel_size=1,
  12. stride=1,
  13. padding=0,
  14. )
  15. self.conv2 = nn.Conv2d(
  16. in_channels=in_channels // reduction,
  17. out_channels=in_channels,
  18. kernel_size=1,
  19. stride=1,
  20. padding=0,
  21. )
  22. def forward(self, inputs):
  23. outputs = self.avg_pool(inputs)
  24. outputs = self.conv1(outputs)
  25. outputs = F.relu(outputs)
  26. outputs = self.conv2(outputs)
  27. outputs = F.hardsigmoid(outputs)
  28. return inputs * outputs
  29. class IntraCLBlock(nn.Module):
  30. def __init__(self, in_channels=96, reduce_factor=4):
  31. super(IntraCLBlock, self).__init__()
  32. self.channels = in_channels
  33. self.rf = reduce_factor
  34. # weight_attr = paddle.nn.initializer.KaimingUniform()
  35. self.conv1x1_reduce_channel = nn.Conv2d(self.channels,
  36. self.channels // self.rf,
  37. kernel_size=1,
  38. stride=1,
  39. padding=0)
  40. self.conv1x1_return_channel = nn.Conv2d(self.channels // self.rf,
  41. self.channels,
  42. kernel_size=1,
  43. stride=1,
  44. padding=0)
  45. self.v_layer_7x1 = nn.Conv2d(
  46. self.channels // self.rf,
  47. self.channels // self.rf,
  48. kernel_size=(7, 1),
  49. stride=(1, 1),
  50. padding=(3, 0),
  51. )
  52. self.v_layer_5x1 = nn.Conv2d(
  53. self.channels // self.rf,
  54. self.channels // self.rf,
  55. kernel_size=(5, 1),
  56. stride=(1, 1),
  57. padding=(2, 0),
  58. )
  59. self.v_layer_3x1 = nn.Conv2d(
  60. self.channels // self.rf,
  61. self.channels // self.rf,
  62. kernel_size=(3, 1),
  63. stride=(1, 1),
  64. padding=(1, 0),
  65. )
  66. self.q_layer_1x7 = nn.Conv2d(
  67. self.channels // self.rf,
  68. self.channels // self.rf,
  69. kernel_size=(1, 7),
  70. stride=(1, 1),
  71. padding=(0, 3),
  72. )
  73. self.q_layer_1x5 = nn.Conv2d(
  74. self.channels // self.rf,
  75. self.channels // self.rf,
  76. kernel_size=(1, 5),
  77. stride=(1, 1),
  78. padding=(0, 2),
  79. )
  80. self.q_layer_1x3 = nn.Conv2d(
  81. self.channels // self.rf,
  82. self.channels // self.rf,
  83. kernel_size=(1, 3),
  84. stride=(1, 1),
  85. padding=(0, 1),
  86. )
  87. # base
  88. self.c_layer_7x7 = nn.Conv2d(
  89. self.channels // self.rf,
  90. self.channels // self.rf,
  91. kernel_size=(7, 7),
  92. stride=(1, 1),
  93. padding=(3, 3),
  94. )
  95. self.c_layer_5x5 = nn.Conv2d(
  96. self.channels // self.rf,
  97. self.channels // self.rf,
  98. kernel_size=(5, 5),
  99. stride=(1, 1),
  100. padding=(2, 2),
  101. )
  102. self.c_layer_3x3 = nn.Conv2d(
  103. self.channels // self.rf,
  104. self.channels // self.rf,
  105. kernel_size=(3, 3),
  106. stride=(1, 1),
  107. padding=(1, 1),
  108. )
  109. self.bn = nn.BatchNorm2d(self.channels)
  110. self.relu = nn.ReLU()
  111. def forward(self, x):
  112. x_new = self.conv1x1_reduce_channel(x)
  113. x_7_c = self.c_layer_7x7(x_new)
  114. x_7_v = self.v_layer_7x1(x_new)
  115. x_7_q = self.q_layer_1x7(x_new)
  116. x_7 = x_7_c + x_7_v + x_7_q
  117. x_5_c = self.c_layer_5x5(x_7)
  118. x_5_v = self.v_layer_5x1(x_7)
  119. x_5_q = self.q_layer_1x5(x_7)
  120. x_5 = x_5_c + x_5_v + x_5_q
  121. x_3_c = self.c_layer_3x3(x_5)
  122. x_3_v = self.v_layer_3x1(x_5)
  123. x_3_q = self.q_layer_1x3(x_5)
  124. x_3 = x_3_c + x_3_v + x_3_q
  125. x_relation = self.conv1x1_return_channel(x_3)
  126. x_relation = self.bn(x_relation)
  127. x_relation = self.relu(x_relation)
  128. return x + x_relation
  129. class DSConv(nn.Module):
  130. def __init__(
  131. self,
  132. in_channels,
  133. out_channels,
  134. kernel_size,
  135. padding,
  136. stride=1,
  137. groups=None,
  138. if_act=True,
  139. act='relu',
  140. **kwargs,
  141. ):
  142. super(DSConv, self).__init__()
  143. if groups is None:
  144. groups = in_channels
  145. self.if_act = if_act
  146. self.act = act
  147. self.conv1 = nn.Conv2d(
  148. in_channels=in_channels,
  149. out_channels=in_channels,
  150. kernel_size=kernel_size,
  151. stride=stride,
  152. padding=padding,
  153. groups=groups,
  154. bias=False,
  155. )
  156. self.bn1 = nn.BatchNorm2d(num_features=in_channels)
  157. self.conv2 = nn.Conv2d(
  158. in_channels=in_channels,
  159. out_channels=int(in_channels * 4),
  160. kernel_size=1,
  161. stride=1,
  162. bias=False,
  163. )
  164. self.bn2 = nn.BatchNorm2d(num_features=int(in_channels * 4))
  165. self.conv3 = nn.Conv2d(
  166. in_channels=int(in_channels * 4),
  167. out_channels=out_channels,
  168. kernel_size=1,
  169. stride=1,
  170. bias=False,
  171. )
  172. self._c = [in_channels, out_channels]
  173. if in_channels != out_channels:
  174. self.conv_end = nn.Conv2d(
  175. in_channels=in_channels,
  176. out_channels=out_channels,
  177. kernel_size=1,
  178. stride=1,
  179. bias=False,
  180. )
  181. def forward(self, inputs):
  182. x = self.conv1(inputs)
  183. x = self.bn1(x)
  184. x = self.conv2(x)
  185. x = self.bn2(x)
  186. if self.if_act:
  187. if self.act == 'relu':
  188. x = F.relu(x)
  189. elif self.act == 'hardswish':
  190. x = F.hardswish(x)
  191. else:
  192. print('The activation function({}) is selected incorrectly.'.
  193. format(self.act))
  194. exit()
  195. x = self.conv3(x)
  196. if self._c[0] != self._c[1]:
  197. x = x + self.conv_end(inputs)
  198. return x
  199. class DBFPN(nn.Module):
  200. def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
  201. super(DBFPN, self).__init__()
  202. self.out_channels = out_channels
  203. self.use_asf = use_asf
  204. # weight_attr = paddle.nn.initializer.KaimingUniform()
  205. self.in2_conv = nn.Conv2d(
  206. in_channels=in_channels[0],
  207. out_channels=self.out_channels,
  208. kernel_size=1,
  209. bias=False,
  210. )
  211. self.in3_conv = nn.Conv2d(
  212. in_channels=in_channels[1],
  213. out_channels=self.out_channels,
  214. kernel_size=1,
  215. bias=False,
  216. )
  217. self.in4_conv = nn.Conv2d(
  218. in_channels=in_channels[2],
  219. out_channels=self.out_channels,
  220. kernel_size=1,
  221. bias=False,
  222. )
  223. self.in5_conv = nn.Conv2d(
  224. in_channels=in_channels[3],
  225. out_channels=self.out_channels,
  226. kernel_size=1,
  227. bias=False,
  228. )
  229. self.p5_conv = nn.Conv2d(
  230. in_channels=self.out_channels,
  231. out_channels=self.out_channels // 4,
  232. kernel_size=3,
  233. padding=1,
  234. bias=False,
  235. )
  236. self.p4_conv = nn.Conv2d(
  237. in_channels=self.out_channels,
  238. out_channels=self.out_channels // 4,
  239. kernel_size=3,
  240. padding=1,
  241. bias=False,
  242. )
  243. self.p3_conv = nn.Conv2d(
  244. in_channels=self.out_channels,
  245. out_channels=self.out_channels // 4,
  246. kernel_size=3,
  247. padding=1,
  248. bias=False,
  249. )
  250. self.p2_conv = nn.Conv2d(
  251. in_channels=self.out_channels,
  252. out_channels=self.out_channels // 4,
  253. kernel_size=3,
  254. padding=1,
  255. bias=False,
  256. )
  257. if self.use_asf is True:
  258. self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
  259. def forward(self, x):
  260. c2, c3, c4, c5 = x
  261. in5 = self.in5_conv(c5)
  262. in4 = self.in4_conv(c4)
  263. in3 = self.in3_conv(c3)
  264. in2 = self.in2_conv(c2)
  265. out4 = in4 + F.interpolate(
  266. in5, scale_factor=2, mode='nearest', align_corners=None) # 1/16
  267. out3 = in3 + F.interpolate(
  268. out4, scale_factor=2, mode='nearest', align_corners=None) # 1/8
  269. out2 = in2 + F.interpolate(
  270. out3, scale_factor=2, mode='nearest', align_corners=None) # 1/4
  271. p5 = self.p5_conv(in5)
  272. p4 = self.p4_conv(out4)
  273. p3 = self.p3_conv(out3)
  274. p2 = self.p2_conv(out2)
  275. p5 = F.interpolate(p5,
  276. scale_factor=8,
  277. mode='nearest',
  278. align_corners=None)
  279. p4 = F.interpolate(p4,
  280. scale_factor=4,
  281. mode='nearest',
  282. align_corners=None)
  283. p3 = F.interpolate(p3,
  284. scale_factor=2,
  285. mode='nearest',
  286. align_corners=None)
  287. fuse = torch.concat([p5, p4, p3, p2], dim=1)
  288. if self.use_asf is True:
  289. fuse = self.asf(fuse, [p5, p4, p3, p2])
  290. return fuse
  291. class RSELayer(nn.Module):
  292. def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
  293. super(RSELayer, self).__init__()
  294. # weight_attr = paddle.nn.initializer.KaimingUniform()
  295. self.out_channels = out_channels
  296. self.in_conv = nn.Conv2d(
  297. in_channels=in_channels,
  298. out_channels=self.out_channels,
  299. kernel_size=kernel_size,
  300. padding=int(kernel_size // 2),
  301. # weight_attr=ParamAttr(initializer=weight_attr),
  302. bias=False,
  303. )
  304. self.se_block = SEModule(self.out_channels)
  305. self.shortcut = shortcut
  306. def forward(self, ins):
  307. x = self.in_conv(ins)
  308. if self.shortcut:
  309. out = x + self.se_block(x)
  310. else:
  311. out = self.se_block(x)
  312. return out
  313. class RSEFPN(nn.Module):
  314. def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
  315. super(RSEFPN, self).__init__()
  316. self.out_channels = out_channels
  317. self.ins_conv = nn.ModuleList()
  318. self.inp_conv = nn.ModuleList()
  319. self.intracl = False
  320. if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
  321. self.intracl = kwargs['intracl']
  322. self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  323. self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  324. self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  325. self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  326. for i in range(len(in_channels)):
  327. self.ins_conv.append(
  328. RSELayer(in_channels[i],
  329. out_channels,
  330. kernel_size=1,
  331. shortcut=shortcut))
  332. self.inp_conv.append(
  333. RSELayer(out_channels,
  334. out_channels // 4,
  335. kernel_size=3,
  336. shortcut=shortcut))
  337. def forward(self, x):
  338. c2, c3, c4, c5 = x
  339. in5 = self.ins_conv[3](c5)
  340. in4 = self.ins_conv[2](c4)
  341. in3 = self.ins_conv[1](c3)
  342. in2 = self.ins_conv[0](c2)
  343. out4 = in4 + F.interpolate(
  344. in5, scale_factor=2, mode='nearest', align_corners=None) # 1/16
  345. out3 = in3 + F.interpolate(
  346. out4, scale_factor=2, mode='nearest', align_corners=None) # 1/8
  347. out2 = in2 + F.interpolate(
  348. out3, scale_factor=2, mode='nearest', align_corners=None) # 1/4
  349. p5 = self.inp_conv[3](in5)
  350. p4 = self.inp_conv[2](out4)
  351. p3 = self.inp_conv[1](out3)
  352. p2 = self.inp_conv[0](out2)
  353. if self.intracl is True:
  354. p5 = self.incl4(p5)
  355. p4 = self.incl3(p4)
  356. p3 = self.incl2(p3)
  357. p2 = self.incl1(p2)
  358. p5 = F.interpolate(p5,
  359. scale_factor=8,
  360. mode='nearest',
  361. align_corners=None)
  362. p4 = F.interpolate(p4,
  363. scale_factor=4,
  364. mode='nearest',
  365. align_corners=None)
  366. p3 = F.interpolate(p3,
  367. scale_factor=2,
  368. mode='nearest',
  369. align_corners=None)
  370. fuse = torch.concat([p5, p4, p3, p2], dim=1)
  371. return fuse
  372. class LKPAN(nn.Module):
  373. def __init__(self, in_channels, out_channels, mode='large', **kwargs):
  374. super(LKPAN, self).__init__()
  375. self.out_channels = out_channels
  376. # weight_attr = paddle.nn.initializer.KaimingUniform()
  377. self.ins_conv = nn.ModuleList()
  378. self.inp_conv = nn.ModuleList()
  379. # pan head
  380. self.pan_head_conv = nn.ModuleList()
  381. self.pan_lat_conv = nn.ModuleList()
  382. if mode.lower() == 'lite':
  383. p_layer = DSConv
  384. elif mode.lower() == 'large':
  385. p_layer = nn.Conv2D
  386. else:
  387. raise ValueError(
  388. "mode can only be one of ['lite', 'large'], but received {}".
  389. format(mode))
  390. for i in range(len(in_channels)):
  391. self.ins_conv.append(
  392. nn.Conv2d(
  393. in_channels=in_channels[i],
  394. out_channels=self.out_channels,
  395. kernel_size=1,
  396. bias=False,
  397. ))
  398. self.inp_conv.append(
  399. p_layer(
  400. in_channels=self.out_channels,
  401. out_channels=self.out_channels // 4,
  402. kernel_size=9,
  403. padding=4,
  404. bias=False,
  405. ))
  406. if i > 0:
  407. self.pan_head_conv.append(
  408. nn.Conv2d(
  409. in_channels=self.out_channels // 4,
  410. out_channels=self.out_channels // 4,
  411. kernel_size=3,
  412. padding=1,
  413. stride=2,
  414. bias=False,
  415. ))
  416. self.pan_lat_conv.append(
  417. p_layer(
  418. in_channels=self.out_channels // 4,
  419. out_channels=self.out_channels // 4,
  420. kernel_size=9,
  421. padding=4,
  422. bias=False,
  423. ))
  424. self.intracl = False
  425. if 'intracl' in kwargs.keys() and kwargs['intracl'] is True:
  426. self.intracl = kwargs['intracl']
  427. self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  428. self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  429. self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  430. self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  431. def forward(self, x):
  432. c2, c3, c4, c5 = x
  433. in5 = self.ins_conv[3](c5)
  434. in4 = self.ins_conv[2](c4)
  435. in3 = self.ins_conv[1](c3)
  436. in2 = self.ins_conv[0](c2)
  437. out4 = in4 + F.interpolate(
  438. in5, scale_factor=2, mode='nearest', align_corners=None) # 1/16
  439. out3 = in3 + F.interpolate(
  440. out4, scale_factor=2, mode='nearest', align_corners=None) # 1/8
  441. out2 = in2 + F.interpolate(
  442. out3, scale_factor=2, mode='nearest', align_corners=None) # 1/4
  443. f5 = self.inp_conv[3](in5)
  444. f4 = self.inp_conv[2](out4)
  445. f3 = self.inp_conv[1](out3)
  446. f2 = self.inp_conv[0](out2)
  447. pan3 = f3 + self.pan_head_conv[0](f2)
  448. pan4 = f4 + self.pan_head_conv[1](pan3)
  449. pan5 = f5 + self.pan_head_conv[2](pan4)
  450. p2 = self.pan_lat_conv[0](f2)
  451. p3 = self.pan_lat_conv[1](pan3)
  452. p4 = self.pan_lat_conv[2](pan4)
  453. p5 = self.pan_lat_conv[3](pan5)
  454. if self.intracl is True:
  455. p5 = self.incl4(p5)
  456. p4 = self.incl3(p4)
  457. p3 = self.incl2(p3)
  458. p2 = self.incl1(p2)
  459. p5 = F.interpolate(p5,
  460. scale_factor=8,
  461. mode='nearest',
  462. align_corners=None)
  463. p4 = F.interpolate(p4,
  464. scale_factor=4,
  465. mode='nearest',
  466. align_corners=None)
  467. p3 = F.interpolate(p3,
  468. scale_factor=2,
  469. mode='nearest',
  470. align_corners=None)
  471. fuse = torch.concat([p5, p4, p3, p2], dim=1)
  472. return fuse
  473. class ASFBlock(nn.Module):
  474. """
  475. This code is refered from:
  476. https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
  477. """
  478. def __init__(self, in_channels, inter_channels, out_features_num=4):
  479. """
  480. Adaptive Scale Fusion (ASF) block of DBNet++
  481. Args:
  482. in_channels: the number of channels in the input data
  483. inter_channels: the number of middle channels
  484. out_features_num: the number of fused stages
  485. """
  486. super(ASFBlock, self).__init__()
  487. # weight_attr = paddle.nn.initializer.KaimingUniform()
  488. self.in_channels = in_channels
  489. self.inter_channels = inter_channels
  490. self.out_features_num = out_features_num
  491. self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
  492. self.spatial_scale = nn.Sequential(
  493. # Nx1xHxW
  494. nn.Conv2d(
  495. in_channels=1,
  496. out_channels=1,
  497. kernel_size=3,
  498. bias=False,
  499. padding=1,
  500. ),
  501. nn.ReLU(),
  502. nn.Conv2d(
  503. in_channels=1,
  504. out_channels=1,
  505. kernel_size=1,
  506. bias=False,
  507. ),
  508. nn.Sigmoid(),
  509. )
  510. self.channel_scale = nn.Sequential(
  511. nn.Conv2d(
  512. in_channels=inter_channels,
  513. out_channels=out_features_num,
  514. kernel_size=1,
  515. bias=False,
  516. ),
  517. nn.Sigmoid(),
  518. )
  519. def forward(self, fuse_features, features_list):
  520. fuse_features = self.conv(fuse_features)
  521. spatial_x = torch.mean(fuse_features, dim=1, keepdim=True)
  522. attention_scores = self.spatial_scale(spatial_x) + fuse_features
  523. attention_scores = self.channel_scale(attention_scores)
  524. assert len(features_list) == self.out_features_num
  525. out_list = []
  526. for i in range(self.out_features_num):
  527. out_list.append(attention_scores[:, i:i + 1] * features_list[i])
  528. return torch.concat(out_list, dim=1)