char_label_encode.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import numpy as np
  2. from .ctc_label_encode import BaseRecLabelEncode
  3. class CharLabelEncode(BaseRecLabelEncode):
  4. """Convert between text-label and text-index."""
  5. def __init__(self,
  6. max_text_length,
  7. character_dict_path=None,
  8. use_space_char=False,
  9. **kwargs):
  10. super(CharLabelEncode,
  11. self).__init__(max_text_length, character_dict_path,
  12. use_space_char)
  13. def __call__(self, data):
  14. text = data['label']
  15. text = self.encode(text)
  16. if text is None:
  17. return None
  18. if len(text) > self.max_text_len:
  19. return None
  20. data['length'] = np.array(len(text))
  21. text_char = text + [104] * (self.max_text_len + 1 - len(text))
  22. text.insert(0, 2)
  23. text.append(3)
  24. text = text + [0] * (self.max_text_len + 2 - len(text))
  25. data['label'] = np.array(text)
  26. data['label_char'] = np.array(text_char)
  27. return data
  28. def add_special_char(self, dict_character):
  29. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  30. return dict_character