123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398 |
- # Scene Text Recognition Model Hub
- # Copyright 2022 Darwin Bautista
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # https://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from itertools import permutations
- from collections import OrderedDict
- import hashlib
- import os
- import gzip
- import html
- import urllib
- import warnings
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from torch.nn.modules import transformer
- from typing import Any, Optional, Tuple, List, Union
- from pkg_resources import packaging
- from PIL import Image
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
- from tqdm import tqdm
- from functools import lru_cache
- import ftfy
- import regex as re
- try:
- from torchvision.transforms import InterpolationMode
- BICUBIC = InterpolationMode.BICUBIC
- except ImportError:
- BICUBIC = Image.BICUBIC
- @lru_cache()
- def default_bpe():
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
- @lru_cache()
- def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a corresponding list of unicode strings.
- The reversible bpe codes work on unicode strings.
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- This is a signficant percentage of your normal, say, 32K bpe vocab.
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- And avoids mapping to whitespace/control characters the bpe code barfs on.
- """
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
- cs = bs[:]
- n = 0
- for b in range(2**8):
- if b not in bs:
- bs.append(b)
- cs.append(2**8+n)
- n += 1
- cs = [chr(n) for n in cs]
- return dict(zip(bs, cs))
- def get_pairs(word):
- """Return set of symbol pairs in a word.
- Word is represented as tuple of symbols (symbols being variable-length strings).
- """
- pairs = set()
- prev_char = word[0]
- for char in word[1:]:
- pairs.add((prev_char, char))
- prev_char = char
- return pairs
- def basic_clean(text):
- text = ftfy.fix_text(text)
- text = html.unescape(html.unescape(text))
- return text.strip()
- def whitespace_clean(text):
- text = re.sub(r'\s+', ' ', text)
- text = text.strip()
- return text
- class SimpleTokenizer(object):
- def __init__(self, bpe_path: str = default_bpe()):
- self.byte_encoder = bytes_to_unicode()
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
- merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
- merges = merges[1:49152-256-2+1]
- merges = [tuple(merge.split()) for merge in merges]
- vocab = list(bytes_to_unicode().values())
- vocab = vocab + [v+'</w>' for v in vocab]
- for merge in merges:
- vocab.append(''.join(merge))
- vocab.extend(['<|startoftext|>', '<|endoftext|>'])
- self.encoder = dict(zip(vocab, range(len(vocab))))
- self.decoder = {v: k for k, v in self.encoder.items()}
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
- self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
- self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
- def bpe(self, token):
- if token in self.cache:
- return self.cache[token]
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
- pairs = get_pairs(word)
- if not pairs:
- return token+'</w>'
- while True:
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
- if bigram not in self.bpe_ranks:
- break
- first, second = bigram
- new_word = []
- i = 0
- while i < len(word):
- try:
- j = word.index(first, i)
- new_word.extend(word[i:j])
- i = j
- except:
- new_word.extend(word[i:])
- break
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
- new_word.append(first+second)
- i += 2
- else:
- new_word.append(word[i])
- i += 1
- new_word = tuple(new_word)
- word = new_word
- if len(word) == 1:
- break
- else:
- pairs = get_pairs(word)
- word = ' '.join(word)
- self.cache[token] = word
- return word
- def encode(self, text):
- bpe_tokens = []
- text = whitespace_clean(basic_clean(text)).lower()
- for token in re.findall(self.pat, text):
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
- return bpe_tokens
- def decode(self, tokens):
- text = ''.join([self.decoder[token] for token in tokens])
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
- return text
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
- warnings.warn("PyTorch version 1.7.1 or higher is recommended")
- __all__ = ["available_models", "load", "tokenize"]
- _tokenizer = SimpleTokenizer()
- _MODELS = {
- "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
- "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
- "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
- "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
- "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
- "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
- "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
- "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
- "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
- }
- def convert_weights(model: nn.Module):
- """Convert applicable model parameters to fp16"""
- def _convert_weights_to_fp16(l):
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
- l.weight.data = l.weight.data.half()
- if l.bias is not None:
- l.bias.data = l.bias.data.half()
- if isinstance(l, nn.MultiheadAttention):
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
- tensor = getattr(l, attr)
- if tensor is not None:
- tensor.data = tensor.data.half()
- for name in ["text_projection", "proj"]:
- if hasattr(l, name):
- attr = getattr(l, name)
- if attr is not None:
- attr.data = attr.data.half()
- model.apply(_convert_weights_to_fp16)
- def build_model(state_dict: dict):
- vit = "visual.proj" in state_dict
- if vit:
- vision_width = state_dict["visual.conv1.weight"].shape[0]
- vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
- image_resolution = vision_patch_size * grid_size
- else:
- counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
- vision_layers = tuple(counts)
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
- vision_patch_size = None
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
- image_resolution = output_width * 32
- embed_dim = state_dict["text_projection"].shape[1]
- context_length = state_dict["positional_embedding"].shape[0]
- vocab_size = state_dict["token_embedding.weight"].shape[0]
- transformer_width = state_dict["ln_final.weight"].shape[0]
- transformer_heads = transformer_width // 64
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
- model = CLIP(
- embed_dim,
- image_resolution, vision_layers, vision_width, vision_patch_size,
- context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
- )
- for key in ["input_resolution", "context_length", "vocab_size"]:
- if key in state_dict:
- del state_dict[key]
- convert_weights(model)
- model.load_state_dict(state_dict)
- return model.eval()
- def _download(url: str, root: str):
- os.makedirs(root, exist_ok=True)
- filename = os.path.basename(url)
- expected_sha256 = url.split("/")[-2]
- download_target = os.path.join(root, filename)
- if os.path.exists(download_target) and not os.path.isfile(download_target):
- raise RuntimeError(f"{download_target} exists and is not a regular file")
- if os.path.isfile(download_target):
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
- return download_target
- else:
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
- with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
- while True:
- buffer = source.read(8192)
- if not buffer:
- break
- output.write(buffer)
- loop.update(len(buffer))
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
- raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
- return download_target
- def _convert_image_to_rgb(image):
- return image.convert("RGB")
- def _transform(n_px):
- return Compose([
- Resize(n_px, interpolation=BICUBIC),
- CenterCrop(n_px),
- _convert_image_to_rgb,
- ToTensor(),
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
- ])
- def available_models() -> List[str]:
- """Returns the names of available CLIP models"""
- return list(_MODELS.keys())
- class Bottleneck(nn.Module):
- expansion = 4
- def __init__(self, inplanes, planes, stride=1):
- super().__init__()
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
- self.bn1 = nn.BatchNorm2d(planes)
- self.relu1 = nn.ReLU(inplace=True)
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(planes)
- self.relu2 = nn.ReLU(inplace=True)
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
- self.relu3 = nn.ReLU(inplace=True)
- self.downsample = None
- self.stride = stride
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
- self.downsample = nn.Sequential(OrderedDict([
- ("-1", nn.AvgPool2d(stride)),
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
- ("1", nn.BatchNorm2d(planes * self.expansion))
- ]))
- def forward(self, x: torch.Tensor):
- identity = x
- out = self.relu1(self.bn1(self.conv1(x)))
- out = self.relu2(self.bn2(self.conv2(out)))
- out = self.avgpool(out)
- out = self.bn3(self.conv3(out))
- if self.downsample is not None:
- identity = self.downsample(x)
- out += identity
- out = self.relu3(out)
- return out
- class AttentionPool2d(nn.Module):
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
- super().__init__()
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
- self.k_proj = nn.Linear(embed_dim, embed_dim)
- self.q_proj = nn.Linear(embed_dim, embed_dim)
- self.v_proj = nn.Linear(embed_dim, embed_dim)
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
- self.num_heads = num_heads
- def forward(self, x):
- x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
- x, _ = F.multi_head_attention_forward(
- query=x[:1], key=x, value=x,
- embed_dim_to_check=x.shape[-1],
- num_heads=self.num_heads,
- q_proj_weight=self.q_proj.weight,
- k_proj_weight=self.k_proj.weight,
- v_proj_weight=self.v_proj.weight,
- in_proj_weight=None,
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
- bias_k=None,
- bias_v=None,
- add_zero_attn=False,
- dropout_p=0,
- out_proj_weight=self.c_proj.weight,
- out_proj_bias=self.c_proj.bias,
- use_separate_proj_weight=True,
- training=self.training,
- need_weights=False
- )
- return x.squeeze(0)
- class ModifiedResNet(nn.Module):
- """
- A ResNet class that is similar to torchvision's but contains the following changes:
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- - The final pooling layer is a QKV attention instead of an average pool
- """
- def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
- super().__init__()
- self.output_dim = output_dim
- self.input_resolution = input_resolution
- # the 3-layer stem
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(width // 2)
- self.relu1 = nn.ReLU(inplace=True)
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(width // 2)
- self.relu2 = nn.ReLU(inplace=True)
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
- self.bn3 = nn.BatchNorm2d(width)
- self.relu3 = nn.ReLU(inplace=True)
- self.avgpool = nn.AvgPool2d(2)
- # residual layers
- self._inplanes = width # this is a *mutable* variable used during construction
- self.layer1 = self._make_layer(width, layers[0])
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
- embed_dim = width * 32 # the ResNet feature dimension
- self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
- def _make_layer(self, planes, blocks, stride=1):
- layers = [Bottleneck(self._inplanes, planes, stride)]
- self._inplanes = planes * Bottleneck.expansion
- for _ in range(1, blocks):
- layers.append(Bottleneck(self._inplanes, planes))
- return nn.Sequential(*layers)
- def forward(self, x):
- def stem(x):
- x = self.relu1(self.bn1(self.conv1(x)))
- x = self.relu2(self.bn2(self.conv2(x)))
- x = self.relu3(self.bn3(self.conv3(x)))
- x = self.avgpool(x)
- return x
- x = x.type(self.conv1.weight.dtype)
- x = stem(x)
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
- x = self.attnpool(x)
- return x
- class LayerNorm(nn.LayerNorm):
- """Subclass torch's LayerNorm to handle fp16."""
- def forward(self, x: torch.Tensor):
- orig_type = x.dtype
- ret = super().forward(x.type(torch.float32))
- return ret.type(orig_type)
- class QuickGELU(nn.Module):
- def forward(self, x: torch.Tensor):
- return x * torch.sigmoid(1.702 * x)
- class ResidualAttentionBlock(nn.Module):
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
- super().__init__()
- self.attn = nn.MultiheadAttention(d_model, n_head)
- self.ln_1 = LayerNorm(d_model)
- self.mlp = nn.Sequential(OrderedDict([
- ("c_fc", nn.Linear(d_model, d_model * 4)),
- ("gelu", QuickGELU()),
- ("c_proj", nn.Linear(d_model * 4, d_model))
- ]))
- self.ln_2 = LayerNorm(d_model)
- self.attn_mask = attn_mask
- def attention(self, x: torch.Tensor):
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
- return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
- def forward(self, x: torch.Tensor):
- x = x + self.attention(self.ln_1(x))
- x = x + self.mlp(self.ln_2(x))
- return x
- class Transformer(nn.Module):
- def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
- super().__init__()
- self.width = width
- self.layers = layers
- self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
- def forward(self, x: torch.Tensor):
- return self.resblocks(x)
- class VisionTransformer(nn.Module):
- def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
- super().__init__()
- self.input_resolution = input_resolution
- self.output_dim = output_dim
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
- scale = width ** -0.5
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
- self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
- self.ln_pre = LayerNorm(width)
- self.transformer = Transformer(width, layers, heads)
- self.ln_post = LayerNorm(width)
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
- def forward(self, x: torch.Tensor):
- x = self.conv1(x) # shape = [*, width, grid, grid]
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
- x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
- x = x + self.positional_embedding.to(x.dtype)
- x = self.ln_pre(x)
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.transformer(x)
- x = x.permute(1, 0, 2) # LND -> NLD
- x = self.ln_post(x)
- if self.proj is not None:
- x = x @ self.proj
- return x
- class CLIP(nn.Module):
- def __init__(self,
- embed_dim: int,
- # vision
- image_resolution: int,
- vision_layers: Union[Tuple[int, int, int, int], int],
- vision_width: int,
- vision_patch_size: int,
- # text
- context_length: int,
- vocab_size: int,
- transformer_width: int,
- transformer_heads: int,
- transformer_layers: int
- ):
- super().__init__()
- self.context_length = context_length
- if isinstance(vision_layers, (tuple, list)):
- vision_heads = vision_width * 32 // 64
- self.visual = ModifiedResNet(
- layers=vision_layers,
- output_dim=embed_dim,
- heads=vision_heads,
- input_resolution=image_resolution,
- width=vision_width
- )
- else:
- vision_heads = vision_width // 64
- self.visual = VisionTransformer(
- input_resolution=image_resolution,
- patch_size=vision_patch_size,
- width=vision_width,
- layers=vision_layers,
- heads=vision_heads,
- output_dim=embed_dim
- )
- self.transformer = Transformer(
- width=transformer_width,
- layers=transformer_layers,
- heads=transformer_heads,
- attn_mask=self.build_attention_mask()
- )
- self.vocab_size = vocab_size
- self.token_embedding = nn.Embedding(vocab_size, transformer_width)
- self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
- self.ln_final = LayerNorm(transformer_width)
- self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
- self.initialize_parameters()
- def initialize_parameters(self):
- nn.init.normal_(self.token_embedding.weight, std=0.02)
- nn.init.normal_(self.positional_embedding, std=0.01)
- if isinstance(self.visual, ModifiedResNet):
- if self.visual.attnpool is not None:
- std = self.visual.attnpool.c_proj.in_features ** -0.5
- nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
- nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
- nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
- nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
- for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
- for name, param in resnet_block.named_parameters():
- if name.endswith("bn3.weight"):
- nn.init.zeros_(param)
- proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
- attn_std = self.transformer.width ** -0.5
- fc_std = (2 * self.transformer.width) ** -0.5
- for block in self.transformer.resblocks:
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
- if self.text_projection is not None:
- nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
- def build_attention_mask(self):
- # lazily create causal attention mask, with full attention between the vision tokens
- # pytorch uses additive attention mask; fill with -inf
- mask = torch.empty(self.context_length, self.context_length)
- mask.fill_(float("-inf"))
- mask.triu_(1) # zero out the lower diagonal
- return mask
- @property
- def dtype(self):
- return self.visual.conv1.weight.dtype
- def encode_image(self, image):
- return self.visual(image.type(self.dtype))
- def encode_text(self, text):
- x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
- x = x + self.positional_embedding.type(self.dtype)
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.transformer(x)
- x = x.permute(1, 0, 2) # LND -> NLD
- x = self.ln_final(x).type(self.dtype)
- # take features from the eot embedding (eot_token is the highest number in each sequence)
- output = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
- output = torch.cat([output.unsqueeze(1), x], dim=1)
- return output
- def forward(self, image, text):
- image_features = self.encode_image(image)
- text_features = self.encode_text(text)
- # normalized features
- image_features = image_features / image_features.norm(dim=1, keepdim=True)
- text_features = text_features / text_features.norm(dim=1, keepdim=True)
- # cosine similarity as logits
- logit_scale = self.logit_scale.exp()
- logits_per_image = logit_scale * image_features @ text_features.t()
- logits_per_text = logits_per_image.t()
- # shape = [global_batch_size, global_batch_size]
- return logits_per_image, logits_per_text
- class FMU(nn.Module):
- """A Transformer decoder layer supporting two-stream attention (XLNet)
- This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu',
- layer_norm_eps=1e-5):
- super().__init__()
- self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
- self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.dropout3 = nn.Dropout(dropout)
- self.activation = transformer._get_activation_fn(activation)
- def __setstate__(self, state):
- if 'activation' not in state:
- state['activation'] = F.gelu
- super().__setstate__(state)
- def forward(self, query: Tensor, memory: Tensor):
- """Forward pass for a single stream (i.e. content or query)
- tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
- Both tgt_kv and memory are expected to be LayerNorm'd too.
- memory is LayerNorm'd by ViT.
- """
- query1, ca_weights = self.cross_attn(query, memory, memory)
- query = query + self.dropout1(query1)
- query2 = self.linear2(self.dropout2(self.activation(self.linear1(self.norm(query)))))
- query = query + self.dropout3(query2)
- return query
- class DecoderLayer(nn.Module):
- """A Transformer decoder layer supporting two-stream attention (XLNet) This
- implements a pre-LN decoder, as opposed to the post-LN default in
- PyTorch."""
- def __init__(
- self,
- d_model,
- nhead,
- dim_feedforward=2048,
- dropout=0.1,
- activation='gelu',
- layer_norm_eps=1e-5,
- ):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(d_model,
- nhead,
- dropout=dropout,
- batch_first=True)
- self.cross_attn = nn.MultiheadAttention(d_model,
- nhead,
- dropout=dropout,
- batch_first=True)
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
- self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
- self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.dropout3 = nn.Dropout(dropout)
- self.activation = transformer._get_activation_fn(activation)
- def __setstate__(self, state):
- if 'activation' not in state:
- state['activation'] = F.gelu
- super().__setstate__(state)
- def forward_stream(
- self,
- tgt: Tensor,
- tgt_norm: Tensor,
- tgt_kv: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor],
- tgt_key_padding_mask: Optional[Tensor],
- ):
- """Forward pass for a single stream (i.e. content or query) tgt_norm is
- just a LayerNorm'd tgt.
- Added as a separate parameter for efficiency. Both tgt_kv and memory
- are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT.
- """
- tgt2, sa_weights = self.self_attn(
- tgt_norm,
- tgt_kv,
- tgt_kv,
- attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask)
- tgt = tgt + self.dropout1(tgt2)
- tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
- self.attn_map = ca_weights
- tgt = tgt + self.dropout2(tgt2)
- tgt2 = self.linear2(
- self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
- tgt = tgt + self.dropout3(tgt2)
- return tgt, sa_weights, ca_weights
- def forward(
- self,
- query,
- content,
- memory,
- query_mask: Optional[Tensor] = None,
- content_mask: Optional[Tensor] = None,
- content_key_padding_mask: Optional[Tensor] = None,
- update_content: bool = True,
- ):
- query_norm = self.norm_q(query)
- content_norm = self.norm_c(content)
- query = self.forward_stream(query, query_norm, content_norm, memory,
- query_mask, content_key_padding_mask)[0]
- if update_content:
- content = self.forward_stream(content, content_norm, content_norm,
- memory, content_mask,
- content_key_padding_mask)[0]
- return query, content
- class Decoder(nn.Module):
- __constants__ = ['norm']
- def __init__(self, decoder_layer, num_layers, norm):
- super().__init__()
- self.layers = transformer._get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
- def forward(
- self,
- query,
- content,
- memory,
- query_mask: Optional[Tensor] = None,
- content_mask: Optional[Tensor] = None,
- content_key_padding_mask: Optional[Tensor] = None,
- ):
- for i, mod in enumerate(self.layers):
- last = i == len(self.layers) - 1
- query, content = mod(
- query,
- content,
- memory,
- query_mask,
- content_mask,
- content_key_padding_mask,
- update_content=not last,
- )
- query = self.norm(query)
- return query
- class TokenEmbedding(nn.Module):
- def __init__(self, charset_size: int, embed_dim: int):
- super().__init__()
- self.embedding = nn.Embedding(charset_size, embed_dim)
- self.embed_dim = embed_dim
- def forward(self, tokens: torch.Tensor):
- return math.sqrt(self.embed_dim) * self.embedding(tokens)
- def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
- """Load a CLIP model
- Parameters
- ----------
- name : str
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
- device : Union[str, torch.device]
- The device to put the loaded model
- jit : bool
- Whether to load the optimized JIT model or more hackable non-JIT model (default).
- download_root: str
- path to download the model files; by default, it uses "~/.cache/clip"
- Returns
- -------
- model : torch.nn.Module
- The CLIP model
- preprocess : Callable[[PIL.Image], torch.Tensor]
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
- """
- if name in _MODELS:
- model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
- elif os.path.isfile(name):
- model_path = name
- else:
- raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
- with open(model_path, 'rb') as opened_file:
- try:
- # loading JIT archive
- model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
- state_dict = None
- except RuntimeError:
- # loading saved state dict
- if jit:
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
- jit = False
- state_dict = torch.load(opened_file, map_location="cpu")
- if not jit:
- model = build_model(state_dict or model.state_dict()).to(device)
- if str(device) == "cpu":
- model.float()
- return model, _transform(model.visual.input_resolution)
- # patch the device names
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
- def patch_device(module):
- try:
- graphs = [module.graph] if hasattr(module, "graph") else []
- except RuntimeError:
- graphs = []
- if hasattr(module, "forward1"):
- graphs.append(module.forward1.graph)
- for graph in graphs:
- for node in graph.findAllNodes("prim::Constant"):
- if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
- node.copyAttributes(device_node)
- model.apply(patch_device)
- patch_device(model.encode_image)
- patch_device(model.encode_text)
- # patch dtype to float32 on CPU
- if str(device) == "cpu":
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
- float_node = float_input.node()
- def patch_float(module):
- try:
- graphs = [module.graph] if hasattr(module, "graph") else []
- except RuntimeError:
- graphs = []
- if hasattr(module, "forward1"):
- graphs.append(module.forward1.graph)
- for graph in graphs:
- for node in graph.findAllNodes("aten::to"):
- inputs = list(node.inputs())
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
- if inputs[i].node()["value"] == 5:
- inputs[i].node().copyAttributes(float_node)
- model.apply(patch_float)
- patch_float(model.encode_image)
- patch_float(model.encode_text)
- model.float()
- return model, _transform(model.input_resolution.item())
- def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
- """
- Returns the tokenized representation of given input string(s)
- Parameters
- ----------
- texts : Union[str, List[str]]
- An input string or a list of input strings to tokenize
- context_length : int
- The context length to use; all CLIP models use 77 as the context length
- truncate: bool
- Whether to truncate the text in case its encoding is longer than the context length
- Returns
- -------
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
- We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
- """
- if isinstance(texts, str):
- texts = [texts]
- sot_token = _tokenizer.encoder["<|startoftext|>"]
- eot_token = _tokenizer.encoder["<|endoftext|>"]
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
- else:
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
- for i, tokens in enumerate(all_tokens):
- if len(tokens) > context_length:
- if truncate:
- tokens = tokens[:context_length]
- tokens[-1] = eot_token
- else:
- raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
- result[i, :len(tokens)] = torch.tensor(tokens)
- return result
- class DptrParseq(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- max_label_length=25,
- embed_dim=512,
- dec_num_heads=8,
- dec_mlp_ratio=4,
- dec_depth=6,
- perm_num=6,
- perm_forward=True,
- perm_mirrored=True,
- decode_ar=True,
- refine_iters=1,
- dropout=0.1,
- is_pretrain=True,
- ORP_path=None,
- **kwargs: Any) -> None:
- super().__init__()
- self.pad_id = out_channels - 1
- self.eos_id = 0
- self.bos_id = out_channels - 2
- self.max_label_length = max_label_length
- self.decode_ar = decode_ar
- self.refine_iters = refine_iters
- self.is_pretrain = is_pretrain
- if not is_pretrain:
- self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim))
- self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
- decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
- self.decoder = Decoder(decoder_layer,
- num_layers=dec_depth,
- norm=nn.LayerNorm(embed_dim))
- # Perm/attn mask stuff
- self.rng = np.random.default_rng()
- self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
- self.perm_forward = perm_forward
- self.perm_mirrored = perm_mirrored
- # We don't predict <bos> nor <pad>
- self.head = nn.Linear(embed_dim, out_channels - 2)
- self.text_embed = TokenEmbedding(out_channels, embed_dim)
- # +1 for <eos>
- self.pos_queries = nn.Parameter(
- torch.Tensor(1, max_label_length + 1, embed_dim))
- self.dropout = nn.Dropout(p=dropout)
- # Encoder has its own init.
- self.apply(self._init_weights)
- nn.init.trunc_normal_(self.pos_queries, std=0.02)
- if is_pretrain:
- self.clip_encoder, preprocess = load("ViT-B/16")
- for p in self.clip_encoder.parameters():
- p.requires_grad = False
- if ORP_path is None:
- background_image_folder_path = 'background_mages_folder/path'
- self.background_features = self.get_noise(background_image_folder_path, preprocess)
- torch.save(self.background_features, 'save/noise/to/ORP_path')
- else:
- self.background_features = torch.load(ORP_path, map_location='cpu')
- def _init_weights(self, module: nn.Module):
- """Initialize the weights using the typical initialization schemes used
- in SOTA models."""
- if isinstance(module, nn.Linear):
- nn.init.trunc_normal_(module.weight, std=0.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Embedding):
- nn.init.trunc_normal_(module.weight, std=0.02)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.Conv2d):
- nn.init.kaiming_normal_(module.weight,
- mode='fan_out',
- nonlinearity='relu')
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.ones_(module.weight)
- nn.init.zeros_(module.bias)
- @torch.jit.ignore
- def no_weight_decay(self):
- param_names = {'text_embed.embedding.weight', 'pos_queries'}
- return param_names
- def get_noise(self, background_image_path, preprocess):
- image_paths = [os.path.join(background_image_path, filename) for filename in os.listdir(image_folder_path) if
- filename.endswith(('.png', '.jpg', '.jpeg'))]
- features = []
- for image_path in image_paths:
- image = Image.open(image_path)
- input = preprocess(image).unsqueeze(0).to(self._device)
- with torch.no_grad():
- feature = self.clip_encoder.encode_image(input)
- features.append(feature)
- image.close()
- return torch.cat(features).cpu().numpy()
- def clip_encode(self, labels):
- text_inputs = torch.cat([tokenize(f"a photo of a '{c}'") for c in labels]).to(self._device)
- return self.clip_encoder.encode_text(text_inputs)
- def decode(
- self,
- tgt: torch.Tensor,
- memory: torch.Tensor,
- tgt_mask: Optional[Tensor] = None,
- tgt_padding_mask: Optional[Tensor] = None,
- tgt_query: Optional[Tensor] = None,
- tgt_query_mask: Optional[Tensor] = None,
- pos_query: torch.Tensor = None,
- ):
- N, L = tgt.shape
- # <bos> stands for the null context. We only supply position information for characters after <bos>.
- null_ctx = self.text_embed(tgt[:, :1])
- if tgt_query is None:
- tgt_query = pos_query[:, :L]
- tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:])
- tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
- tgt_query = self.dropout(tgt_query)
- return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask,
- tgt_mask, tgt_padding_mask)
- def forward(self, memory, data=None, pos_query=None):
- # print(memory.shape, data[0].shape)
- if self.training:
- if self.is_pretrain:
- return self.training_step(None, pos_query, data[0], memory)
- return self.training_step(memory, pos_query, data[0], None)
- else:
- if self.is_pretrain:
- return self.forward_test(None, memory, pos_query)
- return self.forward_test(memory, None, pos_query)
- def forward_test(self,
- memory: Tensor, clip_ids,
- pos_query: Tensor = None,
- max_length: Optional[int] = None) -> Tensor:
- testing = max_length is None
- max_length = (self.max_label_length if max_length is None else min(
- max_length, self.max_label_length))
- if self.is_pretrain:
- memory = self.clip_encoder.encode_text(clip_ids)
- else:
- bs = memory.shape[0]
- token_query = self.token_query.expand(bs, -1, -1)
- memory = self.fmu(token_query, memory)
- _device = memory.get_device()
- bs = memory.shape[0]
- # +1 for <eos> at end of sequence.
- num_steps = max_length + 1
- # memory = self.encode(images)
- # Query positions up to `num_steps`
- if pos_query is None:
- pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
- else:
- pos_queries = pos_query
- # Special case for the forward permutation. Faster than using `generate_attn_masks()`
- tgt_mask = query_mask = torch.triu(
- torch.full((num_steps, num_steps), float('-inf'), device=_device),
- 1)
- self.attn_maps = []
- if self.decode_ar:
- tgt_in = torch.full((bs, num_steps),
- self.pad_id,
- dtype=torch.long,
- device=_device)
- tgt_in[:, 0] = self.bos_id
- logits = []
- for i in range(num_steps):
- j = i + 1 # next token index
- # Efficient decoding:
- # Input the context up to the ith token. We use only one query (at position = i) at a time.
- # This works because of the lookahead masking effect of the canonical (forward) AR context.
- # Past tokens have no access to future tokens, hence are fixed once computed.
- tgt_out = self.decode(
- tgt_in[:, :j],
- memory,
- tgt_mask[:j, :j],
- tgt_query=pos_queries[:, i:j],
- tgt_query_mask=query_mask[i:j, :j],
- pos_query=pos_queries,
- )
- self.attn_maps.append(self.decoder.layers[-1].attn_map)
- # the next token probability is in the output's ith token position
- p_i = self.head(tgt_out)
- logits.append(p_i)
- if j < num_steps:
- # greedy decode. add the next token index to the target input
- tgt_in[:, j] = p_i.squeeze().argmax(-1)
- # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
- if testing and (tgt_in == self.eos_id).any(dim=-1).all():
- break
- logits = torch.cat(logits, dim=1)
- else:
- # No prior context, so input is just <bos>. We query all positions.
- tgt_in = torch.full((bs, 1),
- self.bos_id,
- dtype=torch.long,
- device=_device)
- tgt_out = self.decode(tgt_in,
- memory,
- tgt_query=pos_queries,
- pos_query=pos_queries)
- logits = self.head(tgt_out)
- if self.refine_iters:
- # For iterative refinement, we always use a 'cloze' mask.
- # We can derive it from the AR forward mask by unmasking the token context to the right.
- query_mask[torch.triu(
- torch.ones(num_steps,
- num_steps,
- dtype=torch.bool,
- device=_device), 2)] = 0
- bos = torch.full((bs, 1),
- self.bos_id,
- dtype=torch.long,
- device=_device)
- for i in range(self.refine_iters):
- # Prior context is the previous output.
- tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
- tgt_len = tgt_in.shape[1]
- tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum(
- -1) > 0 # mask tokens beyond the first EOS token.
- tgt_out = self.decode(
- tgt_in,
- memory,
- tgt_mask[:tgt_len, :tgt_len],
- tgt_padding_mask,
- tgt_query=pos_queries,
- tgt_query_mask=query_mask[:, :tgt_len],
- pos_query=pos_queries,
- )
- logits = self.head(tgt_out)
- return F.softmax(logits, -1)
- def gen_tgt_perms(self, tgt, _device):
- """Generate shared permutations for the whole batch.
- This works because the same attention mask can be used for the shorter
- sequences because of the padding mask.
- """
- # We don't permute the position of BOS, we permute EOS separately
- max_num_chars = tgt.shape[1] - 2
- # Special handling for 1-character sequences
- if max_num_chars == 1:
- return torch.arange(3, device=_device).unsqueeze(0)
- perms = [torch.arange(max_num_chars, device=_device)
- ] if self.perm_forward else []
- # Additional permutations if needed
- max_perms = math.factorial(max_num_chars)
- if self.perm_mirrored:
- max_perms //= 2
- num_gen_perms = min(self.max_gen_perms, max_perms)
- # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
- # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
- if max_num_chars < 5:
- # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
- # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
- if max_num_chars == 4 and self.perm_mirrored:
- selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
- else:
- selector = list(range(max_perms))
- perm_pool = torch.as_tensor(list(
- permutations(range(max_num_chars), max_num_chars)),
- device=_device)[selector]
- # If the forward permutation is always selected, no need to add it to the pool for sampling
- if self.perm_forward:
- perm_pool = perm_pool[1:]
- perms = torch.stack(perms)
- if len(perm_pool):
- i = self.rng.choice(len(perm_pool),
- size=num_gen_perms - len(perms),
- replace=False)
- perms = torch.cat([perms, perm_pool[i]])
- else:
- perms.extend([
- torch.randperm(max_num_chars, device=_device)
- for _ in range(num_gen_perms - len(perms))
- ])
- perms = torch.stack(perms)
- if self.perm_mirrored:
- # Add complementary pairs
- comp = perms.flip(-1)
- # Stack in such a way that the pairs are next to each other.
- perms = torch.stack([perms, comp
- ]).transpose(0, 1).reshape(-1, max_num_chars)
- # NOTE:
- # The only meaningful way of permuting the EOS position is by moving it one character position at a time.
- # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
- # positions will always be much less than the number of permutations (unless a low perm_num is set).
- # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
- # distribute it across the chosen number of permutations.
- # Add position indices of BOS and EOS
- bos_idx = perms.new_zeros((len(perms), 1))
- eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
- perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
- # Special handling for the reverse direction. This does two things:
- # 1. Reverse context for the characters
- # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
- if len(perms) > 1:
- perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1,
- device=_device)
- return perms
- def generate_attn_masks(self, perm, _device):
- """Generate attention masks given a sequence permutation (includes pos.
- for bos and eos tokens)
- :param perm: the permutation sequence. i = 0 is always the BOS
- :return: lookahead attention masks
- """
- sz = perm.shape[0]
- mask = torch.zeros((sz, sz), device=_device)
- for i in range(sz):
- query_idx = perm[i]
- masked_keys = perm[i + 1:]
- mask[query_idx, masked_keys] = float('-inf')
- content_mask = mask[:-1, :-1].clone()
- mask[torch.eye(sz, dtype=torch.bool,
- device=_device)] = float('-inf') # mask "self"
- query_mask = mask[1:, :-1]
- return content_mask, query_mask
- def training_step(self, memory, pos_query, tgt_ids, clip_ids):
- bs = tgt_ids.shape[0]
- if self.is_pretrain:
- memory = self.clip_encoder.encode_text(clip_ids)
- n = memory.shape[1]
- B, N, D = self.background_features.shape
- random_B = np.random.choice(B, bs, replace=False)
- random_N = np.random.choice(N, n, replace=False)
- noise = self.background_features[random_B][:, random_N]
- noise = torch.from_numpy(noise).to(memory.get_device())
- memory = memory + noise * 1e-1
- else:
- token_query = self.token_query.expand(bs, -1, -1)
- memory = self.fmu(token_query, memory)
- if pos_query is None:
- pos_query = self.pos_queries.expand(bs, -1, -1)
- # Prepare the target sequences (input and output)
- tgt_perms = self.gen_tgt_perms(tgt_ids, memory.get_device())
- tgt_in = tgt_ids[:, :-1]
- tgt_out = tgt_ids[:, 1:]
- # The [EOS] token is not depended upon by any other token in any permutation ordering
- tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
- loss = 0
- loss_numel = 0
- n = (tgt_out != self.pad_id).sum().item()
- for i, perm in enumerate(tgt_perms):
- tgt_mask, query_mask = self.generate_attn_masks(
- perm, memory.get_device())
- # print("tgt_in:", tgt_in, "tgt_out:", tgt_out, "tgt_padding_mask:", tgt_padding_mask)
- # print('tgt_mask:', tgt_mask)
- # print('query_mask:', query_mask)
- # print(tgt_in.shape, memory.shape, tgt_mask.shape, tgt_padding_mask.shape, query_mask.shape, pos_query.shape)
- out = self.decode(
- tgt_in,
- memory,
- tgt_mask,
- tgt_padding_mask,
- tgt_query_mask=query_mask,
- pos_query=pos_query,
- )
- # print('out:', out)
- logits = self.head(out)
- # print('logits:', logits)
- if i == 0:
- final_out = logits
- loss += n * F.cross_entropy(logits.flatten(end_dim=1),
- tgt_out.flatten(),
- ignore_index=self.pad_id)
- loss_numel += n
- # After the second iteration (i.e. done with canonical and reverse orderings),
- # remove the [EOS] tokens for the succeeding perms
- if i == 1:
- tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id,
- tgt_out)
- n = (tgt_out != self.pad_id).sum().item()
- loss /= loss_numel
- # self.log('loss', loss)
- return [loss, final_out]
|