1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- '''
- This code is refer from:
- https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
- '''
- import numpy as np
- from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
- class MGPLabelEncode(BaseRecLabelEncode):
- """ Convert between text-label and text-index """
- SPACE = '[s]'
- GO = '[GO]'
- list_token = [GO, SPACE]
- def __init__(self,
- max_text_length,
- character_dict_path=None,
- use_space_char=False,
- only_char=False,
- **kwargs):
- super(MGPLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- use_space_char)
- # character (str): set of the possible characters.
- # [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
- self.batch_max_length = max_text_length + len(self.list_token)
- self.only_char = only_char
- if not only_char:
- # transformers==4.2.1
- from transformers import BertTokenizer, GPT2Tokenizer
- self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- self.wp_tokenizer = BertTokenizer.from_pretrained(
- 'bert-base-uncased')
- def __call__(self, data):
- text = data['label']
- char_text, char_len = self.encode(text)
- if char_text is None:
- return None
- data['length'] = np.array(char_len)
- data['char_label'] = np.array(char_text)
- if self.only_char:
- return data
- bpe_text = self.bpe_encode(text)
- if bpe_text is None:
- return None
- wp_text = self.wp_encode(text)
- data['bpe_label'] = np.array(bpe_text)
- data['wp_label'] = wp_text
- return data
- def add_special_char(self, dict_character):
- dict_character = self.list_token + dict_character
- return dict_character
- def encode(self, text):
- """ convert text-label into text-index.
- """
- if len(text) == 0:
- return None, None
- if self.lower:
- text = text.lower()
- length = len(text)
- text = [self.GO] + list(text) + [self.SPACE]
- text_list = []
- for char in text:
- if char not in self.dict:
- continue
- text_list.append(self.dict[char])
- if len(text_list) == 0 or len(text_list) > self.batch_max_length:
- return None, None
- text_list = text_list + [self.dict[self.GO]
- ] * (self.batch_max_length - len(text_list))
- return text_list, length
- def bpe_encode(self, text):
- if len(text) == 0:
- return None
- token = self.bpe_tokenizer(text)['input_ids']
- text_list = [1] + token + [2]
- if len(text_list) == 0 or len(text_list) > self.batch_max_length:
- return None
- text_list = text_list + [self.dict[self.GO]
- ] * (self.batch_max_length - len(text_list))
- return text_list
- def wp_encode(self, text):
- wp_target = self.wp_tokenizer([text],
- padding='max_length',
- max_length=self.batch_max_length,
- truncation=True,
- return_tensors='np')
- return wp_target['input_ids'][0]
|