ar_label_encode.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import numpy as np
  2. from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
  3. class ARLabelEncode(BaseRecLabelEncode):
  4. """Convert between text-label and text-index."""
  5. BOS = '<s>'
  6. EOS = '</s>'
  7. PAD = '<pad>'
  8. def __init__(self,
  9. max_text_length,
  10. character_dict_path=None,
  11. use_space_char=False,
  12. **kwargs):
  13. super(ARLabelEncode,
  14. self).__init__(max_text_length, character_dict_path,
  15. use_space_char)
  16. def __call__(self, data):
  17. text = data['label']
  18. text = self.encode(text)
  19. if text is None:
  20. return None
  21. data['length'] = np.array(len(text))
  22. text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
  23. text = text + [self.dict[self.PAD]
  24. ] * (self.max_text_len + 2 - len(text))
  25. data['label'] = np.array(text)
  26. return data
  27. def add_special_char(self, dict_character):
  28. dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
  29. return dict_character