123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import numpy as np
- import torch
- from torch import nn
- from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
- from openrec.modeling.common import Block, PatchEmbed
- from openrec.modeling.encoders.svtrv2_lnconv import Feat2D, LastStage
- class ViT(nn.Module):
- def __init__(
- self,
- img_size=[32, 128],
- patch_size=[4, 8],
- in_channels=3,
- out_channels=256,
- embed_dim=384,
- depth=12,
- num_heads=6,
- mlp_ratio=4,
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.0,
- attn_drop_rate=0.0,
- drop_path_rate=0.0,
- norm_layer=nn.LayerNorm,
- act_layer=nn.GELU,
- last_stage=False,
- feat2d=False,
- use_cls_token=False,
- **kwargs,
- ):
- super().__init__()
- self.img_size = img_size
- self.embed_dim = embed_dim
- self.out_channels = embed_dim
- self.use_cls_token = use_cls_token
- self.feat_sz = [
- img_size[0] // patch_size[0], img_size[1] // patch_size[1]
- ]
- self.patch_embed = PatchEmbed(img_size, patch_size, in_channels,
- embed_dim)
- num_patches = self.patch_embed.num_patches
- if use_cls_token:
- self.cls_token = nn.Parameter(
- torch.zeros([1, 1, embed_dim], dtype=torch.float32),
- requires_grad=True,
- )
- trunc_normal_(self.cls_token, mean=0, std=0.02)
- self.pos_embed = nn.Parameter(
- torch.zeros([1, num_patches + 1, embed_dim],
- dtype=torch.float32),
- requires_grad=True,
- )
- else:
- self.pos_embed = nn.Parameter(
- torch.zeros([1, num_patches, embed_dim], dtype=torch.float32),
- requires_grad=True,
- )
- self.pos_drop = nn.Dropout(p=drop_rate)
- dpr = np.linspace(0, drop_path_rate, depth)
- self.blocks = nn.ModuleList([
- Block(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- act_layer=act_layer,
- attn_drop=attn_drop_rate,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- ) for i in range(depth)
- ])
- self.norm = norm_layer(embed_dim)
- self.last_stage = last_stage
- self.feat2d = feat2d
- if last_stage:
- self.out_channels = out_channels
- self.stages = LastStage(embed_dim, out_channels, last_drop=0.1)
- if feat2d:
- self.stages = Feat2D()
- trunc_normal_(self.pos_embed, mean=0, std=0.02)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, mean=0, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- zeros_(m.bias)
- if isinstance(m, nn.LayerNorm):
- zeros_(m.bias)
- ones_(m.weight)
- if isinstance(m, nn.Conv2d):
- kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'pos_embed'}
- def forward(self, x):
- x = self.patch_embed(x)
- if self.use_cls_token:
- x = torch.concat([self.cls_token.tile([x.shape[0], 1, 1]), x], 1)
- x = x + self.pos_embed
- x = self.pos_drop(x)
- for blk in self.blocks:
- x = blk(x)
- x = self.norm(x)
- if self.use_cls_token:
- x = x[:, 1:, :]
- if self.last_stage:
- x, sz = self.stages(x, self.feat_sz)
- if self.feat2d:
- x, sz = self.stages(x, self.feat_sz)
- return x
|