cppd_label_encode.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import random
  2. import numpy as np
  3. from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
  4. class CPPDLabelEncode(BaseRecLabelEncode):
  5. """Convert between text-label and text-index."""
  6. def __init__(
  7. self,
  8. max_text_length,
  9. character_dict_path=None,
  10. use_space_char=False,
  11. ch=False,
  12. # ch_7000=7000,
  13. ignore_index=100,
  14. use_sos=False,
  15. pos_len=False,
  16. **kwargs):
  17. self.use_sos = use_sos
  18. super(CPPDLabelEncode,
  19. self).__init__(max_text_length, character_dict_path,
  20. use_space_char)
  21. self.ch = ch
  22. self.ignore_index = ignore_index
  23. self.pos_len = pos_len
  24. def __call__(self, data):
  25. text = data['label']
  26. if self.ch:
  27. text, text_node_index, text_node_num = self.encodech(text)
  28. if text is None:
  29. return None
  30. if len(text) > self.max_text_len:
  31. return None
  32. data['length'] = np.array(len(text))
  33. # text.insert(0, 0)
  34. if self.pos_len:
  35. text_pos_node = [i_ for i_ in range(len(text), -1, -1)
  36. ] + [100] * (self.max_text_len - len(text))
  37. else:
  38. text_pos_node = [1] * (len(text) + 1) + [0] * (
  39. self.max_text_len - len(text))
  40. text.append(0)
  41. text + [0] * (self.max_text_len - len(text))
  42. text = text + [self.ignore_index
  43. ] * (self.max_text_len + 1 - len(text))
  44. data['label'] = np.array(text)
  45. data['label_node'] = np.array(text_node_num + text_pos_node)
  46. data['label_index'] = np.array(text_node_index)
  47. # data['label_ctc'] = np.array(ctc_text)
  48. return data
  49. else:
  50. text, text_char_node, ch_order = self.encode(text)
  51. if text is None:
  52. return None
  53. if len(text) > self.max_text_len:
  54. return None
  55. data['length'] = np.array(len(text))
  56. # text.insert(0, 0)
  57. if self.pos_len:
  58. text_pos_node = [i_ for i_ in range(len(text), -1, -1)
  59. ] + [100] * (self.max_text_len - len(text))
  60. else:
  61. text_pos_node = [1] * (len(text) + 1) + [0] * (
  62. self.max_text_len - len(text))
  63. text.append(0)
  64. text = text + [self.ignore_index
  65. ] * (self.max_text_len + 1 - len(text))
  66. data['label'] = np.array(text)
  67. data['label_node'] = np.array(text_char_node + text_pos_node)
  68. data['label_order'] = np.array(ch_order)
  69. return data
  70. def add_special_char(self, dict_character):
  71. if self.use_sos:
  72. dict_character = ['<s>', '</s>'] + dict_character
  73. else:
  74. dict_character = ['</s>'] + dict_character
  75. self.num_character = len(dict_character)
  76. return dict_character
  77. def encode(self, text):
  78. """convert text-label into text-index.
  79. input:
  80. text: text labels of each image. [batch_size]
  81. output:
  82. text: concatenated text index for CTCLoss.
  83. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  84. length: length of each text. [batch_size]
  85. """
  86. if len(text) == 0:
  87. return None, None, None
  88. if self.lower:
  89. text = text.lower()
  90. text_node = [0 for _ in range(self.num_character)]
  91. text_node[0] = 1
  92. text_list = []
  93. ch_order = []
  94. order = 1
  95. for char in text:
  96. if char not in self.dict:
  97. continue
  98. text_list.append(self.dict[char])
  99. text_node[self.dict[char]] += 1
  100. ch_order.append(
  101. [self.dict[char], text_node[self.dict[char]], order])
  102. order += 1
  103. no_ch_order = []
  104. for char in self.character:
  105. if char not in text:
  106. no_ch_order.append([self.dict[char], 1, 0])
  107. random.shuffle(no_ch_order)
  108. ch_order = ch_order + no_ch_order
  109. ch_order = ch_order[:self.max_text_len + 1]
  110. if len(text_list) == 0 or len(text_list) > self.max_text_len:
  111. return None, None, None
  112. return text_list, text_node, ch_order.sort()
  113. def encodech(self, text):
  114. """convert text-label into text-index.
  115. input:
  116. text: text labels of each image. [batch_size]
  117. output:
  118. text: concatenated text index for CTCLoss.
  119. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  120. length: length of each text. [batch_size]
  121. """
  122. if len(text) == 0:
  123. return None, None, None
  124. if self.lower:
  125. text = text.lower()
  126. text_node_dict = {}
  127. text_node_dict.update({0: 1})
  128. character_index = [_ for _ in range(self.num_character)]
  129. text_list = []
  130. for char in text:
  131. if char not in self.dict:
  132. continue
  133. i_c = self.dict[char]
  134. text_list.append(i_c)
  135. if i_c in text_node_dict.keys():
  136. text_node_dict[i_c] += 1
  137. else:
  138. text_node_dict.update({i_c: 1})
  139. for ic in list(text_node_dict.keys()):
  140. character_index.remove(ic)
  141. none_char_index = random.sample(character_index,
  142. 37 - len(list(text_node_dict.keys())))
  143. for ic in none_char_index:
  144. text_node_dict[ic] = 0
  145. text_node_index = sorted(text_node_dict)
  146. text_node_num = [text_node_dict[k] for k in text_node_index]
  147. if len(text_list) == 0 or len(text_list) > self.max_text_len:
  148. return None, None, None
  149. return text_list, text_node_index, text_node_num