vit.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
  5. from openrec.modeling.common import Block, PatchEmbed
  6. from openrec.modeling.encoders.svtrv2_lnconv import Feat2D, LastStage
  7. class ViT(nn.Module):
  8. def __init__(
  9. self,
  10. img_size=[32, 128],
  11. patch_size=[4, 8],
  12. in_channels=3,
  13. out_channels=256,
  14. embed_dim=384,
  15. depth=12,
  16. num_heads=6,
  17. mlp_ratio=4,
  18. qkv_bias=True,
  19. qk_scale=None,
  20. drop_rate=0.0,
  21. attn_drop_rate=0.0,
  22. drop_path_rate=0.0,
  23. norm_layer=nn.LayerNorm,
  24. act_layer=nn.GELU,
  25. last_stage=False,
  26. feat2d=False,
  27. use_cls_token=False,
  28. **kwargs,
  29. ):
  30. super().__init__()
  31. self.img_size = img_size
  32. self.embed_dim = embed_dim
  33. self.out_channels = embed_dim
  34. self.use_cls_token = use_cls_token
  35. self.feat_sz = [
  36. img_size[0] // patch_size[0], img_size[1] // patch_size[1]
  37. ]
  38. self.patch_embed = PatchEmbed(img_size, patch_size, in_channels,
  39. embed_dim)
  40. num_patches = self.patch_embed.num_patches
  41. if use_cls_token:
  42. self.cls_token = nn.Parameter(
  43. torch.zeros([1, 1, embed_dim], dtype=torch.float32),
  44. requires_grad=True,
  45. )
  46. trunc_normal_(self.cls_token, mean=0, std=0.02)
  47. self.pos_embed = nn.Parameter(
  48. torch.zeros([1, num_patches + 1, embed_dim],
  49. dtype=torch.float32),
  50. requires_grad=True,
  51. )
  52. else:
  53. self.pos_embed = nn.Parameter(
  54. torch.zeros([1, num_patches, embed_dim], dtype=torch.float32),
  55. requires_grad=True,
  56. )
  57. self.pos_drop = nn.Dropout(p=drop_rate)
  58. dpr = np.linspace(0, drop_path_rate, depth)
  59. self.blocks = nn.ModuleList([
  60. Block(
  61. dim=embed_dim,
  62. num_heads=num_heads,
  63. mlp_ratio=mlp_ratio,
  64. qkv_bias=qkv_bias,
  65. qk_scale=qk_scale,
  66. drop=drop_rate,
  67. act_layer=act_layer,
  68. attn_drop=attn_drop_rate,
  69. drop_path=dpr[i],
  70. norm_layer=norm_layer,
  71. ) for i in range(depth)
  72. ])
  73. self.norm = norm_layer(embed_dim)
  74. self.last_stage = last_stage
  75. self.feat2d = feat2d
  76. if last_stage:
  77. self.out_channels = out_channels
  78. self.stages = LastStage(embed_dim, out_channels, last_drop=0.1)
  79. if feat2d:
  80. self.stages = Feat2D()
  81. trunc_normal_(self.pos_embed, mean=0, std=0.02)
  82. self.apply(self._init_weights)
  83. def _init_weights(self, m):
  84. if isinstance(m, nn.Linear):
  85. trunc_normal_(m.weight, mean=0, std=0.02)
  86. if isinstance(m, nn.Linear) and m.bias is not None:
  87. zeros_(m.bias)
  88. if isinstance(m, nn.LayerNorm):
  89. zeros_(m.bias)
  90. ones_(m.weight)
  91. if isinstance(m, nn.Conv2d):
  92. kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  93. @torch.jit.ignore
  94. def no_weight_decay(self):
  95. return {'pos_embed'}
  96. def forward(self, x):
  97. x = self.patch_embed(x)
  98. if self.use_cls_token:
  99. x = torch.concat([self.cls_token.tile([x.shape[0], 1, 1]), x], 1)
  100. x = x + self.pos_embed
  101. x = self.pos_drop(x)
  102. for blk in self.blocks:
  103. x = blk(x)
  104. x = self.norm(x)
  105. if self.use_cls_token:
  106. x = x[:, 1:, :]
  107. if self.last_stage:
  108. x, sz = self.stages(x, self.feat_sz)
  109. if self.feat2d:
  110. x, sz = self.stages(x, self.feat_sz)
  111. return x