smtr_label_encode.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import copy
  2. import random
  3. import numpy as np
  4. from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
  5. class SMTRLabelEncode(BaseRecLabelEncode):
  6. """Convert between text-label and text-index."""
  7. BOS = '<s>'
  8. EOS = '</s>'
  9. IN_F = '<INF>' # ignore
  10. IN_B = '<INB>' # ignore
  11. PAD = '<pad>'
  12. def __init__(self,
  13. max_text_length,
  14. character_dict_path=None,
  15. use_space_char=False,
  16. sub_str_len=5,
  17. **kwargs):
  18. super(SMTRLabelEncode,
  19. self).__init__(max_text_length, character_dict_path,
  20. use_space_char)
  21. self.substr_len = sub_str_len
  22. self.rang_subs = [i for i in range(1, self.substr_len + 1)]
  23. self.idx_char = [i for i in range(1, self.num_character - 5)]
  24. def __call__(self, data):
  25. text = data['label']
  26. text = self.encode(text)
  27. if text is None:
  28. return None
  29. if len(text) > self.max_text_len:
  30. return None
  31. data['length'] = np.array(len(text))
  32. text_in = [self.dict[self.IN_F]] * (self.substr_len) + text + [
  33. self.dict[self.IN_B]
  34. ] * (self.substr_len)
  35. sub_string_list_pre = []
  36. next_label_pre = []
  37. sub_string_list = []
  38. next_label = []
  39. for i in range(self.substr_len, len(text_in) - self.substr_len):
  40. sub_string_list.append(text_in[i - self.substr_len:i])
  41. next_label.append(text_in[i])
  42. if self.substr_len - i == 0:
  43. sub_string_list_pre.append(text_in[-i:])
  44. else:
  45. sub_string_list_pre.append(text_in[-i:self.substr_len - i])
  46. next_label_pre.append(text_in[-(i + 1)])
  47. sub_string_list.append(
  48. [self.dict[self.IN_F]] *
  49. (self.substr_len - len(text[-self.substr_len:])) +
  50. text[-self.substr_len:])
  51. next_label.append(self.dict[self.EOS])
  52. sub_string_list_pre.append(
  53. text[:self.substr_len] + [self.dict[self.IN_B]] *
  54. (self.substr_len - len(text[:self.substr_len])))
  55. next_label_pre.append(self.dict[self.EOS])
  56. for sstr, l in zip(sub_string_list[self.substr_len:],
  57. next_label[self.substr_len:]):
  58. id_shu = np.random.choice(self.rang_subs, 2)
  59. sstr1 = copy.deepcopy(sstr)
  60. sstr1[id_shu[0] - 1] = random.randint(1, self.num_character - 5)
  61. if sstr1 not in sub_string_list:
  62. sub_string_list.append(sstr1)
  63. next_label.append(l)
  64. sstr[id_shu[1] - 1] = random.randint(1, self.num_character - 5)
  65. for sstr, l in zip(sub_string_list_pre[self.substr_len:],
  66. next_label_pre[self.substr_len:]):
  67. id_shu = np.random.choice(self.rang_subs, 2)
  68. sstr1 = copy.deepcopy(sstr)
  69. sstr1[id_shu[0] - 1] = random.randint(1, self.num_character - 5)
  70. if sstr1 not in sub_string_list_pre:
  71. sub_string_list_pre.append(sstr1)
  72. next_label_pre.append(l)
  73. sstr[id_shu[1] - 1] = random.randint(1, self.num_character - 5)
  74. data['length_subs'] = np.array(len(sub_string_list))
  75. sub_string_list = sub_string_list + [
  76. [self.dict[self.PAD]] * self.substr_len
  77. ] * ((self.max_text_len * 2) + 2 - len(sub_string_list))
  78. next_label = next_label + [self.dict[self.PAD]] * (
  79. (self.max_text_len * 2) + 2 - len(next_label))
  80. data['label_subs'] = np.array(sub_string_list)
  81. data['label_next'] = np.array(next_label)
  82. data['length_subs_pre'] = np.array(len(sub_string_list_pre))
  83. sub_string_list_pre = sub_string_list_pre + [
  84. [self.dict[self.PAD]] * self.substr_len
  85. ] * ((self.max_text_len * 2) + 2 - len(sub_string_list_pre))
  86. next_label_pre = next_label_pre + [self.dict[self.PAD]] * (
  87. (self.max_text_len * 2) + 2 - len(next_label_pre))
  88. data['label_subs_pre'] = np.array(sub_string_list_pre)
  89. data['label_next_pre'] = np.array(next_label_pre)
  90. text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
  91. text = text + [self.dict[self.PAD]
  92. ] * (self.max_text_len + 2 - len(text))
  93. data['label'] = np.array(text)
  94. return data
  95. def add_special_char(self, dict_character):
  96. dict_character = [self.EOS] + dict_character + [
  97. self.BOS, self.IN_F, self.IN_B, self.PAD
  98. ]
  99. self.num_character = len(dict_character)
  100. return dict_character