ep_label_encode.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import numpy as np
  2. from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
  3. class EPLabelEncode(BaseRecLabelEncode):
  4. """Convert between text-label and text-index."""
  5. EOS = '</s>'
  6. PAD = '<pad>'
  7. def __init__(self,
  8. max_text_length,
  9. character_dict_path=None,
  10. use_space_char=False,
  11. **kwargs):
  12. super(EPLabelEncode,
  13. self).__init__(max_text_length, character_dict_path,
  14. use_space_char)
  15. def __call__(self, data):
  16. text = data['label']
  17. text = self.encode(text)
  18. if text is None:
  19. return None
  20. if len(text) > self.max_text_len:
  21. return None
  22. data['length'] = np.array(len(text))
  23. text = text + [self.dict[self.EOS]]
  24. text = text + [self.dict[self.PAD]
  25. ] * (self.max_text_len + 1 - len(text))
  26. data['label'] = np.array(text)
  27. return data
  28. def add_special_char(self, dict_character):
  29. dict_character = [self.EOS] + dict_character + [self.PAD]
  30. return dict_character