moran.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. """This code is refer from:
  2. https://github.com/Canjie-Luo/MORAN_v2
  3. """
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from torch.nn import functional as F
  8. class MORN(nn.Module):
  9. def __init__(self, in_channels, target_shape=[32, 100], enhance=1):
  10. super(MORN, self).__init__()
  11. self.targetH = target_shape[0]
  12. self.targetW = target_shape[1]
  13. self.enhance = enhance
  14. self.out_channels = in_channels
  15. self.cnn = nn.Sequential(nn.MaxPool2d(2, 2),
  16. nn.Conv2d(in_channels, 64, 3, 1, 1),
  17. nn.BatchNorm2d(64), nn.ReLU(True),
  18. nn.MaxPool2d(2,
  19. 2), nn.Conv2d(64, 128, 3, 1, 1),
  20. nn.BatchNorm2d(128), nn.ReLU(True),
  21. nn.MaxPool2d(2,
  22. 2), nn.Conv2d(128, 64, 3, 1, 1),
  23. nn.BatchNorm2d(64), nn.ReLU(True),
  24. nn.Conv2d(64, 16, 3, 1, 1),
  25. nn.BatchNorm2d(16), nn.ReLU(True),
  26. nn.Conv2d(16, 1, 3, 1, 1), nn.BatchNorm2d(1))
  27. self.pool = nn.MaxPool2d(2, 1)
  28. h_list = np.arange(self.targetH) * 2. / (self.targetH - 1) - 1
  29. w_list = np.arange(self.targetW) * 2. / (self.targetW - 1) - 1
  30. grid = np.meshgrid(w_list, h_list, indexing='ij')
  31. grid = np.stack(grid, axis=-1)
  32. grid = np.transpose(grid, (1, 0, 2))
  33. grid = np.expand_dims(grid, 0)
  34. self.grid = nn.Parameter(
  35. torch.from_numpy(grid).float(),
  36. requires_grad=False,
  37. )
  38. def forward(self, x):
  39. bs = x.shape[0]
  40. grid = self.grid.tile([bs, 1, 1, 1])
  41. grid_x = self.grid[:, :, :, 0].unsqueeze(3).tile([bs, 1, 1, 1])
  42. grid_y = self.grid[:, :, :, 1].unsqueeze(3).tile([bs, 1, 1, 1])
  43. x_small = F.upsample(x,
  44. size=(self.targetH, self.targetW),
  45. mode='bilinear')
  46. offsets = self.cnn(x_small)
  47. offsets_posi = F.relu(offsets, inplace=False)
  48. offsets_nega = F.relu(-offsets, inplace=False)
  49. offsets_pool = self.pool(offsets_posi) - self.pool(offsets_nega)
  50. offsets_grid = F.grid_sample(offsets_pool, grid)
  51. offsets_grid = offsets_grid.permute(0, 2, 3, 1).contiguous()
  52. offsets_x = torch.cat([grid_x, grid_y + offsets_grid], 3)
  53. x_rectified = F.grid_sample(x, offsets_x)
  54. for iteration in range(self.enhance):
  55. offsets = self.cnn(x_rectified)
  56. offsets_posi = F.relu(offsets, inplace=False)
  57. offsets_nega = F.relu(-offsets, inplace=False)
  58. offsets_pool = self.pool(offsets_posi) - self.pool(offsets_nega)
  59. offsets_grid += F.grid_sample(offsets_pool,
  60. grid).permute(0, 2, 3,
  61. 1).contiguous()
  62. offsets_x = torch.cat([grid_x, grid_y + offsets_grid], 3)
  63. x_rectified = F.grid_sample(x, offsets_x)
  64. # if debug:
  65. # offsets_mean = torch.mean(offsets_grid.view(x.size(0), -1), 1)
  66. # offsets_max, _ = torch.max(offsets_grid.view(x.size(0), -1), 1)
  67. # offsets_min, _ = torch.min(offsets_grid.view(x.size(0), -1), 1)
  68. # import matplotlib.pyplot as plt
  69. # from colour import Color
  70. # from torchvision import transforms
  71. # import cv2
  72. # alpha = 0.7
  73. # density_range = 256
  74. # color_map = np.empty([self.targetH, self.targetW, 3], dtype=int)
  75. # cmap = plt.get_cmap("rainbow")
  76. # blue = Color("blue")
  77. # hex_colors = list(blue.range_to(Color("red"), density_range))
  78. # rgb_colors = [[rgb * 255 for rgb in color.rgb] for color in hex_colors][::-1]
  79. # to_pil_image = transforms.ToPILImage()
  80. # for i in range(x.size(0)):
  81. # img_small = x_small[i].data.cpu().mul_(0.5).add_(0.5)
  82. # img = to_pil_image(img_small)
  83. # img = np.array(img)
  84. # if len(img.shape) == 2:
  85. # img = cv2.merge([img.copy()]*3)
  86. # img_copy = img.copy()
  87. # v_max = offsets_max.data[i]
  88. # v_min = offsets_min.data[i]
  89. # if self.cuda:
  90. # img_offsets = (offsets_grid[i]).view(1, self.targetH, self.targetW).data.cuda().add_(-v_min).mul_(1./(v_max-v_min))
  91. # else:
  92. # img_offsets = (offsets_grid[i]).view(1, self.targetH, self.targetW).data.cpu().add_(-v_min).mul_(1./(v_max-v_min))
  93. # img_offsets = to_pil_image(img_offsets)
  94. # img_offsets = np.array(img_offsets)
  95. # color_map = np.empty([self.targetH, self.targetW, 3], dtype=int)
  96. # for h_i in range(self.targetH):
  97. # for w_i in range(self.targetW):
  98. # color_map[h_i][w_i] = rgb_colors[int(img_offsets[h_i, w_i]/256.*density_range)]
  99. # color_map = color_map.astype(np.uint8)
  100. # cv2.addWeighted(color_map, alpha, img_copy, 1-alpha, 0, img_copy)
  101. # img_processed = x_rectified[i].data.cpu().mul_(0.5).add_(0.5)
  102. # img_processed = to_pil_image(img_processed)
  103. # img_processed = np.array(img_processed)
  104. # if len(img_processed.shape) == 2:
  105. # img_processed = cv2.merge([img_processed.copy()]*3)
  106. # total_img = np.ones([self.targetH, self.targetW*3+10, 3], dtype=int)*255
  107. # total_img[0:self.targetH, 0:self.targetW] = img
  108. # total_img[0:self.targetH, self.targetW+5:2*self.targetW+5] = img_copy
  109. # total_img[0:self.targetH, self.targetW*2+10:3*self.targetW+10] = img_processed
  110. # total_img = cv2.resize(total_img.astype(np.uint8), (300, 50))
  111. # # cv2.imshow("Input_Offsets_Output", total_img)
  112. # # cv2.waitKey()
  113. # return x_rectified, total_img
  114. return x_rectified