cppd_decoder.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from torch.nn.init import ones_, trunc_normal_, zeros_
  6. from openrec.modeling.common import DropPath, Identity, Mlp
  7. from openrec.modeling.decoders.nrtr_decoder import Embeddings
  8. class Attention(nn.Module):
  9. def __init__(
  10. self,
  11. dim,
  12. num_heads=8,
  13. qkv_bias=False,
  14. qk_scale=None,
  15. attn_drop=0.0,
  16. proj_drop=0.0,
  17. ):
  18. super().__init__()
  19. self.num_heads = num_heads
  20. head_dim = dim // num_heads
  21. self.scale = qk_scale or head_dim**-0.5
  22. self.q = nn.Linear(dim, dim, bias=qkv_bias)
  23. self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
  24. self.attn_drop = nn.Dropout(attn_drop)
  25. self.proj = nn.Linear(dim, dim)
  26. self.proj_drop = nn.Dropout(proj_drop)
  27. def forward(self, q, kv, key_mask=None):
  28. N, C = kv.shape[1:]
  29. QN = q.shape[1]
  30. q = self.q(q).reshape([-1, QN, self.num_heads,
  31. C // self.num_heads]).transpose(1, 2)
  32. q = q * self.scale
  33. k, v = self.kv(kv).reshape(
  34. [-1, N, 2, self.num_heads,
  35. C // self.num_heads]).permute(2, 0, 3, 1, 4)
  36. attn = q.matmul(k.transpose(2, 3))
  37. if key_mask is not None:
  38. attn = attn + key_mask.unsqueeze(1)
  39. attn = F.softmax(attn, -1)
  40. # if not self.training:
  41. # self.attn_map = attn
  42. attn = self.attn_drop(attn)
  43. x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
  44. x = self.proj(x)
  45. x = self.proj_drop(x)
  46. return x
  47. class EdgeDecoderLayer(nn.Module):
  48. def __init__(
  49. self,
  50. dim,
  51. num_heads,
  52. mlp_ratio=4.0,
  53. qkv_bias=False,
  54. qk_scale=None,
  55. drop=0.0,
  56. attn_drop=0.0,
  57. drop_path=[0.0, 0.0],
  58. act_layer=nn.GELU,
  59. norm_layer='nn.LayerNorm',
  60. epsilon=1e-6,
  61. ):
  62. super().__init__()
  63. self.head_dim = dim // num_heads
  64. self.scale = qk_scale or self.head_dim**-0.5
  65. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  66. self.drop_path1 = DropPath(
  67. drop_path[0]) if drop_path[0] > 0.0 else Identity()
  68. self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
  69. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  70. self.p = nn.Linear(dim, dim)
  71. self.cv = nn.Linear(dim, dim)
  72. self.pv = nn.Linear(dim, dim)
  73. self.dim = dim
  74. self.num_heads = num_heads
  75. self.p_proj = nn.Linear(dim, dim)
  76. mlp_hidden_dim = int(dim * mlp_ratio)
  77. self.mlp_ratio = mlp_ratio
  78. self.mlp = Mlp(
  79. in_features=dim,
  80. hidden_features=mlp_hidden_dim,
  81. act_layer=act_layer,
  82. drop=drop,
  83. )
  84. def forward(self, p, cv, pv):
  85. pN = p.shape[1]
  86. vN = cv.shape[1]
  87. p_shortcut = p
  88. p1 = self.p(p).reshape(
  89. [-1, pN, self.num_heads,
  90. self.dim // self.num_heads]).transpose(1, 2)
  91. cv1 = self.cv(cv).reshape(
  92. [-1, vN, self.num_heads,
  93. self.dim // self.num_heads]).transpose(1, 2)
  94. pv1 = self.pv(pv).reshape(
  95. [-1, vN, self.num_heads,
  96. self.dim // self.num_heads]).transpose(1, 2)
  97. edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N
  98. p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim))
  99. x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
  100. x = self.norm2(x1 + self.drop_path1(self.mlp(x1)))
  101. return x
  102. class DecoderLayer(nn.Module):
  103. def __init__(
  104. self,
  105. dim,
  106. num_heads,
  107. mlp_ratio=4.0,
  108. qkv_bias=False,
  109. qk_scale=None,
  110. drop=0.0,
  111. attn_drop=0.0,
  112. drop_path=0.0,
  113. act_layer=nn.GELU,
  114. norm_layer=nn.LayerNorm,
  115. epsilon=1e-6,
  116. ):
  117. super().__init__()
  118. self.norm1 = norm_layer(dim, eps=epsilon)
  119. self.mixer = Attention(
  120. dim,
  121. num_heads=num_heads,
  122. qkv_bias=qkv_bias,
  123. qk_scale=qk_scale,
  124. attn_drop=attn_drop,
  125. proj_drop=drop,
  126. )
  127. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  128. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  129. self.norm2 = norm_layer(dim, eps=epsilon)
  130. mlp_hidden_dim = int(dim * mlp_ratio)
  131. self.mlp_ratio = mlp_ratio
  132. self.mlp = Mlp(
  133. in_features=dim,
  134. hidden_features=mlp_hidden_dim,
  135. act_layer=act_layer,
  136. drop=drop,
  137. )
  138. def forward(self, q, kv, key_mask=None):
  139. x1 = self.norm1(q + self.drop_path(self.mixer(q, kv, key_mask)))
  140. x = self.norm2(x1 + self.drop_path(self.mlp(x1)))
  141. return x
  142. class CPPDDecoder(nn.Module):
  143. def __init__(self,
  144. in_channels,
  145. out_channels,
  146. num_layer=2,
  147. drop_path_rate=0.1,
  148. max_len=25,
  149. vis_seq=50,
  150. iters=1,
  151. pos_len=False,
  152. ch=False,
  153. rec_layer=1,
  154. num_heads=None,
  155. ds=False,
  156. **kwargs):
  157. super(CPPDDecoder, self).__init__()
  158. self.out_channels = out_channels # none + 26 + 10
  159. dim = in_channels
  160. self.dim = dim
  161. self.iters = iters
  162. self.max_len = max_len + 1 # max_len + eos
  163. self.pos_len = pos_len
  164. self.ch = ch
  165. self.char_node_embed = Embeddings(d_model=dim,
  166. vocab=self.out_channels,
  167. scale_embedding=True)
  168. self.pos_node_embed = Embeddings(d_model=dim,
  169. vocab=self.max_len,
  170. scale_embedding=True)
  171. dpr = np.linspace(0, drop_path_rate, num_layer + rec_layer)
  172. self.char_node_decoder = nn.ModuleList([
  173. DecoderLayer(
  174. dim=dim,
  175. num_heads=dim // 32 if num_heads is None else num_heads,
  176. mlp_ratio=4.0,
  177. qkv_bias=True,
  178. drop_path=dpr[i],
  179. ) for i in range(num_layer)
  180. ])
  181. self.pos_node_decoder = nn.ModuleList([
  182. DecoderLayer(
  183. dim=dim,
  184. num_heads=dim // 32 if num_heads is None else num_heads,
  185. mlp_ratio=4.0,
  186. qkv_bias=True,
  187. drop_path=dpr[i],
  188. ) for i in range(num_layer)
  189. ])
  190. self.edge_decoder = nn.ModuleList([
  191. DecoderLayer(
  192. dim=dim,
  193. num_heads=dim // 32 if num_heads is None else num_heads,
  194. mlp_ratio=4.0,
  195. qkv_bias=True,
  196. qk_scale=1.0 if (rec_layer + i) % 2 != 0 else None,
  197. drop_path=dpr[num_layer + i],
  198. ) for i in range(rec_layer)
  199. ])
  200. self.rec_layer_num = rec_layer
  201. self_mask = torch.tril(
  202. torch.ones([self.max_len, self.max_len], dtype=torch.float32))
  203. self_mask = torch.where(
  204. self_mask > 0,
  205. torch.zeros_like(self_mask, dtype=torch.float32),
  206. torch.full([self.max_len, self.max_len],
  207. float('-inf'),
  208. dtype=torch.float32),
  209. )
  210. self.self_mask = self_mask.unsqueeze(0)
  211. self.char_pos_embed = nn.Parameter(torch.zeros([1, self.max_len, dim],
  212. dtype=torch.float32),
  213. requires_grad=True)
  214. self.ds = ds
  215. if not self.ds:
  216. self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
  217. dtype=torch.float32),
  218. requires_grad=True)
  219. trunc_normal_(self.vis_pos_embed, std=0.02)
  220. self.char_node_fc1 = nn.Linear(dim, max_len)
  221. self.pos_node_fc1 = nn.Linear(dim, self.max_len)
  222. self.edge_fc = nn.Linear(dim, self.out_channels)
  223. trunc_normal_(self.char_pos_embed, std=0.02)
  224. self.apply(self._init_weights)
  225. def _init_weights(self, m):
  226. if isinstance(m, nn.Linear):
  227. trunc_normal_(m.weight, std=0.02)
  228. if isinstance(m, nn.Linear) and m.bias is not None:
  229. zeros_(m.bias)
  230. elif isinstance(m, nn.LayerNorm):
  231. zeros_(m.bias)
  232. ones_(m.weight)
  233. @torch.jit.ignore
  234. def no_weight_decay(self):
  235. return {
  236. 'char_pos_embed', 'vis_pos_embed', 'char_node_embed',
  237. 'pos_node_embed'
  238. }
  239. def forward(self, x, data=None):
  240. if self.training:
  241. return self.forward_train(x, data)
  242. else:
  243. return self.forward_test(x)
  244. def forward_test(self, x):
  245. if not self.ds:
  246. visual_feats = x + self.vis_pos_embed
  247. else:
  248. visual_feats = x
  249. bs = visual_feats.shape[0]
  250. pos_node_embed = self.pos_node_embed(
  251. torch.arange(self.max_len).cuda(
  252. x.get_device())).unsqueeze(0) + self.char_pos_embed
  253. pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1])
  254. char_vis_node_query = visual_feats
  255. pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1)
  256. for char_decoder_layer, pos_decoder_layer in zip(
  257. self.char_node_decoder, self.pos_node_decoder):
  258. char_vis_node_query = char_decoder_layer(char_vis_node_query,
  259. char_vis_node_query)
  260. pos_vis_node_query = pos_decoder_layer(
  261. pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :])
  262. pos_node_query = pos_vis_node_query[:, :self.max_len, :]
  263. char_vis_feats = char_vis_node_query
  264. # pos_vis_feats = pos_vis_node_query[:, self.max_len :, :]
  265. # pos_node_feats = self.edge_decoder(
  266. # pos_node_query, char_vis_feats, pos_vis_feats
  267. # ) # B, 26, dim
  268. pos_node_feats = pos_node_query
  269. for layer_i in range(self.rec_layer_num):
  270. rec_layer = self.edge_decoder[layer_i]
  271. if (self.rec_layer_num + layer_i) % 2 == 0:
  272. pos_node_feats = rec_layer(pos_node_feats, pos_node_feats,
  273. self.self_mask)
  274. else:
  275. pos_node_feats = rec_layer(pos_node_feats, char_vis_feats)
  276. edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
  277. edge_logits = F.softmax(
  278. edge_feats,
  279. -1) # * F.sigmoid(pos_node_feats1.unsqueeze(-1)) # B, 26, 37
  280. return edge_logits
  281. def forward_train(self, x, targets=None):
  282. if not self.ds:
  283. visual_feats = x + self.vis_pos_embed
  284. else:
  285. visual_feats = x
  286. bs = visual_feats.shape[0]
  287. if self.ch:
  288. char_node_embed = self.char_node_embed(targets[-2])
  289. else:
  290. char_node_embed = self.char_node_embed(
  291. torch.arange(self.out_channels).cuda(
  292. x.get_device())).unsqueeze(0)
  293. char_node_embed = torch.tile(char_node_embed, [bs, 1, 1])
  294. counting_char_num = char_node_embed.shape[1]
  295. pos_node_embed = self.pos_node_embed(
  296. torch.arange(self.max_len).cuda(
  297. x.get_device())).unsqueeze(0) + self.char_pos_embed
  298. pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1])
  299. node_feats = []
  300. char_vis_node_query = torch.concat([char_node_embed, visual_feats], 1)
  301. pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1)
  302. for char_decoder_layer, pos_decoder_layer in zip(
  303. self.char_node_decoder, self.pos_node_decoder):
  304. char_vis_node_query = char_decoder_layer(
  305. char_vis_node_query,
  306. char_vis_node_query[:, counting_char_num:, :])
  307. pos_vis_node_query = pos_decoder_layer(
  308. pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :])
  309. char_node_query = char_vis_node_query[:, :counting_char_num, :]
  310. pos_node_query = pos_vis_node_query[:, :self.max_len, :]
  311. char_vis_feats = char_vis_node_query[:, counting_char_num:, :]
  312. char_node_feats1 = self.char_node_fc1(char_node_query)
  313. pos_node_feats1 = self.pos_node_fc1(pos_node_query)
  314. if not self.pos_len:
  315. diag_mask = torch.eye(pos_node_feats1.shape[1]).unsqueeze(0).tile(
  316. [pos_node_feats1.shape[0], 1, 1])
  317. pos_node_feats1 = (
  318. pos_node_feats1 *
  319. diag_mask.cuda(pos_node_feats1.get_device())).sum(-1)
  320. node_feats.append(char_node_feats1)
  321. node_feats.append(pos_node_feats1)
  322. pos_node_feats = pos_node_query
  323. for layer_i in range(self.rec_layer_num):
  324. rec_layer = self.edge_decoder[layer_i]
  325. if (self.rec_layer_num + layer_i) % 2 == 0:
  326. pos_node_feats = rec_layer(pos_node_feats, pos_node_feats,
  327. self.self_mask)
  328. else:
  329. pos_node_feats = rec_layer(pos_node_feats, char_vis_feats)
  330. edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
  331. return node_feats, edge_feats