dptr_parseq_clip_b_decoder.py 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398
  1. # Scene Text Recognition Model Hub
  2. # Copyright 2022 Darwin Bautista
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # https://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from itertools import permutations
  17. from collections import OrderedDict
  18. import hashlib
  19. import os
  20. import gzip
  21. import html
  22. import urllib
  23. import warnings
  24. import numpy as np
  25. import torch
  26. import torch.nn as nn
  27. import torch.nn.functional as F
  28. from torch import Tensor
  29. from torch.nn.modules import transformer
  30. from typing import Any, Optional, Tuple, List, Union
  31. from pkg_resources import packaging
  32. from PIL import Image
  33. from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
  34. from tqdm import tqdm
  35. from functools import lru_cache
  36. import ftfy
  37. import regex as re
  38. try:
  39. from torchvision.transforms import InterpolationMode
  40. BICUBIC = InterpolationMode.BICUBIC
  41. except ImportError:
  42. BICUBIC = Image.BICUBIC
  43. @lru_cache()
  44. def default_bpe():
  45. return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
  46. @lru_cache()
  47. def bytes_to_unicode():
  48. """
  49. Returns list of utf-8 byte and a corresponding list of unicode strings.
  50. The reversible bpe codes work on unicode strings.
  51. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  52. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  53. This is a signficant percentage of your normal, say, 32K bpe vocab.
  54. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  55. And avoids mapping to whitespace/control characters the bpe code barfs on.
  56. """
  57. bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
  58. cs = bs[:]
  59. n = 0
  60. for b in range(2**8):
  61. if b not in bs:
  62. bs.append(b)
  63. cs.append(2**8+n)
  64. n += 1
  65. cs = [chr(n) for n in cs]
  66. return dict(zip(bs, cs))
  67. def get_pairs(word):
  68. """Return set of symbol pairs in a word.
  69. Word is represented as tuple of symbols (symbols being variable-length strings).
  70. """
  71. pairs = set()
  72. prev_char = word[0]
  73. for char in word[1:]:
  74. pairs.add((prev_char, char))
  75. prev_char = char
  76. return pairs
  77. def basic_clean(text):
  78. text = ftfy.fix_text(text)
  79. text = html.unescape(html.unescape(text))
  80. return text.strip()
  81. def whitespace_clean(text):
  82. text = re.sub(r'\s+', ' ', text)
  83. text = text.strip()
  84. return text
  85. class SimpleTokenizer(object):
  86. def __init__(self, bpe_path: str = default_bpe()):
  87. self.byte_encoder = bytes_to_unicode()
  88. self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
  89. merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
  90. merges = merges[1:49152-256-2+1]
  91. merges = [tuple(merge.split()) for merge in merges]
  92. vocab = list(bytes_to_unicode().values())
  93. vocab = vocab + [v+'</w>' for v in vocab]
  94. for merge in merges:
  95. vocab.append(''.join(merge))
  96. vocab.extend(['<|startoftext|>', '<|endoftext|>'])
  97. self.encoder = dict(zip(vocab, range(len(vocab))))
  98. self.decoder = {v: k for k, v in self.encoder.items()}
  99. self.bpe_ranks = dict(zip(merges, range(len(merges))))
  100. self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
  101. self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
  102. def bpe(self, token):
  103. if token in self.cache:
  104. return self.cache[token]
  105. word = tuple(token[:-1]) + ( token[-1] + '</w>',)
  106. pairs = get_pairs(word)
  107. if not pairs:
  108. return token+'</w>'
  109. while True:
  110. bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
  111. if bigram not in self.bpe_ranks:
  112. break
  113. first, second = bigram
  114. new_word = []
  115. i = 0
  116. while i < len(word):
  117. try:
  118. j = word.index(first, i)
  119. new_word.extend(word[i:j])
  120. i = j
  121. except:
  122. new_word.extend(word[i:])
  123. break
  124. if word[i] == first and i < len(word)-1 and word[i+1] == second:
  125. new_word.append(first+second)
  126. i += 2
  127. else:
  128. new_word.append(word[i])
  129. i += 1
  130. new_word = tuple(new_word)
  131. word = new_word
  132. if len(word) == 1:
  133. break
  134. else:
  135. pairs = get_pairs(word)
  136. word = ' '.join(word)
  137. self.cache[token] = word
  138. return word
  139. def encode(self, text):
  140. bpe_tokens = []
  141. text = whitespace_clean(basic_clean(text)).lower()
  142. for token in re.findall(self.pat, text):
  143. token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
  144. bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
  145. return bpe_tokens
  146. def decode(self, tokens):
  147. text = ''.join([self.decoder[token] for token in tokens])
  148. text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
  149. return text
  150. if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
  151. warnings.warn("PyTorch version 1.7.1 or higher is recommended")
  152. __all__ = ["available_models", "load", "tokenize"]
  153. _tokenizer = SimpleTokenizer()
  154. _MODELS = {
  155. "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
  156. "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
  157. "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
  158. "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
  159. "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
  160. "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
  161. "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
  162. "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
  163. "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
  164. }
  165. def convert_weights(model: nn.Module):
  166. """Convert applicable model parameters to fp16"""
  167. def _convert_weights_to_fp16(l):
  168. if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
  169. l.weight.data = l.weight.data.half()
  170. if l.bias is not None:
  171. l.bias.data = l.bias.data.half()
  172. if isinstance(l, nn.MultiheadAttention):
  173. for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
  174. tensor = getattr(l, attr)
  175. if tensor is not None:
  176. tensor.data = tensor.data.half()
  177. for name in ["text_projection", "proj"]:
  178. if hasattr(l, name):
  179. attr = getattr(l, name)
  180. if attr is not None:
  181. attr.data = attr.data.half()
  182. model.apply(_convert_weights_to_fp16)
  183. def build_model(state_dict: dict):
  184. vit = "visual.proj" in state_dict
  185. if vit:
  186. vision_width = state_dict["visual.conv1.weight"].shape[0]
  187. vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
  188. vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
  189. grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
  190. image_resolution = vision_patch_size * grid_size
  191. else:
  192. counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
  193. vision_layers = tuple(counts)
  194. vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
  195. output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
  196. vision_patch_size = None
  197. assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
  198. image_resolution = output_width * 32
  199. embed_dim = state_dict["text_projection"].shape[1]
  200. context_length = state_dict["positional_embedding"].shape[0]
  201. vocab_size = state_dict["token_embedding.weight"].shape[0]
  202. transformer_width = state_dict["ln_final.weight"].shape[0]
  203. transformer_heads = transformer_width // 64
  204. transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
  205. model = CLIP(
  206. embed_dim,
  207. image_resolution, vision_layers, vision_width, vision_patch_size,
  208. context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
  209. )
  210. for key in ["input_resolution", "context_length", "vocab_size"]:
  211. if key in state_dict:
  212. del state_dict[key]
  213. convert_weights(model)
  214. model.load_state_dict(state_dict)
  215. return model.eval()
  216. def _download(url: str, root: str):
  217. os.makedirs(root, exist_ok=True)
  218. filename = os.path.basename(url)
  219. expected_sha256 = url.split("/")[-2]
  220. download_target = os.path.join(root, filename)
  221. if os.path.exists(download_target) and not os.path.isfile(download_target):
  222. raise RuntimeError(f"{download_target} exists and is not a regular file")
  223. if os.path.isfile(download_target):
  224. if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
  225. return download_target
  226. else:
  227. warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
  228. with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
  229. with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
  230. while True:
  231. buffer = source.read(8192)
  232. if not buffer:
  233. break
  234. output.write(buffer)
  235. loop.update(len(buffer))
  236. if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
  237. raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
  238. return download_target
  239. def _convert_image_to_rgb(image):
  240. return image.convert("RGB")
  241. def _transform(n_px):
  242. return Compose([
  243. Resize(n_px, interpolation=BICUBIC),
  244. CenterCrop(n_px),
  245. _convert_image_to_rgb,
  246. ToTensor(),
  247. Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
  248. ])
  249. def available_models() -> List[str]:
  250. """Returns the names of available CLIP models"""
  251. return list(_MODELS.keys())
  252. class Bottleneck(nn.Module):
  253. expansion = 4
  254. def __init__(self, inplanes, planes, stride=1):
  255. super().__init__()
  256. # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
  257. self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
  258. self.bn1 = nn.BatchNorm2d(planes)
  259. self.relu1 = nn.ReLU(inplace=True)
  260. self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
  261. self.bn2 = nn.BatchNorm2d(planes)
  262. self.relu2 = nn.ReLU(inplace=True)
  263. self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
  264. self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
  265. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  266. self.relu3 = nn.ReLU(inplace=True)
  267. self.downsample = None
  268. self.stride = stride
  269. if stride > 1 or inplanes != planes * Bottleneck.expansion:
  270. # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
  271. self.downsample = nn.Sequential(OrderedDict([
  272. ("-1", nn.AvgPool2d(stride)),
  273. ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
  274. ("1", nn.BatchNorm2d(planes * self.expansion))
  275. ]))
  276. def forward(self, x: torch.Tensor):
  277. identity = x
  278. out = self.relu1(self.bn1(self.conv1(x)))
  279. out = self.relu2(self.bn2(self.conv2(out)))
  280. out = self.avgpool(out)
  281. out = self.bn3(self.conv3(out))
  282. if self.downsample is not None:
  283. identity = self.downsample(x)
  284. out += identity
  285. out = self.relu3(out)
  286. return out
  287. class AttentionPool2d(nn.Module):
  288. def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
  289. super().__init__()
  290. self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
  291. self.k_proj = nn.Linear(embed_dim, embed_dim)
  292. self.q_proj = nn.Linear(embed_dim, embed_dim)
  293. self.v_proj = nn.Linear(embed_dim, embed_dim)
  294. self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
  295. self.num_heads = num_heads
  296. def forward(self, x):
  297. x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
  298. x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
  299. x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
  300. x, _ = F.multi_head_attention_forward(
  301. query=x[:1], key=x, value=x,
  302. embed_dim_to_check=x.shape[-1],
  303. num_heads=self.num_heads,
  304. q_proj_weight=self.q_proj.weight,
  305. k_proj_weight=self.k_proj.weight,
  306. v_proj_weight=self.v_proj.weight,
  307. in_proj_weight=None,
  308. in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
  309. bias_k=None,
  310. bias_v=None,
  311. add_zero_attn=False,
  312. dropout_p=0,
  313. out_proj_weight=self.c_proj.weight,
  314. out_proj_bias=self.c_proj.bias,
  315. use_separate_proj_weight=True,
  316. training=self.training,
  317. need_weights=False
  318. )
  319. return x.squeeze(0)
  320. class ModifiedResNet(nn.Module):
  321. """
  322. A ResNet class that is similar to torchvision's but contains the following changes:
  323. - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
  324. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
  325. - The final pooling layer is a QKV attention instead of an average pool
  326. """
  327. def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
  328. super().__init__()
  329. self.output_dim = output_dim
  330. self.input_resolution = input_resolution
  331. # the 3-layer stem
  332. self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
  333. self.bn1 = nn.BatchNorm2d(width // 2)
  334. self.relu1 = nn.ReLU(inplace=True)
  335. self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
  336. self.bn2 = nn.BatchNorm2d(width // 2)
  337. self.relu2 = nn.ReLU(inplace=True)
  338. self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
  339. self.bn3 = nn.BatchNorm2d(width)
  340. self.relu3 = nn.ReLU(inplace=True)
  341. self.avgpool = nn.AvgPool2d(2)
  342. # residual layers
  343. self._inplanes = width # this is a *mutable* variable used during construction
  344. self.layer1 = self._make_layer(width, layers[0])
  345. self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
  346. self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
  347. self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
  348. embed_dim = width * 32 # the ResNet feature dimension
  349. self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
  350. def _make_layer(self, planes, blocks, stride=1):
  351. layers = [Bottleneck(self._inplanes, planes, stride)]
  352. self._inplanes = planes * Bottleneck.expansion
  353. for _ in range(1, blocks):
  354. layers.append(Bottleneck(self._inplanes, planes))
  355. return nn.Sequential(*layers)
  356. def forward(self, x):
  357. def stem(x):
  358. x = self.relu1(self.bn1(self.conv1(x)))
  359. x = self.relu2(self.bn2(self.conv2(x)))
  360. x = self.relu3(self.bn3(self.conv3(x)))
  361. x = self.avgpool(x)
  362. return x
  363. x = x.type(self.conv1.weight.dtype)
  364. x = stem(x)
  365. x = self.layer1(x)
  366. x = self.layer2(x)
  367. x = self.layer3(x)
  368. x = self.layer4(x)
  369. x = self.attnpool(x)
  370. return x
  371. class LayerNorm(nn.LayerNorm):
  372. """Subclass torch's LayerNorm to handle fp16."""
  373. def forward(self, x: torch.Tensor):
  374. orig_type = x.dtype
  375. ret = super().forward(x.type(torch.float32))
  376. return ret.type(orig_type)
  377. class QuickGELU(nn.Module):
  378. def forward(self, x: torch.Tensor):
  379. return x * torch.sigmoid(1.702 * x)
  380. class ResidualAttentionBlock(nn.Module):
  381. def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
  382. super().__init__()
  383. self.attn = nn.MultiheadAttention(d_model, n_head)
  384. self.ln_1 = LayerNorm(d_model)
  385. self.mlp = nn.Sequential(OrderedDict([
  386. ("c_fc", nn.Linear(d_model, d_model * 4)),
  387. ("gelu", QuickGELU()),
  388. ("c_proj", nn.Linear(d_model * 4, d_model))
  389. ]))
  390. self.ln_2 = LayerNorm(d_model)
  391. self.attn_mask = attn_mask
  392. def attention(self, x: torch.Tensor):
  393. self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
  394. return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  395. def forward(self, x: torch.Tensor):
  396. x = x + self.attention(self.ln_1(x))
  397. x = x + self.mlp(self.ln_2(x))
  398. return x
  399. class Transformer(nn.Module):
  400. def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
  401. super().__init__()
  402. self.width = width
  403. self.layers = layers
  404. self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
  405. def forward(self, x: torch.Tensor):
  406. return self.resblocks(x)
  407. class VisionTransformer(nn.Module):
  408. def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
  409. super().__init__()
  410. self.input_resolution = input_resolution
  411. self.output_dim = output_dim
  412. self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
  413. scale = width ** -0.5
  414. self.class_embedding = nn.Parameter(scale * torch.randn(width))
  415. self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
  416. self.ln_pre = LayerNorm(width)
  417. self.transformer = Transformer(width, layers, heads)
  418. self.ln_post = LayerNorm(width)
  419. self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
  420. def forward(self, x: torch.Tensor):
  421. x = self.conv1(x) # shape = [*, width, grid, grid]
  422. x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
  423. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  424. x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
  425. x = x + self.positional_embedding.to(x.dtype)
  426. x = self.ln_pre(x)
  427. x = x.permute(1, 0, 2) # NLD -> LND
  428. x = self.transformer(x)
  429. x = x.permute(1, 0, 2) # LND -> NLD
  430. x = self.ln_post(x)
  431. if self.proj is not None:
  432. x = x @ self.proj
  433. return x
  434. class CLIP(nn.Module):
  435. def __init__(self,
  436. embed_dim: int,
  437. # vision
  438. image_resolution: int,
  439. vision_layers: Union[Tuple[int, int, int, int], int],
  440. vision_width: int,
  441. vision_patch_size: int,
  442. # text
  443. context_length: int,
  444. vocab_size: int,
  445. transformer_width: int,
  446. transformer_heads: int,
  447. transformer_layers: int
  448. ):
  449. super().__init__()
  450. self.context_length = context_length
  451. if isinstance(vision_layers, (tuple, list)):
  452. vision_heads = vision_width * 32 // 64
  453. self.visual = ModifiedResNet(
  454. layers=vision_layers,
  455. output_dim=embed_dim,
  456. heads=vision_heads,
  457. input_resolution=image_resolution,
  458. width=vision_width
  459. )
  460. else:
  461. vision_heads = vision_width // 64
  462. self.visual = VisionTransformer(
  463. input_resolution=image_resolution,
  464. patch_size=vision_patch_size,
  465. width=vision_width,
  466. layers=vision_layers,
  467. heads=vision_heads,
  468. output_dim=embed_dim
  469. )
  470. self.transformer = Transformer(
  471. width=transformer_width,
  472. layers=transformer_layers,
  473. heads=transformer_heads,
  474. attn_mask=self.build_attention_mask()
  475. )
  476. self.vocab_size = vocab_size
  477. self.token_embedding = nn.Embedding(vocab_size, transformer_width)
  478. self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
  479. self.ln_final = LayerNorm(transformer_width)
  480. self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
  481. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
  482. self.initialize_parameters()
  483. def initialize_parameters(self):
  484. nn.init.normal_(self.token_embedding.weight, std=0.02)
  485. nn.init.normal_(self.positional_embedding, std=0.01)
  486. if isinstance(self.visual, ModifiedResNet):
  487. if self.visual.attnpool is not None:
  488. std = self.visual.attnpool.c_proj.in_features ** -0.5
  489. nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
  490. nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
  491. nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
  492. nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
  493. for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
  494. for name, param in resnet_block.named_parameters():
  495. if name.endswith("bn3.weight"):
  496. nn.init.zeros_(param)
  497. proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
  498. attn_std = self.transformer.width ** -0.5
  499. fc_std = (2 * self.transformer.width) ** -0.5
  500. for block in self.transformer.resblocks:
  501. nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
  502. nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
  503. nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
  504. nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
  505. if self.text_projection is not None:
  506. nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
  507. def build_attention_mask(self):
  508. # lazily create causal attention mask, with full attention between the vision tokens
  509. # pytorch uses additive attention mask; fill with -inf
  510. mask = torch.empty(self.context_length, self.context_length)
  511. mask.fill_(float("-inf"))
  512. mask.triu_(1) # zero out the lower diagonal
  513. return mask
  514. @property
  515. def dtype(self):
  516. return self.visual.conv1.weight.dtype
  517. def encode_image(self, image):
  518. return self.visual(image.type(self.dtype))
  519. def encode_text(self, text):
  520. x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
  521. x = x + self.positional_embedding.type(self.dtype)
  522. x = x.permute(1, 0, 2) # NLD -> LND
  523. x = self.transformer(x)
  524. x = x.permute(1, 0, 2) # LND -> NLD
  525. x = self.ln_final(x).type(self.dtype)
  526. # take features from the eot embedding (eot_token is the highest number in each sequence)
  527. output = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
  528. output = torch.cat([output.unsqueeze(1), x], dim=1)
  529. return output
  530. def forward(self, image, text):
  531. image_features = self.encode_image(image)
  532. text_features = self.encode_text(text)
  533. # normalized features
  534. image_features = image_features / image_features.norm(dim=1, keepdim=True)
  535. text_features = text_features / text_features.norm(dim=1, keepdim=True)
  536. # cosine similarity as logits
  537. logit_scale = self.logit_scale.exp()
  538. logits_per_image = logit_scale * image_features @ text_features.t()
  539. logits_per_text = logits_per_image.t()
  540. # shape = [global_batch_size, global_batch_size]
  541. return logits_per_image, logits_per_text
  542. class FMU(nn.Module):
  543. """A Transformer decoder layer supporting two-stream attention (XLNet)
  544. This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
  545. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu',
  546. layer_norm_eps=1e-5):
  547. super().__init__()
  548. self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
  549. # Implementation of Feedforward model
  550. self.linear1 = nn.Linear(d_model, dim_feedforward)
  551. self.linear2 = nn.Linear(dim_feedforward, d_model)
  552. self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
  553. self.dropout1 = nn.Dropout(dropout)
  554. self.dropout2 = nn.Dropout(dropout)
  555. self.dropout3 = nn.Dropout(dropout)
  556. self.activation = transformer._get_activation_fn(activation)
  557. def __setstate__(self, state):
  558. if 'activation' not in state:
  559. state['activation'] = F.gelu
  560. super().__setstate__(state)
  561. def forward(self, query: Tensor, memory: Tensor):
  562. """Forward pass for a single stream (i.e. content or query)
  563. tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
  564. Both tgt_kv and memory are expected to be LayerNorm'd too.
  565. memory is LayerNorm'd by ViT.
  566. """
  567. query1, ca_weights = self.cross_attn(query, memory, memory)
  568. query = query + self.dropout1(query1)
  569. query2 = self.linear2(self.dropout2(self.activation(self.linear1(self.norm(query)))))
  570. query = query + self.dropout3(query2)
  571. return query
  572. class DecoderLayer(nn.Module):
  573. """A Transformer decoder layer supporting two-stream attention (XLNet) This
  574. implements a pre-LN decoder, as opposed to the post-LN default in
  575. PyTorch."""
  576. def __init__(
  577. self,
  578. d_model,
  579. nhead,
  580. dim_feedforward=2048,
  581. dropout=0.1,
  582. activation='gelu',
  583. layer_norm_eps=1e-5,
  584. ):
  585. super().__init__()
  586. self.self_attn = nn.MultiheadAttention(d_model,
  587. nhead,
  588. dropout=dropout,
  589. batch_first=True)
  590. self.cross_attn = nn.MultiheadAttention(d_model,
  591. nhead,
  592. dropout=dropout,
  593. batch_first=True)
  594. # Implementation of Feedforward model
  595. self.linear1 = nn.Linear(d_model, dim_feedforward)
  596. self.dropout = nn.Dropout(dropout)
  597. self.linear2 = nn.Linear(dim_feedforward, d_model)
  598. self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  599. self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  600. self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
  601. self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
  602. self.dropout1 = nn.Dropout(dropout)
  603. self.dropout2 = nn.Dropout(dropout)
  604. self.dropout3 = nn.Dropout(dropout)
  605. self.activation = transformer._get_activation_fn(activation)
  606. def __setstate__(self, state):
  607. if 'activation' not in state:
  608. state['activation'] = F.gelu
  609. super().__setstate__(state)
  610. def forward_stream(
  611. self,
  612. tgt: Tensor,
  613. tgt_norm: Tensor,
  614. tgt_kv: Tensor,
  615. memory: Tensor,
  616. tgt_mask: Optional[Tensor],
  617. tgt_key_padding_mask: Optional[Tensor],
  618. ):
  619. """Forward pass for a single stream (i.e. content or query) tgt_norm is
  620. just a LayerNorm'd tgt.
  621. Added as a separate parameter for efficiency. Both tgt_kv and memory
  622. are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT.
  623. """
  624. tgt2, sa_weights = self.self_attn(
  625. tgt_norm,
  626. tgt_kv,
  627. tgt_kv,
  628. attn_mask=tgt_mask,
  629. key_padding_mask=tgt_key_padding_mask)
  630. tgt = tgt + self.dropout1(tgt2)
  631. tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
  632. self.attn_map = ca_weights
  633. tgt = tgt + self.dropout2(tgt2)
  634. tgt2 = self.linear2(
  635. self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
  636. tgt = tgt + self.dropout3(tgt2)
  637. return tgt, sa_weights, ca_weights
  638. def forward(
  639. self,
  640. query,
  641. content,
  642. memory,
  643. query_mask: Optional[Tensor] = None,
  644. content_mask: Optional[Tensor] = None,
  645. content_key_padding_mask: Optional[Tensor] = None,
  646. update_content: bool = True,
  647. ):
  648. query_norm = self.norm_q(query)
  649. content_norm = self.norm_c(content)
  650. query = self.forward_stream(query, query_norm, content_norm, memory,
  651. query_mask, content_key_padding_mask)[0]
  652. if update_content:
  653. content = self.forward_stream(content, content_norm, content_norm,
  654. memory, content_mask,
  655. content_key_padding_mask)[0]
  656. return query, content
  657. class Decoder(nn.Module):
  658. __constants__ = ['norm']
  659. def __init__(self, decoder_layer, num_layers, norm):
  660. super().__init__()
  661. self.layers = transformer._get_clones(decoder_layer, num_layers)
  662. self.num_layers = num_layers
  663. self.norm = norm
  664. def forward(
  665. self,
  666. query,
  667. content,
  668. memory,
  669. query_mask: Optional[Tensor] = None,
  670. content_mask: Optional[Tensor] = None,
  671. content_key_padding_mask: Optional[Tensor] = None,
  672. ):
  673. for i, mod in enumerate(self.layers):
  674. last = i == len(self.layers) - 1
  675. query, content = mod(
  676. query,
  677. content,
  678. memory,
  679. query_mask,
  680. content_mask,
  681. content_key_padding_mask,
  682. update_content=not last,
  683. )
  684. query = self.norm(query)
  685. return query
  686. class TokenEmbedding(nn.Module):
  687. def __init__(self, charset_size: int, embed_dim: int):
  688. super().__init__()
  689. self.embedding = nn.Embedding(charset_size, embed_dim)
  690. self.embed_dim = embed_dim
  691. def forward(self, tokens: torch.Tensor):
  692. return math.sqrt(self.embed_dim) * self.embedding(tokens)
  693. def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
  694. """Load a CLIP model
  695. Parameters
  696. ----------
  697. name : str
  698. A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
  699. device : Union[str, torch.device]
  700. The device to put the loaded model
  701. jit : bool
  702. Whether to load the optimized JIT model or more hackable non-JIT model (default).
  703. download_root: str
  704. path to download the model files; by default, it uses "~/.cache/clip"
  705. Returns
  706. -------
  707. model : torch.nn.Module
  708. The CLIP model
  709. preprocess : Callable[[PIL.Image], torch.Tensor]
  710. A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
  711. """
  712. if name in _MODELS:
  713. model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
  714. elif os.path.isfile(name):
  715. model_path = name
  716. else:
  717. raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
  718. with open(model_path, 'rb') as opened_file:
  719. try:
  720. # loading JIT archive
  721. model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
  722. state_dict = None
  723. except RuntimeError:
  724. # loading saved state dict
  725. if jit:
  726. warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
  727. jit = False
  728. state_dict = torch.load(opened_file, map_location="cpu")
  729. if not jit:
  730. model = build_model(state_dict or model.state_dict()).to(device)
  731. if str(device) == "cpu":
  732. model.float()
  733. return model, _transform(model.visual.input_resolution)
  734. # patch the device names
  735. device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
  736. device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
  737. def patch_device(module):
  738. try:
  739. graphs = [module.graph] if hasattr(module, "graph") else []
  740. except RuntimeError:
  741. graphs = []
  742. if hasattr(module, "forward1"):
  743. graphs.append(module.forward1.graph)
  744. for graph in graphs:
  745. for node in graph.findAllNodes("prim::Constant"):
  746. if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
  747. node.copyAttributes(device_node)
  748. model.apply(patch_device)
  749. patch_device(model.encode_image)
  750. patch_device(model.encode_text)
  751. # patch dtype to float32 on CPU
  752. if str(device) == "cpu":
  753. float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
  754. float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
  755. float_node = float_input.node()
  756. def patch_float(module):
  757. try:
  758. graphs = [module.graph] if hasattr(module, "graph") else []
  759. except RuntimeError:
  760. graphs = []
  761. if hasattr(module, "forward1"):
  762. graphs.append(module.forward1.graph)
  763. for graph in graphs:
  764. for node in graph.findAllNodes("aten::to"):
  765. inputs = list(node.inputs())
  766. for i in [1, 2]: # dtype can be the second or third argument to aten::to()
  767. if inputs[i].node()["value"] == 5:
  768. inputs[i].node().copyAttributes(float_node)
  769. model.apply(patch_float)
  770. patch_float(model.encode_image)
  771. patch_float(model.encode_text)
  772. model.float()
  773. return model, _transform(model.input_resolution.item())
  774. def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
  775. """
  776. Returns the tokenized representation of given input string(s)
  777. Parameters
  778. ----------
  779. texts : Union[str, List[str]]
  780. An input string or a list of input strings to tokenize
  781. context_length : int
  782. The context length to use; all CLIP models use 77 as the context length
  783. truncate: bool
  784. Whether to truncate the text in case its encoding is longer than the context length
  785. Returns
  786. -------
  787. A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
  788. We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
  789. """
  790. if isinstance(texts, str):
  791. texts = [texts]
  792. sot_token = _tokenizer.encoder["<|startoftext|>"]
  793. eot_token = _tokenizer.encoder["<|endoftext|>"]
  794. all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
  795. if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
  796. result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
  797. else:
  798. result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
  799. for i, tokens in enumerate(all_tokens):
  800. if len(tokens) > context_length:
  801. if truncate:
  802. tokens = tokens[:context_length]
  803. tokens[-1] = eot_token
  804. else:
  805. raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
  806. result[i, :len(tokens)] = torch.tensor(tokens)
  807. return result
  808. class DptrParseq(nn.Module):
  809. def __init__(self,
  810. in_channels,
  811. out_channels,
  812. max_label_length=25,
  813. embed_dim=512,
  814. dec_num_heads=8,
  815. dec_mlp_ratio=4,
  816. dec_depth=6,
  817. perm_num=6,
  818. perm_forward=True,
  819. perm_mirrored=True,
  820. decode_ar=True,
  821. refine_iters=1,
  822. dropout=0.1,
  823. is_pretrain=True,
  824. ORP_path=None,
  825. **kwargs: Any) -> None:
  826. super().__init__()
  827. self.pad_id = out_channels - 1
  828. self.eos_id = 0
  829. self.bos_id = out_channels - 2
  830. self.max_label_length = max_label_length
  831. self.decode_ar = decode_ar
  832. self.refine_iters = refine_iters
  833. self.is_pretrain = is_pretrain
  834. if not is_pretrain:
  835. self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim))
  836. self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
  837. decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
  838. self.decoder = Decoder(decoder_layer,
  839. num_layers=dec_depth,
  840. norm=nn.LayerNorm(embed_dim))
  841. # Perm/attn mask stuff
  842. self.rng = np.random.default_rng()
  843. self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
  844. self.perm_forward = perm_forward
  845. self.perm_mirrored = perm_mirrored
  846. # We don't predict <bos> nor <pad>
  847. self.head = nn.Linear(embed_dim, out_channels - 2)
  848. self.text_embed = TokenEmbedding(out_channels, embed_dim)
  849. # +1 for <eos>
  850. self.pos_queries = nn.Parameter(
  851. torch.Tensor(1, max_label_length + 1, embed_dim))
  852. self.dropout = nn.Dropout(p=dropout)
  853. # Encoder has its own init.
  854. self.apply(self._init_weights)
  855. nn.init.trunc_normal_(self.pos_queries, std=0.02)
  856. if is_pretrain:
  857. self.clip_encoder, preprocess = load("ViT-B/16")
  858. for p in self.clip_encoder.parameters():
  859. p.requires_grad = False
  860. if ORP_path is None:
  861. background_image_folder_path = 'background_mages_folder/path'
  862. self.background_features = self.get_noise(background_image_folder_path, preprocess)
  863. torch.save(self.background_features, 'save/noise/to/ORP_path')
  864. else:
  865. self.background_features = torch.load(ORP_path, map_location='cpu')
  866. def _init_weights(self, module: nn.Module):
  867. """Initialize the weights using the typical initialization schemes used
  868. in SOTA models."""
  869. if isinstance(module, nn.Linear):
  870. nn.init.trunc_normal_(module.weight, std=0.02)
  871. if module.bias is not None:
  872. nn.init.zeros_(module.bias)
  873. elif isinstance(module, nn.Embedding):
  874. nn.init.trunc_normal_(module.weight, std=0.02)
  875. if module.padding_idx is not None:
  876. module.weight.data[module.padding_idx].zero_()
  877. elif isinstance(module, nn.Conv2d):
  878. nn.init.kaiming_normal_(module.weight,
  879. mode='fan_out',
  880. nonlinearity='relu')
  881. if module.bias is not None:
  882. nn.init.zeros_(module.bias)
  883. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
  884. nn.init.ones_(module.weight)
  885. nn.init.zeros_(module.bias)
  886. @torch.jit.ignore
  887. def no_weight_decay(self):
  888. param_names = {'text_embed.embedding.weight', 'pos_queries'}
  889. return param_names
  890. def get_noise(self, background_image_path, preprocess):
  891. image_paths = [os.path.join(background_image_path, filename) for filename in os.listdir(image_folder_path) if
  892. filename.endswith(('.png', '.jpg', '.jpeg'))]
  893. features = []
  894. for image_path in image_paths:
  895. image = Image.open(image_path)
  896. input = preprocess(image).unsqueeze(0).to(self._device)
  897. with torch.no_grad():
  898. feature = self.clip_encoder.encode_image(input)
  899. features.append(feature)
  900. image.close()
  901. return torch.cat(features).cpu().numpy()
  902. def clip_encode(self, labels):
  903. text_inputs = torch.cat([tokenize(f"a photo of a '{c}'") for c in labels]).to(self._device)
  904. return self.clip_encoder.encode_text(text_inputs)
  905. def decode(
  906. self,
  907. tgt: torch.Tensor,
  908. memory: torch.Tensor,
  909. tgt_mask: Optional[Tensor] = None,
  910. tgt_padding_mask: Optional[Tensor] = None,
  911. tgt_query: Optional[Tensor] = None,
  912. tgt_query_mask: Optional[Tensor] = None,
  913. pos_query: torch.Tensor = None,
  914. ):
  915. N, L = tgt.shape
  916. # <bos> stands for the null context. We only supply position information for characters after <bos>.
  917. null_ctx = self.text_embed(tgt[:, :1])
  918. if tgt_query is None:
  919. tgt_query = pos_query[:, :L]
  920. tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:])
  921. tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
  922. tgt_query = self.dropout(tgt_query)
  923. return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask,
  924. tgt_mask, tgt_padding_mask)
  925. def forward(self, memory, data=None, pos_query=None):
  926. # print(memory.shape, data[0].shape)
  927. if self.training:
  928. if self.is_pretrain:
  929. return self.training_step(None, pos_query, data[0], memory)
  930. return self.training_step(memory, pos_query, data[0], None)
  931. else:
  932. if self.is_pretrain:
  933. return self.forward_test(None, memory, pos_query)
  934. return self.forward_test(memory, None, pos_query)
  935. def forward_test(self,
  936. memory: Tensor, clip_ids,
  937. pos_query: Tensor = None,
  938. max_length: Optional[int] = None) -> Tensor:
  939. testing = max_length is None
  940. max_length = (self.max_label_length if max_length is None else min(
  941. max_length, self.max_label_length))
  942. if self.is_pretrain:
  943. memory = self.clip_encoder.encode_text(clip_ids)
  944. else:
  945. bs = memory.shape[0]
  946. token_query = self.token_query.expand(bs, -1, -1)
  947. memory = self.fmu(token_query, memory)
  948. _device = memory.get_device()
  949. bs = memory.shape[0]
  950. # +1 for <eos> at end of sequence.
  951. num_steps = max_length + 1
  952. # memory = self.encode(images)
  953. # Query positions up to `num_steps`
  954. if pos_query is None:
  955. pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
  956. else:
  957. pos_queries = pos_query
  958. # Special case for the forward permutation. Faster than using `generate_attn_masks()`
  959. tgt_mask = query_mask = torch.triu(
  960. torch.full((num_steps, num_steps), float('-inf'), device=_device),
  961. 1)
  962. self.attn_maps = []
  963. if self.decode_ar:
  964. tgt_in = torch.full((bs, num_steps),
  965. self.pad_id,
  966. dtype=torch.long,
  967. device=_device)
  968. tgt_in[:, 0] = self.bos_id
  969. logits = []
  970. for i in range(num_steps):
  971. j = i + 1 # next token index
  972. # Efficient decoding:
  973. # Input the context up to the ith token. We use only one query (at position = i) at a time.
  974. # This works because of the lookahead masking effect of the canonical (forward) AR context.
  975. # Past tokens have no access to future tokens, hence are fixed once computed.
  976. tgt_out = self.decode(
  977. tgt_in[:, :j],
  978. memory,
  979. tgt_mask[:j, :j],
  980. tgt_query=pos_queries[:, i:j],
  981. tgt_query_mask=query_mask[i:j, :j],
  982. pos_query=pos_queries,
  983. )
  984. self.attn_maps.append(self.decoder.layers[-1].attn_map)
  985. # the next token probability is in the output's ith token position
  986. p_i = self.head(tgt_out)
  987. logits.append(p_i)
  988. if j < num_steps:
  989. # greedy decode. add the next token index to the target input
  990. tgt_in[:, j] = p_i.squeeze().argmax(-1)
  991. # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
  992. if testing and (tgt_in == self.eos_id).any(dim=-1).all():
  993. break
  994. logits = torch.cat(logits, dim=1)
  995. else:
  996. # No prior context, so input is just <bos>. We query all positions.
  997. tgt_in = torch.full((bs, 1),
  998. self.bos_id,
  999. dtype=torch.long,
  1000. device=_device)
  1001. tgt_out = self.decode(tgt_in,
  1002. memory,
  1003. tgt_query=pos_queries,
  1004. pos_query=pos_queries)
  1005. logits = self.head(tgt_out)
  1006. if self.refine_iters:
  1007. # For iterative refinement, we always use a 'cloze' mask.
  1008. # We can derive it from the AR forward mask by unmasking the token context to the right.
  1009. query_mask[torch.triu(
  1010. torch.ones(num_steps,
  1011. num_steps,
  1012. dtype=torch.bool,
  1013. device=_device), 2)] = 0
  1014. bos = torch.full((bs, 1),
  1015. self.bos_id,
  1016. dtype=torch.long,
  1017. device=_device)
  1018. for i in range(self.refine_iters):
  1019. # Prior context is the previous output.
  1020. tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
  1021. tgt_len = tgt_in.shape[1]
  1022. tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum(
  1023. -1) > 0 # mask tokens beyond the first EOS token.
  1024. tgt_out = self.decode(
  1025. tgt_in,
  1026. memory,
  1027. tgt_mask[:tgt_len, :tgt_len],
  1028. tgt_padding_mask,
  1029. tgt_query=pos_queries,
  1030. tgt_query_mask=query_mask[:, :tgt_len],
  1031. pos_query=pos_queries,
  1032. )
  1033. logits = self.head(tgt_out)
  1034. return F.softmax(logits, -1)
  1035. def gen_tgt_perms(self, tgt, _device):
  1036. """Generate shared permutations for the whole batch.
  1037. This works because the same attention mask can be used for the shorter
  1038. sequences because of the padding mask.
  1039. """
  1040. # We don't permute the position of BOS, we permute EOS separately
  1041. max_num_chars = tgt.shape[1] - 2
  1042. # Special handling for 1-character sequences
  1043. if max_num_chars == 1:
  1044. return torch.arange(3, device=_device).unsqueeze(0)
  1045. perms = [torch.arange(max_num_chars, device=_device)
  1046. ] if self.perm_forward else []
  1047. # Additional permutations if needed
  1048. max_perms = math.factorial(max_num_chars)
  1049. if self.perm_mirrored:
  1050. max_perms //= 2
  1051. num_gen_perms = min(self.max_gen_perms, max_perms)
  1052. # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
  1053. # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
  1054. if max_num_chars < 5:
  1055. # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
  1056. # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
  1057. if max_num_chars == 4 and self.perm_mirrored:
  1058. selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
  1059. else:
  1060. selector = list(range(max_perms))
  1061. perm_pool = torch.as_tensor(list(
  1062. permutations(range(max_num_chars), max_num_chars)),
  1063. device=_device)[selector]
  1064. # If the forward permutation is always selected, no need to add it to the pool for sampling
  1065. if self.perm_forward:
  1066. perm_pool = perm_pool[1:]
  1067. perms = torch.stack(perms)
  1068. if len(perm_pool):
  1069. i = self.rng.choice(len(perm_pool),
  1070. size=num_gen_perms - len(perms),
  1071. replace=False)
  1072. perms = torch.cat([perms, perm_pool[i]])
  1073. else:
  1074. perms.extend([
  1075. torch.randperm(max_num_chars, device=_device)
  1076. for _ in range(num_gen_perms - len(perms))
  1077. ])
  1078. perms = torch.stack(perms)
  1079. if self.perm_mirrored:
  1080. # Add complementary pairs
  1081. comp = perms.flip(-1)
  1082. # Stack in such a way that the pairs are next to each other.
  1083. perms = torch.stack([perms, comp
  1084. ]).transpose(0, 1).reshape(-1, max_num_chars)
  1085. # NOTE:
  1086. # The only meaningful way of permuting the EOS position is by moving it one character position at a time.
  1087. # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
  1088. # positions will always be much less than the number of permutations (unless a low perm_num is set).
  1089. # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
  1090. # distribute it across the chosen number of permutations.
  1091. # Add position indices of BOS and EOS
  1092. bos_idx = perms.new_zeros((len(perms), 1))
  1093. eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
  1094. perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
  1095. # Special handling for the reverse direction. This does two things:
  1096. # 1. Reverse context for the characters
  1097. # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
  1098. if len(perms) > 1:
  1099. perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1,
  1100. device=_device)
  1101. return perms
  1102. def generate_attn_masks(self, perm, _device):
  1103. """Generate attention masks given a sequence permutation (includes pos.
  1104. for bos and eos tokens)
  1105. :param perm: the permutation sequence. i = 0 is always the BOS
  1106. :return: lookahead attention masks
  1107. """
  1108. sz = perm.shape[0]
  1109. mask = torch.zeros((sz, sz), device=_device)
  1110. for i in range(sz):
  1111. query_idx = perm[i]
  1112. masked_keys = perm[i + 1:]
  1113. mask[query_idx, masked_keys] = float('-inf')
  1114. content_mask = mask[:-1, :-1].clone()
  1115. mask[torch.eye(sz, dtype=torch.bool,
  1116. device=_device)] = float('-inf') # mask "self"
  1117. query_mask = mask[1:, :-1]
  1118. return content_mask, query_mask
  1119. def training_step(self, memory, pos_query, tgt_ids, clip_ids):
  1120. bs = tgt_ids.shape[0]
  1121. if self.is_pretrain:
  1122. memory = self.clip_encoder.encode_text(clip_ids)
  1123. n = memory.shape[1]
  1124. B, N, D = self.background_features.shape
  1125. random_B = np.random.choice(B, bs, replace=False)
  1126. random_N = np.random.choice(N, n, replace=False)
  1127. noise = self.background_features[random_B][:, random_N]
  1128. noise = torch.from_numpy(noise).to(memory.get_device())
  1129. memory = memory + noise * 1e-1
  1130. else:
  1131. token_query = self.token_query.expand(bs, -1, -1)
  1132. memory = self.fmu(token_query, memory)
  1133. if pos_query is None:
  1134. pos_query = self.pos_queries.expand(bs, -1, -1)
  1135. # Prepare the target sequences (input and output)
  1136. tgt_perms = self.gen_tgt_perms(tgt_ids, memory.get_device())
  1137. tgt_in = tgt_ids[:, :-1]
  1138. tgt_out = tgt_ids[:, 1:]
  1139. # The [EOS] token is not depended upon by any other token in any permutation ordering
  1140. tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
  1141. loss = 0
  1142. loss_numel = 0
  1143. n = (tgt_out != self.pad_id).sum().item()
  1144. for i, perm in enumerate(tgt_perms):
  1145. tgt_mask, query_mask = self.generate_attn_masks(
  1146. perm, memory.get_device())
  1147. # print("tgt_in:", tgt_in, "tgt_out:", tgt_out, "tgt_padding_mask:", tgt_padding_mask)
  1148. # print('tgt_mask:', tgt_mask)
  1149. # print('query_mask:', query_mask)
  1150. # print(tgt_in.shape, memory.shape, tgt_mask.shape, tgt_padding_mask.shape, query_mask.shape, pos_query.shape)
  1151. out = self.decode(
  1152. tgt_in,
  1153. memory,
  1154. tgt_mask,
  1155. tgt_padding_mask,
  1156. tgt_query_mask=query_mask,
  1157. pos_query=pos_query,
  1158. )
  1159. # print('out:', out)
  1160. logits = self.head(out)
  1161. # print('logits:', logits)
  1162. if i == 0:
  1163. final_out = logits
  1164. loss += n * F.cross_entropy(logits.flatten(end_dim=1),
  1165. tgt_out.flatten(),
  1166. ignore_index=self.pad_id)
  1167. loss_numel += n
  1168. # After the second iteration (i.e. done with canonical and reverse orderings),
  1169. # remove the [EOS] tokens for the succeeding perms
  1170. if i == 1:
  1171. tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id,
  1172. tgt_out)
  1173. n = (tgt_out != self.pad_id).sum().item()
  1174. loss /= loss_numel
  1175. # self.log('loss', loss)
  1176. return [loss, final_out]