collate_fn.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import numbers
  2. from collections import defaultdict
  3. import numpy as np
  4. import torch
  5. class DictCollator(object):
  6. """data batch."""
  7. def __call__(self, batch):
  8. data_dict = defaultdict(list)
  9. to_tensor_keys = []
  10. for sample in batch:
  11. for k, v in sample.items():
  12. if isinstance(v, (np.ndarray, torch.Tensor, numbers.Number)):
  13. if k not in to_tensor_keys:
  14. to_tensor_keys.append(k)
  15. data_dict[k].append(v)
  16. for k in to_tensor_keys:
  17. data_dict[k] = torch.from_numpy(data_dict[k])
  18. return data_dict
  19. class ListCollator(object):
  20. """data batch."""
  21. def __call__(self, batch):
  22. data_dict = defaultdict(list)
  23. to_tensor_idxs = []
  24. for sample in batch:
  25. for idx, v in enumerate(sample):
  26. if isinstance(v, (np.ndarray, torch.Tensor, numbers.Number)):
  27. if idx not in to_tensor_idxs:
  28. to_tensor_idxs.append(idx)
  29. data_dict[idx].append(v)
  30. for idx in to_tensor_idxs:
  31. data_dict[idx] = torch.from_numpy(data_dict[idx])
  32. return list(data_dict.values())
  33. class SSLRotateCollate(object):
  34. """
  35. bach: [
  36. [(4*3xH*W), (4,)]
  37. [(4*3xH*W), (4,)]
  38. ...
  39. ]
  40. """
  41. def __call__(self, batch):
  42. output = [np.concatenate(d, axis=0) for d in zip(*batch)]
  43. return output
  44. class DyMaskCollator(object):
  45. """
  46. batch: [
  47. image [batch_size, channel, maxHinbatch, maxWinbatch]
  48. image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
  49. label [batch_size, maxLabelLen]
  50. label_mask [batch_size, maxLabelLen]
  51. ...
  52. ]
  53. """
  54. def __call__(self, batch):
  55. max_width, max_height, max_length = 0, 0, 0
  56. bs, channel = len(batch), batch[0][0].shape[0]
  57. proper_items = []
  58. for item in batch:
  59. if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
  60. 2] * max_height > 1600 * 320:
  61. continue
  62. max_height = item[0].shape[
  63. 1] if item[0].shape[1] > max_height else max_height
  64. max_width = item[0].shape[
  65. 2] if item[0].shape[2] > max_width else max_width
  66. max_length = len(
  67. item[1]) if len(item[1]) > max_length else max_length
  68. proper_items.append(item)
  69. images, image_masks = np.zeros(
  70. (len(proper_items), channel, max_height, max_width),
  71. dtype='float32'), np.zeros(
  72. (len(proper_items), 1, max_height, max_width), dtype='float32')
  73. labels, label_masks = np.zeros((len(proper_items), max_length),
  74. dtype='int64'), np.zeros(
  75. (len(proper_items), max_length),
  76. dtype='int64')
  77. for i in range(len(proper_items)):
  78. _, h, w = proper_items[i][0].shape
  79. images[i][:, :h, :w] = proper_items[i][0]
  80. image_masks[i][:, :h, :w] = 1
  81. l = len(proper_items[i][1])
  82. labels[i][:l] = proper_items[i][1]
  83. label_masks[i][:l] = 1
  84. return images, image_masks, labels, label_masks