common.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import torch
  2. import torch.nn as nn
  3. class GELU(nn.Module):
  4. def __init__(self, inplace=True):
  5. super(GELU, self).__init__()
  6. self.inplace = inplace
  7. def forward(self, x):
  8. return torch.nn.functional.gelu(x)
  9. class Swish(nn.Module):
  10. def __init__(self, inplace=True):
  11. super(Swish, self).__init__()
  12. self.inplace = inplace
  13. def forward(self, x):
  14. if self.inplace:
  15. x.mul_(torch.sigmoid(x))
  16. return x
  17. else:
  18. return x * torch.sigmoid(x)
  19. class Activation(nn.Module):
  20. def __init__(self, act_type, inplace=True):
  21. super(Activation, self).__init__()
  22. act_type = act_type.lower()
  23. if act_type == 'relu':
  24. self.act = nn.ReLU(inplace=inplace)
  25. elif act_type == 'relu6':
  26. self.act = nn.ReLU6(inplace=inplace)
  27. elif act_type == 'sigmoid':
  28. self.act = nn.Sigmoid()
  29. elif act_type == 'hard_sigmoid':
  30. self.act = nn.Hardsigmoid(inplace)
  31. elif act_type == 'hard_swish':
  32. self.act = nn.Hardswish(inplace=inplace)
  33. elif act_type == 'leakyrelu':
  34. self.act = nn.LeakyReLU(inplace=inplace)
  35. elif act_type == 'gelu':
  36. self.act = GELU(inplace=inplace)
  37. elif act_type == 'swish':
  38. self.act = Swish(inplace=inplace)
  39. else:
  40. raise NotImplementedError
  41. def forward(self, inputs):
  42. return self.act(inputs)
  43. def drop_path(x,
  44. drop_prob: float = 0.0,
  45. training: bool = False,
  46. scale_by_keep: bool = True):
  47. """Drop paths (Stochastic Depth) per sample (when applied in main path of
  48. residual blocks).
  49. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  50. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  51. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  52. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  53. 'survival rate' as the argument.
  54. """
  55. if drop_prob == 0.0 or not training:
  56. return x
  57. keep_prob = 1 - drop_prob
  58. shape = (x.shape[0], ) + (1, ) * (
  59. x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  60. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  61. if keep_prob > 0.0 and scale_by_keep:
  62. random_tensor.div_(keep_prob)
  63. return x * random_tensor
  64. class DropPath(nn.Module):
  65. """Drop paths (Stochastic Depth) per sample (when applied in main path of
  66. residual blocks)."""
  67. def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
  68. super(DropPath, self).__init__()
  69. self.drop_prob = drop_prob
  70. self.scale_by_keep = scale_by_keep
  71. def forward(self, x):
  72. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  73. def extra_repr(self):
  74. return f'drop_prob={round(self.drop_prob,3):0.3f}'
  75. class Identity(nn.Module):
  76. def __init__(self):
  77. super(Identity, self).__init__()
  78. def forward(self, input):
  79. return input
  80. class Mlp(nn.Module):
  81. def __init__(
  82. self,
  83. in_features,
  84. hidden_features=None,
  85. out_features=None,
  86. act_layer=nn.GELU,
  87. drop=0.0,
  88. ):
  89. super().__init__()
  90. out_features = out_features or in_features
  91. hidden_features = hidden_features or in_features
  92. self.fc1 = nn.Linear(in_features, hidden_features)
  93. self.act = act_layer()
  94. self.fc2 = nn.Linear(hidden_features, out_features)
  95. self.drop = nn.Dropout(drop)
  96. def forward(self, x):
  97. x = self.fc1(x)
  98. x = self.act(x)
  99. x = self.drop(x)
  100. x = self.fc2(x)
  101. x = self.drop(x)
  102. return x
  103. class Attention(nn.Module):
  104. def __init__(self,
  105. dim,
  106. num_heads=8,
  107. qkv_bias=False,
  108. qk_scale=None,
  109. attn_drop=0.0,
  110. proj_drop=0.0):
  111. super().__init__()
  112. self.num_heads = num_heads
  113. head_dim = dim // num_heads
  114. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  115. self.scale = qk_scale or head_dim**-0.5
  116. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  117. self.attn_drop = nn.Dropout(attn_drop)
  118. self.proj = nn.Linear(dim, dim)
  119. self.proj_drop = nn.Dropout(proj_drop)
  120. def forward(self, x):
  121. B, N, C = x.shape
  122. qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads,
  123. C // self.num_heads).permute(2, 0, 3, 1, 4))
  124. q, k, v = qkv[0], qkv[1], qkv[
  125. 2] # make torchscript happy (cannot use tensor as tuple)
  126. attn = (q @ k.transpose(-2, -1)) * self.scale
  127. attn = attn.softmax(dim=-1)
  128. attn = self.attn_drop(attn)
  129. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  130. x = self.proj(x)
  131. x = self.proj_drop(x)
  132. return x
  133. class Block(nn.Module):
  134. def __init__(
  135. self,
  136. dim,
  137. num_heads,
  138. mlp_ratio=4.0,
  139. qkv_bias=False,
  140. qk_scale=None,
  141. drop=0.0,
  142. attn_drop=0.0,
  143. drop_path=0.0,
  144. act_layer=nn.GELU,
  145. norm_layer=nn.LayerNorm,
  146. ):
  147. super().__init__()
  148. self.norm1 = norm_layer(dim)
  149. self.attn = Attention(
  150. dim,
  151. num_heads=num_heads,
  152. qkv_bias=qkv_bias,
  153. qk_scale=qk_scale,
  154. attn_drop=attn_drop,
  155. proj_drop=drop,
  156. )
  157. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  158. self.drop_path = DropPath(
  159. drop_path) if drop_path > 0.0 else nn.Identity()
  160. self.norm2 = norm_layer(dim)
  161. mlp_hidden_dim = int(dim * mlp_ratio)
  162. self.mlp = Mlp(in_features=dim,
  163. hidden_features=mlp_hidden_dim,
  164. act_layer=act_layer,
  165. drop=drop)
  166. def forward(self, x):
  167. x = x + self.drop_path(self.attn(self.norm1(x)))
  168. x = x + self.drop_path(self.mlp(self.norm2(x)))
  169. return x
  170. class PatchEmbed(nn.Module):
  171. """Image to Patch Embedding."""
  172. def __init__(self,
  173. img_size=[32, 128],
  174. patch_size=[4, 4],
  175. in_chans=3,
  176. embed_dim=768):
  177. super().__init__()
  178. num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
  179. patch_size[0])
  180. self.img_size = img_size
  181. self.patch_size = patch_size
  182. self.num_patches = num_patches
  183. self.proj = nn.Conv2d(in_chans,
  184. embed_dim,
  185. kernel_size=patch_size,
  186. stride=patch_size)
  187. def forward(self, x):
  188. B, C, H, W = x.shape
  189. # FIXME look at relaxing size constraints
  190. assert (
  191. H == self.img_size[0] and W == self.img_size[1]
  192. ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  193. x = self.proj(x).flatten(2).transpose(1, 2)
  194. return x