db_head.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. class Head(nn.Module):
  5. def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
  6. super(Head, self).__init__()
  7. self.conv1 = nn.Conv2d(
  8. in_channels=in_channels,
  9. out_channels=in_channels // 4,
  10. kernel_size=kernel_list[0],
  11. padding=int(kernel_list[0] // 2),
  12. bias=False,
  13. )
  14. self.conv_bn1 = nn.BatchNorm2d(num_features=in_channels // 4, )
  15. self.conv2 = nn.ConvTranspose2d(
  16. in_channels=in_channels // 4,
  17. out_channels=in_channels // 4,
  18. kernel_size=kernel_list[1],
  19. stride=2,
  20. )
  21. self.conv_bn2 = nn.BatchNorm2d(num_features=in_channels // 4, )
  22. self.conv3 = nn.ConvTranspose2d(
  23. in_channels=in_channels // 4,
  24. out_channels=1,
  25. kernel_size=kernel_list[2],
  26. stride=2,
  27. )
  28. def forward(self, x, return_f=False):
  29. x = self.conv1(x)
  30. x = F.relu(self.conv_bn1(x))
  31. x = self.conv2(x)
  32. x = F.relu(self.conv_bn2(x))
  33. if return_f is True:
  34. f = x
  35. x = self.conv3(x)
  36. x = torch.sigmoid(x)
  37. if return_f is True:
  38. return x, f
  39. return x
  40. class DBHead(nn.Module):
  41. """
  42. Differentiable Binarization (DB) for text detection:
  43. see https://arxiv.org/abs/1911.08947
  44. args:
  45. params(dict): super parameters for build DB network
  46. """
  47. def __init__(self, in_channels, k=50, **kwargs):
  48. super(DBHead, self).__init__()
  49. self.k = k
  50. self.binarize = Head(in_channels, **kwargs)
  51. self.thresh = Head(in_channels, **kwargs)
  52. def step_function(self, x, y):
  53. return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
  54. def forward(self, x, data=None):
  55. shrink_maps = self.binarize(x)
  56. if not self.training:
  57. return {'maps': shrink_maps}
  58. threshold_maps = self.thresh(x)
  59. binary_maps = self.step_function(shrink_maps, threshold_maps)
  60. y = torch.concat([shrink_maps, threshold_maps, binary_maps], dim=1)
  61. return {'maps': y}