mgp_label_encode.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. '''
  2. This code is refer from:
  3. https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
  4. '''
  5. import numpy as np
  6. from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
  7. class MGPLabelEncode(BaseRecLabelEncode):
  8. """ Convert between text-label and text-index """
  9. SPACE = '[s]'
  10. GO = '[GO]'
  11. list_token = [GO, SPACE]
  12. def __init__(self,
  13. max_text_length,
  14. character_dict_path=None,
  15. use_space_char=False,
  16. only_char=False,
  17. **kwargs):
  18. super(MGPLabelEncode,
  19. self).__init__(max_text_length, character_dict_path,
  20. use_space_char)
  21. # character (str): set of the possible characters.
  22. # [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
  23. self.batch_max_length = max_text_length + len(self.list_token)
  24. self.only_char = only_char
  25. if not only_char:
  26. # transformers==4.2.1
  27. from transformers import BertTokenizer, GPT2Tokenizer
  28. self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
  29. self.wp_tokenizer = BertTokenizer.from_pretrained(
  30. 'bert-base-uncased')
  31. def __call__(self, data):
  32. text = data['label']
  33. char_text, char_len = self.encode(text)
  34. if char_text is None:
  35. return None
  36. data['length'] = np.array(char_len)
  37. data['char_label'] = np.array(char_text)
  38. if self.only_char:
  39. return data
  40. bpe_text = self.bpe_encode(text)
  41. if bpe_text is None:
  42. return None
  43. wp_text = self.wp_encode(text)
  44. data['bpe_label'] = np.array(bpe_text)
  45. data['wp_label'] = wp_text
  46. return data
  47. def add_special_char(self, dict_character):
  48. dict_character = self.list_token + dict_character
  49. return dict_character
  50. def encode(self, text):
  51. """ convert text-label into text-index.
  52. """
  53. if len(text) == 0:
  54. return None, None
  55. if self.lower:
  56. text = text.lower()
  57. length = len(text)
  58. text = [self.GO] + list(text) + [self.SPACE]
  59. text_list = []
  60. for char in text:
  61. if char not in self.dict:
  62. continue
  63. text_list.append(self.dict[char])
  64. if len(text_list) == 0 or len(text_list) > self.batch_max_length:
  65. return None, None
  66. text_list = text_list + [self.dict[self.GO]
  67. ] * (self.batch_max_length - len(text_list))
  68. return text_list, length
  69. def bpe_encode(self, text):
  70. if len(text) == 0:
  71. return None
  72. token = self.bpe_tokenizer(text)['input_ids']
  73. text_list = [1] + token + [2]
  74. if len(text_list) == 0 or len(text_list) > self.batch_max_length:
  75. return None
  76. text_list = text_list + [self.dict[self.GO]
  77. ] * (self.batch_max_length - len(text_list))
  78. return text_list
  79. def wp_encode(self, text):
  80. wp_target = self.wp_tokenizer([text],
  81. padding='max_length',
  82. max_length=self.batch_max_length,
  83. truncation=True,
  84. return_tensors='np')
  85. return wp_target['input_ids'][0]