nrtr_decoder.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. import math
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn
  6. from openrec.modeling.common import Mlp
  7. class NRTRDecoder(nn.Module):
  8. """A transformer model. User is able to modify the attributes as needed.
  9. The architechture is based on the paper "Attention Is All You Need". Ashish
  10. Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
  11. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you
  12. need. In Advances in Neural Information Processing Systems, pages
  13. 6000-6010.
  14. Args:
  15. d_model: the number of expected features in the encoder/decoder inputs (default=512).
  16. nhead: the number of heads in the multiheadattention models (default=8).
  17. num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
  18. num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
  19. dim_feedforward: the dimension of the feedforward network model (default=2048).
  20. dropout: the dropout value (default=0.1).
  21. custom_encoder: custom encoder (default=None).
  22. custom_decoder: custom decoder (default=None).
  23. """
  24. def __init__(
  25. self,
  26. in_channels,
  27. out_channels,
  28. nhead=None,
  29. num_encoder_layers=6,
  30. beam_size=0,
  31. num_decoder_layers=6,
  32. max_len=25,
  33. attention_dropout_rate=0.0,
  34. residual_dropout_rate=0.1,
  35. scale_embedding=True,
  36. ):
  37. super(NRTRDecoder, self).__init__()
  38. self.out_channels = out_channels
  39. self.ignore_index = out_channels - 1
  40. self.bos = out_channels - 2
  41. self.eos = 0
  42. self.max_len = max_len
  43. d_model = in_channels
  44. dim_feedforward = d_model * 4
  45. nhead = nhead if nhead is not None else d_model // 32
  46. self.embedding = Embeddings(
  47. d_model=d_model,
  48. vocab=self.out_channels,
  49. padding_idx=0,
  50. scale_embedding=scale_embedding,
  51. )
  52. self.positional_encoding = PositionalEncoding(
  53. dropout=residual_dropout_rate, dim=d_model)
  54. if num_encoder_layers > 0:
  55. self.encoder = nn.ModuleList([
  56. TransformerBlock(
  57. d_model,
  58. nhead,
  59. dim_feedforward,
  60. attention_dropout_rate,
  61. residual_dropout_rate,
  62. with_self_attn=True,
  63. with_cross_attn=False,
  64. ) for i in range(num_encoder_layers)
  65. ])
  66. else:
  67. self.encoder = None
  68. self.decoder = nn.ModuleList([
  69. TransformerBlock(
  70. d_model,
  71. nhead,
  72. dim_feedforward,
  73. attention_dropout_rate,
  74. residual_dropout_rate,
  75. with_self_attn=True,
  76. with_cross_attn=True,
  77. ) for i in range(num_decoder_layers)
  78. ])
  79. self.beam_size = beam_size
  80. self.d_model = d_model
  81. self.nhead = nhead
  82. self.tgt_word_prj = nn.Linear(d_model,
  83. self.out_channels - 2,
  84. bias=False)
  85. w0 = np.random.normal(0.0, d_model**-0.5,
  86. (d_model, self.out_channels - 2)).astype(
  87. np.float32)
  88. self.tgt_word_prj.weight.data = torch.from_numpy(w0.transpose())
  89. self.apply(self._init_weights)
  90. def _init_weights(self, m):
  91. if isinstance(m, nn.Linear):
  92. nn.init.xavier_normal_(m.weight)
  93. if m.bias is not None:
  94. nn.init.zeros_(m.bias)
  95. def forward_train(self, src, tgt):
  96. tgt = tgt[:, :-1]
  97. tgt = self.embedding(tgt)
  98. tgt = self.positional_encoding(tgt)
  99. tgt_mask = self.generate_square_subsequent_mask(
  100. tgt.shape[1], device=src.get_device())
  101. if self.encoder is not None:
  102. src = self.positional_encoding(src)
  103. for encoder_layer in self.encoder:
  104. src = encoder_layer(src)
  105. memory = src # B N C
  106. else:
  107. memory = src # B N C
  108. for decoder_layer in self.decoder:
  109. tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
  110. output = tgt
  111. logit = self.tgt_word_prj(output)
  112. return logit
  113. def forward(self, src, data=None):
  114. """Take in and process masked source/target sequences.
  115. Args:
  116. src: the sequence to the encoder (required).
  117. tgt: the sequence to the decoder (required).
  118. Shape:
  119. - src: :math:`(B, sN, C)`.
  120. - tgt: :math:`(B, tN, C)`.
  121. Examples:
  122. >>> output = transformer_model(src, tgt)
  123. """
  124. if self.training:
  125. max_len = data[1].max()
  126. tgt = data[0][:, :2 + max_len]
  127. res = self.forward_train(src, tgt)
  128. else:
  129. res = self.forward_test(src)
  130. return res
  131. def forward_test(self, src):
  132. bs = src.shape[0]
  133. if self.encoder is not None:
  134. src = self.positional_encoding(src)
  135. for encoder_layer in self.encoder:
  136. src = encoder_layer(src)
  137. memory = src # B N C
  138. else:
  139. memory = src
  140. dec_seq = torch.full((bs, self.max_len + 1),
  141. self.ignore_index,
  142. dtype=torch.int64,
  143. device=src.get_device())
  144. dec_seq[:, 0] = self.bos
  145. logits = []
  146. self.attn_maps = []
  147. for len_dec_seq in range(0, self.max_len):
  148. dec_seq_embed = self.embedding(
  149. dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # </s> 012 a
  150. dec_seq_embed = self.positional_encoding(dec_seq_embed)
  151. tgt_mask = self.generate_square_subsequent_mask(
  152. dec_seq_embed.shape[1], src.get_device())
  153. tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos
  154. for decoder_layer in self.decoder:
  155. tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
  156. self.attn_maps.append(
  157. self.decoder[-1].cross_attn.attn_map[0][:, -1:, :])
  158. dec_output = tgt
  159. dec_output = dec_output[:, -1:, :]
  160. word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
  161. logits.append(word_prob)
  162. if len_dec_seq < self.max_len:
  163. # greedy decode. add the next token index to the target input
  164. dec_seq[:, len_dec_seq + 1] = word_prob.squeeze().argmax(-1)
  165. # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
  166. if (dec_seq == self.eos).any(dim=-1).all():
  167. break
  168. logits = torch.cat(logits, dim=1)
  169. return logits
  170. def generate_square_subsequent_mask(self, sz, device):
  171. """Generate a square mask for the sequence.
  172. The masked positions are filled with float('-inf'). Unmasked positions
  173. are filled with float(0.0).
  174. """
  175. mask = torch.zeros([sz, sz], dtype=torch.float32)
  176. mask_inf = torch.triu(
  177. torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf),
  178. diagonal=1,
  179. )
  180. mask = mask + mask_inf
  181. return mask.unsqueeze(0).unsqueeze(0).to(device)
  182. class MultiheadAttention(nn.Module):
  183. def __init__(self, embed_dim, num_heads, dropout=0.0, self_attn=False):
  184. super(MultiheadAttention, self).__init__()
  185. self.embed_dim = embed_dim
  186. self.num_heads = num_heads
  187. self.head_dim = embed_dim // num_heads
  188. assert (self.head_dim * num_heads == self.embed_dim
  189. ), 'embed_dim must be divisible by num_heads'
  190. self.scale = self.head_dim**-0.5
  191. self.self_attn = self_attn
  192. if self_attn:
  193. self.qkv = nn.Linear(embed_dim, embed_dim * 3)
  194. else:
  195. self.q = nn.Linear(embed_dim, embed_dim)
  196. self.kv = nn.Linear(embed_dim, embed_dim * 2)
  197. self.attn_drop = nn.Dropout(dropout)
  198. self.out_proj = nn.Linear(embed_dim, embed_dim)
  199. def forward(self, query, key=None, attn_mask=None):
  200. B, qN = query.shape[:2]
  201. if self.self_attn:
  202. qkv = self.qkv(query)
  203. qkv = qkv.reshape(B, qN, 3, self.num_heads,
  204. self.head_dim).permute(2, 0, 3, 1, 4)
  205. q, k, v = qkv.unbind(0)
  206. else:
  207. kN = key.shape[1]
  208. q = self.q(query)
  209. q = q.reshape(B, qN, self.num_heads, self.head_dim).transpose(1, 2)
  210. kv = self.kv(key)
  211. kv = kv.reshape(B, kN, 2, self.num_heads,
  212. self.head_dim).permute(2, 0, 3, 1, 4)
  213. k, v = kv.unbind(0)
  214. attn = (q.matmul(k.transpose(2, 3))) * self.scale
  215. if attn_mask is not None:
  216. attn += attn_mask
  217. attn = F.softmax(attn, dim=-1)
  218. if not self.training:
  219. self.attn_map = attn
  220. attn = self.attn_drop(attn)
  221. x = (attn.matmul(v)).transpose(1, 2)
  222. x = x.reshape(B, qN, self.embed_dim)
  223. x = self.out_proj(x)
  224. return x
  225. class TransformerBlock(nn.Module):
  226. def __init__(
  227. self,
  228. d_model,
  229. nhead,
  230. dim_feedforward=2048,
  231. attention_dropout_rate=0.0,
  232. residual_dropout_rate=0.1,
  233. with_self_attn=True,
  234. with_cross_attn=False,
  235. epsilon=1e-5,
  236. ):
  237. super(TransformerBlock, self).__init__()
  238. self.with_self_attn = with_self_attn
  239. if with_self_attn:
  240. self.self_attn = MultiheadAttention(d_model,
  241. nhead,
  242. dropout=attention_dropout_rate,
  243. self_attn=with_self_attn)
  244. self.norm1 = nn.LayerNorm(d_model, eps=epsilon)
  245. self.dropout1 = nn.Dropout(residual_dropout_rate)
  246. self.with_cross_attn = with_cross_attn
  247. if with_cross_attn:
  248. self.cross_attn = MultiheadAttention(
  249. d_model, nhead, dropout=attention_dropout_rate
  250. ) # for self_attn of encoder or cross_attn of decoder
  251. self.norm2 = nn.LayerNorm(d_model, eps=epsilon)
  252. self.dropout2 = nn.Dropout(residual_dropout_rate)
  253. self.mlp = Mlp(
  254. in_features=d_model,
  255. hidden_features=dim_feedforward,
  256. act_layer=nn.ReLU,
  257. drop=residual_dropout_rate,
  258. )
  259. self.norm3 = nn.LayerNorm(d_model, eps=epsilon)
  260. self.dropout3 = nn.Dropout(residual_dropout_rate)
  261. def forward(self, tgt, memory=None, self_mask=None, cross_mask=None):
  262. if self.with_self_attn:
  263. tgt1 = self.self_attn(tgt, attn_mask=self_mask)
  264. tgt = self.norm1(tgt + self.dropout1(tgt1))
  265. if self.with_cross_attn:
  266. tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
  267. tgt = self.norm2(tgt + self.dropout2(tgt2))
  268. tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
  269. return tgt
  270. class PositionalEncoding(nn.Module):
  271. """Inject some information about the relative or absolute position of the
  272. tokens in the sequence. The positional encodings have the same dimension as
  273. the embeddings, so that the two can be summed. Here, we use sine and cosine
  274. functions of different frequencies.
  275. .. math::
  276. \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
  277. \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
  278. \text{where pos is the word position and i is the embed idx)
  279. Args:
  280. d_model: the embed dim (required).
  281. dropout: the dropout value (default=0.1).
  282. max_len: the max. length of the incoming sequence (default=5000).
  283. Examples:
  284. >>> pos_encoder = PositionalEncoding(d_model)
  285. """
  286. def __init__(self, dropout, dim, max_len=5000):
  287. super(PositionalEncoding, self).__init__()
  288. self.dropout = nn.Dropout(p=dropout)
  289. pe = torch.zeros([max_len, dim])
  290. position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
  291. div_term = torch.exp(
  292. torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
  293. pe[:, 0::2] = torch.sin(position * div_term)
  294. pe[:, 1::2] = torch.cos(position * div_term)
  295. pe = torch.unsqueeze(pe, 0)
  296. # pe = torch.permute(pe, [1, 0, 2])
  297. self.register_buffer('pe', pe)
  298. def forward(self, x):
  299. """Inputs of forward function
  300. Args:
  301. x: the sequence fed to the positional encoder model (required).
  302. Shape:
  303. x: [sequence length, batch size, embed dim]
  304. output: [sequence length, batch size, embed dim]
  305. Examples:
  306. >>> output = pos_encoder(x)
  307. """
  308. # x = x.permute([1, 0, 2])
  309. # x = x + self.pe[:x.shape[0], :]
  310. x = x + self.pe[:, :x.shape[1], :]
  311. return self.dropout(x) # .permute([1, 0, 2])
  312. class PositionalEncoding_2d(nn.Module):
  313. """Inject some information about the relative or absolute position of the
  314. tokens in the sequence. The positional encodings have the same dimension as
  315. the embeddings, so that the two can be summed. Here, we use sine and cosine
  316. functions of different frequencies.
  317. .. math::
  318. \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
  319. \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
  320. \text{where pos is the word position and i is the embed idx)
  321. Args:
  322. d_model: the embed dim (required).
  323. dropout: the dropout value (default=0.1).
  324. max_len: the max. length of the incoming sequence (default=5000).
  325. Examples:
  326. >>> pos_encoder = PositionalEncoding(d_model)
  327. """
  328. def __init__(self, dropout, dim, max_len=5000):
  329. super(PositionalEncoding_2d, self).__init__()
  330. self.dropout = nn.Dropout(p=dropout)
  331. pe = torch.zeros([max_len, dim])
  332. position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
  333. div_term = torch.exp(
  334. torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
  335. pe[:, 0::2] = torch.sin(position * div_term)
  336. pe[:, 1::2] = torch.cos(position * div_term)
  337. pe = torch.permute(torch.unsqueeze(pe, 0), [1, 0, 2])
  338. self.register_buffer('pe', pe)
  339. self.avg_pool_1 = nn.AdaptiveAvgPool2d((1, 1))
  340. self.linear1 = nn.Linear(dim, dim)
  341. self.linear1.weight.data.fill_(1.0)
  342. self.avg_pool_2 = nn.AdaptiveAvgPool2d((1, 1))
  343. self.linear2 = nn.Linear(dim, dim)
  344. self.linear2.weight.data.fill_(1.0)
  345. def forward(self, x):
  346. """Inputs of forward function
  347. Args:
  348. x: the sequence fed to the positional encoder model (required).
  349. Shape:
  350. x: [sequence length, batch size, embed dim]
  351. output: [sequence length, batch size, embed dim]
  352. Examples:
  353. >>> output = pos_encoder(x)
  354. """
  355. w_pe = self.pe[:x.shape[-1], :]
  356. w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
  357. w_pe = w_pe * w1
  358. w_pe = torch.permute(w_pe, [1, 2, 0])
  359. w_pe = torch.unsqueeze(w_pe, 2)
  360. h_pe = self.pe[:x.shape[-2], :]
  361. w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
  362. h_pe = h_pe * w2
  363. h_pe = torch.permute(h_pe, [1, 2, 0])
  364. h_pe = torch.unsqueeze(h_pe, 3)
  365. x = x + w_pe + h_pe
  366. x = torch.permute(
  367. torch.reshape(x,
  368. [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
  369. [2, 0, 1],
  370. )
  371. return self.dropout(x)
  372. class Embeddings(nn.Module):
  373. def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
  374. super(Embeddings, self).__init__()
  375. self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
  376. self.embedding.weight.data.normal_(mean=0.0, std=d_model**-0.5)
  377. self.d_model = d_model
  378. self.scale_embedding = scale_embedding
  379. def forward(self, x):
  380. if self.scale_embedding:
  381. x = self.embedding(x)
  382. return x * math.sqrt(self.d_model)
  383. return self.embedding(x)