visionlan_label_encode.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from random import sample
  2. import numpy as np
  3. from .ctc_label_encode import BaseRecLabelEncode
  4. class VisionLANLabelEncode(BaseRecLabelEncode):
  5. """Convert between text-label and text-index."""
  6. def __init__(self,
  7. max_text_length,
  8. character_dict_path=None,
  9. use_space_char=False,
  10. **kwargs):
  11. super(VisionLANLabelEncode,
  12. self).__init__(max_text_length, character_dict_path,
  13. use_space_char)
  14. self.dict = {}
  15. for i, char in enumerate(self.character):
  16. self.dict[char] = i
  17. def __call__(self, data):
  18. text = data['label'] # original string
  19. # generate occluded text
  20. len_str = len(text)
  21. if len_str <= 0:
  22. return None
  23. change_num = 1
  24. order = list(range(len_str))
  25. change_id = sample(order, change_num)[0]
  26. label_sub = text[change_id]
  27. if change_id == (len_str - 1):
  28. label_res = text[:change_id]
  29. elif change_id == 0:
  30. label_res = text[1:]
  31. else:
  32. label_res = text[:change_id] + text[change_id + 1:]
  33. data['label_res'] = label_res # remaining string
  34. data['label_sub'] = label_sub # occluded character
  35. data['label_id'] = change_id # character index
  36. # encode label
  37. text = self.encode(text)
  38. if text is None:
  39. return None
  40. text = [i + 1 for i in text]
  41. data['length'] = np.array(len(text))
  42. text = text + [0] * (self.max_text_len + 1 - len(text))
  43. data['label'] = np.array(text)
  44. label_res = self.encode(label_res)
  45. label_sub = self.encode(label_sub)
  46. if label_res is None:
  47. label_res = []
  48. else:
  49. label_res = [i + 1 for i in label_res]
  50. if label_sub is None:
  51. label_sub = []
  52. else:
  53. label_sub = [i + 1 for i in label_sub]
  54. data['length_res'] = np.array(len(label_res))
  55. data['length_sub'] = np.array(len(label_sub))
  56. label_res = label_res + [0] * (self.max_text_len - len(label_res))
  57. label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
  58. data['label_res'] = np.array(label_res)
  59. data['label_sub'] = np.array(label_sub)
  60. return data