extract_batchsize.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import torch
  2. import numpy as np
  3. import copy
  4. def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
  5. """ """
  6. pos_lists_, pos_masks_, label_lists_ = [], [], []
  7. img_bs = batch_size
  8. ngpu = int(batch_size / img_bs)
  9. img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
  10. pos_lists_split, pos_masks_split, label_lists_split = [], [], []
  11. for i in range(ngpu):
  12. pos_lists_split.append([])
  13. pos_masks_split.append([])
  14. label_lists_split.append([])
  15. for i in range(img_ids.shape[0]):
  16. img_id = img_ids[i]
  17. gpu_id = int(img_id / img_bs)
  18. img_id = img_id % img_bs
  19. pos_list = pos_lists[i].copy()
  20. pos_list[:, 0] = img_id
  21. pos_lists_split[gpu_id].append(pos_list)
  22. pos_masks_split[gpu_id].append(pos_masks[i].copy())
  23. label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
  24. # repeat or delete
  25. for i in range(ngpu):
  26. vp_len = len(pos_lists_split[i])
  27. if vp_len <= tcl_bs:
  28. for j in range(0, tcl_bs - vp_len):
  29. pos_list = pos_lists_split[i][j].copy()
  30. pos_lists_split[i].append(pos_list)
  31. pos_mask = pos_masks_split[i][j].copy()
  32. pos_masks_split[i].append(pos_mask)
  33. label_list = copy.deepcopy(label_lists_split[i][j])
  34. label_lists_split[i].append(label_list)
  35. else:
  36. for j in range(0, vp_len - tcl_bs):
  37. c_len = len(pos_lists_split[i])
  38. pop_id = np.random.permutation(c_len)[0]
  39. pos_lists_split[i].pop(pop_id)
  40. pos_masks_split[i].pop(pop_id)
  41. label_lists_split[i].pop(pop_id)
  42. # merge
  43. for i in range(ngpu):
  44. pos_lists_.extend(pos_lists_split[i])
  45. pos_masks_.extend(pos_masks_split[i])
  46. label_lists_.extend(label_lists_split[i])
  47. return pos_lists_, pos_masks_, label_lists_
  48. def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
  49. pad_num, tcl_bs):
  50. label_list = label_list.numpy()
  51. batch, _, _, _ = label_list.shape
  52. pos_list = pos_list.numpy()
  53. pos_mask = pos_mask.numpy()
  54. pos_list_t = []
  55. pos_mask_t = []
  56. label_list_t = []
  57. for i in range(batch):
  58. for j in range(max_text_nums):
  59. if pos_mask[i, j].any():
  60. pos_list_t.append(pos_list[i][j])
  61. pos_mask_t.append(pos_mask[i][j])
  62. label_list_t.append(label_list[i][j])
  63. pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
  64. label_list_t, tcl_bs)
  65. label = []
  66. tt = [l.tolist() for l in label_list]
  67. for i in range(tcl_bs):
  68. k = 0
  69. for j in range(max_text_length):
  70. if tt[i][j][0] != pad_num:
  71. k += 1
  72. else:
  73. break
  74. label.append(k)
  75. label = torch.tensor(label)
  76. label = label.long()
  77. pos_list = torch.tensor(pos_list)
  78. pos_mask = torch.tensor(pos_mask)
  79. label_list = torch.squeeze(torch.tensor(label_list), dim=2)
  80. label_list = label_list.int()
  81. return pos_list, pos_mask, label_list, label