123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- import math
- import numpy as np
- import torch
- import torch.nn.functional as F
- from torch import nn
- from openrec.modeling.common import Mlp
- class NRTRDecoder(nn.Module):
- """A transformer model. User is able to modify the attributes as needed.
- The architechture is based on the paper "Attention Is All You Need". Ashish
- Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
- Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you
- need. In Advances in Neural Information Processing Systems, pages
- 6000-6010.
- Args:
- d_model: the number of expected features in the encoder/decoder inputs (default=512).
- nhead: the number of heads in the multiheadattention models (default=8).
- num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
- num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
- custom_encoder: custom encoder (default=None).
- custom_decoder: custom decoder (default=None).
- """
- def __init__(
- self,
- in_channels,
- out_channels,
- nhead=None,
- num_encoder_layers=6,
- beam_size=0,
- num_decoder_layers=6,
- max_len=25,
- attention_dropout_rate=0.0,
- residual_dropout_rate=0.1,
- scale_embedding=True,
- ):
- super(NRTRDecoder, self).__init__()
- self.out_channels = out_channels
- self.ignore_index = out_channels - 1
- self.bos = out_channels - 2
- self.eos = 0
- self.max_len = max_len
- d_model = in_channels
- dim_feedforward = d_model * 4
- nhead = nhead if nhead is not None else d_model // 32
- self.embedding = Embeddings(
- d_model=d_model,
- vocab=self.out_channels,
- padding_idx=0,
- scale_embedding=scale_embedding,
- )
- self.positional_encoding = PositionalEncoding(
- dropout=residual_dropout_rate, dim=d_model)
- if num_encoder_layers > 0:
- self.encoder = nn.ModuleList([
- TransformerBlock(
- d_model,
- nhead,
- dim_feedforward,
- attention_dropout_rate,
- residual_dropout_rate,
- with_self_attn=True,
- with_cross_attn=False,
- ) for i in range(num_encoder_layers)
- ])
- else:
- self.encoder = None
- self.decoder = nn.ModuleList([
- TransformerBlock(
- d_model,
- nhead,
- dim_feedforward,
- attention_dropout_rate,
- residual_dropout_rate,
- with_self_attn=True,
- with_cross_attn=True,
- ) for i in range(num_decoder_layers)
- ])
- self.beam_size = beam_size
- self.d_model = d_model
- self.nhead = nhead
- self.tgt_word_prj = nn.Linear(d_model,
- self.out_channels - 2,
- bias=False)
- w0 = np.random.normal(0.0, d_model**-0.5,
- (d_model, self.out_channels - 2)).astype(
- np.float32)
- self.tgt_word_prj.weight.data = torch.from_numpy(w0.transpose())
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- def forward_train(self, src, tgt):
- tgt = tgt[:, :-1]
- tgt = self.embedding(tgt)
- tgt = self.positional_encoding(tgt)
- tgt_mask = self.generate_square_subsequent_mask(
- tgt.shape[1], device=src.get_device())
- if self.encoder is not None:
- src = self.positional_encoding(src)
- for encoder_layer in self.encoder:
- src = encoder_layer(src)
- memory = src # B N C
- else:
- memory = src # B N C
- for decoder_layer in self.decoder:
- tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
- output = tgt
- logit = self.tgt_word_prj(output)
- return logit
- def forward(self, src, data=None):
- """Take in and process masked source/target sequences.
- Args:
- src: the sequence to the encoder (required).
- tgt: the sequence to the decoder (required).
- Shape:
- - src: :math:`(B, sN, C)`.
- - tgt: :math:`(B, tN, C)`.
- Examples:
- >>> output = transformer_model(src, tgt)
- """
- if self.training:
- max_len = data[1].max()
- tgt = data[0][:, :2 + max_len]
- res = self.forward_train(src, tgt)
- else:
- res = self.forward_test(src)
- return res
- def forward_test(self, src):
- bs = src.shape[0]
- if self.encoder is not None:
- src = self.positional_encoding(src)
- for encoder_layer in self.encoder:
- src = encoder_layer(src)
- memory = src # B N C
- else:
- memory = src
- dec_seq = torch.full((bs, self.max_len + 1),
- self.ignore_index,
- dtype=torch.int64,
- device=src.get_device())
- dec_seq[:, 0] = self.bos
- logits = []
- self.attn_maps = []
- for len_dec_seq in range(0, self.max_len):
- dec_seq_embed = self.embedding(
- dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # </s> 012 a
- dec_seq_embed = self.positional_encoding(dec_seq_embed)
- tgt_mask = self.generate_square_subsequent_mask(
- dec_seq_embed.shape[1], src.get_device())
- tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos
- for decoder_layer in self.decoder:
- tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
- self.attn_maps.append(
- self.decoder[-1].cross_attn.attn_map[0][:, -1:, :])
- dec_output = tgt
- dec_output = dec_output[:, -1:, :]
- word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
- logits.append(word_prob)
- if len_dec_seq < self.max_len:
- # greedy decode. add the next token index to the target input
- dec_seq[:, len_dec_seq + 1] = word_prob.squeeze().argmax(-1)
- # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
- if (dec_seq == self.eos).any(dim=-1).all():
- break
- logits = torch.cat(logits, dim=1)
- return logits
- def generate_square_subsequent_mask(self, sz, device):
- """Generate a square mask for the sequence.
- The masked positions are filled with float('-inf'). Unmasked positions
- are filled with float(0.0).
- """
- mask = torch.zeros([sz, sz], dtype=torch.float32)
- mask_inf = torch.triu(
- torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf),
- diagonal=1,
- )
- mask = mask + mask_inf
- return mask.unsqueeze(0).unsqueeze(0).to(device)
- class MultiheadAttention(nn.Module):
- def __init__(self, embed_dim, num_heads, dropout=0.0, self_attn=False):
- super(MultiheadAttention, self).__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.head_dim = embed_dim // num_heads
- assert (self.head_dim * num_heads == self.embed_dim
- ), 'embed_dim must be divisible by num_heads'
- self.scale = self.head_dim**-0.5
- self.self_attn = self_attn
- if self_attn:
- self.qkv = nn.Linear(embed_dim, embed_dim * 3)
- else:
- self.q = nn.Linear(embed_dim, embed_dim)
- self.kv = nn.Linear(embed_dim, embed_dim * 2)
- self.attn_drop = nn.Dropout(dropout)
- self.out_proj = nn.Linear(embed_dim, embed_dim)
- def forward(self, query, key=None, attn_mask=None):
- B, qN = query.shape[:2]
- if self.self_attn:
- qkv = self.qkv(query)
- qkv = qkv.reshape(B, qN, 3, self.num_heads,
- self.head_dim).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- else:
- kN = key.shape[1]
- q = self.q(query)
- q = q.reshape(B, qN, self.num_heads, self.head_dim).transpose(1, 2)
- kv = self.kv(key)
- kv = kv.reshape(B, kN, 2, self.num_heads,
- self.head_dim).permute(2, 0, 3, 1, 4)
- k, v = kv.unbind(0)
- attn = (q.matmul(k.transpose(2, 3))) * self.scale
- if attn_mask is not None:
- attn += attn_mask
- attn = F.softmax(attn, dim=-1)
- if not self.training:
- self.attn_map = attn
- attn = self.attn_drop(attn)
- x = (attn.matmul(v)).transpose(1, 2)
- x = x.reshape(B, qN, self.embed_dim)
- x = self.out_proj(x)
- return x
- class TransformerBlock(nn.Module):
- def __init__(
- self,
- d_model,
- nhead,
- dim_feedforward=2048,
- attention_dropout_rate=0.0,
- residual_dropout_rate=0.1,
- with_self_attn=True,
- with_cross_attn=False,
- epsilon=1e-5,
- ):
- super(TransformerBlock, self).__init__()
- self.with_self_attn = with_self_attn
- if with_self_attn:
- self.self_attn = MultiheadAttention(d_model,
- nhead,
- dropout=attention_dropout_rate,
- self_attn=with_self_attn)
- self.norm1 = nn.LayerNorm(d_model, eps=epsilon)
- self.dropout1 = nn.Dropout(residual_dropout_rate)
- self.with_cross_attn = with_cross_attn
- if with_cross_attn:
- self.cross_attn = MultiheadAttention(
- d_model, nhead, dropout=attention_dropout_rate
- ) # for self_attn of encoder or cross_attn of decoder
- self.norm2 = nn.LayerNorm(d_model, eps=epsilon)
- self.dropout2 = nn.Dropout(residual_dropout_rate)
- self.mlp = Mlp(
- in_features=d_model,
- hidden_features=dim_feedforward,
- act_layer=nn.ReLU,
- drop=residual_dropout_rate,
- )
- self.norm3 = nn.LayerNorm(d_model, eps=epsilon)
- self.dropout3 = nn.Dropout(residual_dropout_rate)
- def forward(self, tgt, memory=None, self_mask=None, cross_mask=None):
- if self.with_self_attn:
- tgt1 = self.self_attn(tgt, attn_mask=self_mask)
- tgt = self.norm1(tgt + self.dropout1(tgt1))
- if self.with_cross_attn:
- tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
- tgt = self.norm2(tgt + self.dropout2(tgt2))
- tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
- return tgt
- class PositionalEncoding(nn.Module):
- """Inject some information about the relative or absolute position of the
- tokens in the sequence. The positional encodings have the same dimension as
- the embeddings, so that the two can be summed. Here, we use sine and cosine
- functions of different frequencies.
- .. math::
- \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
- \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
- \text{where pos is the word position and i is the embed idx)
- Args:
- d_model: the embed dim (required).
- dropout: the dropout value (default=0.1).
- max_len: the max. length of the incoming sequence (default=5000).
- Examples:
- >>> pos_encoder = PositionalEncoding(d_model)
- """
- def __init__(self, dropout, dim, max_len=5000):
- super(PositionalEncoding, self).__init__()
- self.dropout = nn.Dropout(p=dropout)
- pe = torch.zeros([max_len, dim])
- position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = torch.unsqueeze(pe, 0)
- # pe = torch.permute(pe, [1, 0, 2])
- self.register_buffer('pe', pe)
- def forward(self, x):
- """Inputs of forward function
- Args:
- x: the sequence fed to the positional encoder model (required).
- Shape:
- x: [sequence length, batch size, embed dim]
- output: [sequence length, batch size, embed dim]
- Examples:
- >>> output = pos_encoder(x)
- """
- # x = x.permute([1, 0, 2])
- # x = x + self.pe[:x.shape[0], :]
- x = x + self.pe[:, :x.shape[1], :]
- return self.dropout(x) # .permute([1, 0, 2])
- class PositionalEncoding_2d(nn.Module):
- """Inject some information about the relative or absolute position of the
- tokens in the sequence. The positional encodings have the same dimension as
- the embeddings, so that the two can be summed. Here, we use sine and cosine
- functions of different frequencies.
- .. math::
- \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
- \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
- \text{where pos is the word position and i is the embed idx)
- Args:
- d_model: the embed dim (required).
- dropout: the dropout value (default=0.1).
- max_len: the max. length of the incoming sequence (default=5000).
- Examples:
- >>> pos_encoder = PositionalEncoding(d_model)
- """
- def __init__(self, dropout, dim, max_len=5000):
- super(PositionalEncoding_2d, self).__init__()
- self.dropout = nn.Dropout(p=dropout)
- pe = torch.zeros([max_len, dim])
- position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = torch.permute(torch.unsqueeze(pe, 0), [1, 0, 2])
- self.register_buffer('pe', pe)
- self.avg_pool_1 = nn.AdaptiveAvgPool2d((1, 1))
- self.linear1 = nn.Linear(dim, dim)
- self.linear1.weight.data.fill_(1.0)
- self.avg_pool_2 = nn.AdaptiveAvgPool2d((1, 1))
- self.linear2 = nn.Linear(dim, dim)
- self.linear2.weight.data.fill_(1.0)
- def forward(self, x):
- """Inputs of forward function
- Args:
- x: the sequence fed to the positional encoder model (required).
- Shape:
- x: [sequence length, batch size, embed dim]
- output: [sequence length, batch size, embed dim]
- Examples:
- >>> output = pos_encoder(x)
- """
- w_pe = self.pe[:x.shape[-1], :]
- w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
- w_pe = w_pe * w1
- w_pe = torch.permute(w_pe, [1, 2, 0])
- w_pe = torch.unsqueeze(w_pe, 2)
- h_pe = self.pe[:x.shape[-2], :]
- w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
- h_pe = h_pe * w2
- h_pe = torch.permute(h_pe, [1, 2, 0])
- h_pe = torch.unsqueeze(h_pe, 3)
- x = x + w_pe + h_pe
- x = torch.permute(
- torch.reshape(x,
- [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
- [2, 0, 1],
- )
- return self.dropout(x)
- class Embeddings(nn.Module):
- def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
- super(Embeddings, self).__init__()
- self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
- self.embedding.weight.data.normal_(mean=0.0, std=d_model**-0.5)
- self.d_model = d_model
- self.scale_embedding = scale_embedding
- def forward(self, x):
- if self.scale_embedding:
- x = self.embedding(x)
- return x * math.sqrt(self.d_model)
- return self.embedding(x)
|