12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import torch
- from torch import nn
- import torch.nn.functional as F
- class Head(nn.Module):
- def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
- super(Head, self).__init__()
- self.conv1 = nn.Conv2d(
- in_channels=in_channels,
- out_channels=in_channels // 4,
- kernel_size=kernel_list[0],
- padding=int(kernel_list[0] // 2),
- bias=False,
- )
- self.conv_bn1 = nn.BatchNorm2d(num_features=in_channels // 4, )
- self.conv2 = nn.ConvTranspose2d(
- in_channels=in_channels // 4,
- out_channels=in_channels // 4,
- kernel_size=kernel_list[1],
- stride=2,
- )
- self.conv_bn2 = nn.BatchNorm2d(num_features=in_channels // 4, )
- self.conv3 = nn.ConvTranspose2d(
- in_channels=in_channels // 4,
- out_channels=1,
- kernel_size=kernel_list[2],
- stride=2,
- )
- def forward(self, x, return_f=False):
- x = self.conv1(x)
- x = F.relu(self.conv_bn1(x))
- x = self.conv2(x)
- x = F.relu(self.conv_bn2(x))
- if return_f is True:
- f = x
- x = self.conv3(x)
- x = torch.sigmoid(x)
- if return_f is True:
- return x, f
- return x
- class DBHead(nn.Module):
- """
- Differentiable Binarization (DB) for text detection:
- see https://arxiv.org/abs/1911.08947
- args:
- params(dict): super parameters for build DB network
- """
- def __init__(self, in_channels, k=50, **kwargs):
- super(DBHead, self).__init__()
- self.k = k
- self.binarize = Head(in_channels, **kwargs)
- self.thresh = Head(in_channels, **kwargs)
- def step_function(self, x, y):
- return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
- def forward(self, x, data=None):
- shrink_maps = self.binarize(x)
- if not self.training:
- return {'maps': shrink_maps}
- threshold_maps = self.thresh(x)
- binary_maps = self.step_function(shrink_maps, threshold_maps)
- y = torch.concat([shrink_maps, threshold_maps, binary_maps], dim=1)
- return {'maps': y}
|