igtr_label_encode.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import copy
  2. import random
  3. import numpy as np
  4. from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
  5. class IGTRLabelEncode(BaseRecLabelEncode):
  6. """Convert between text-label and text-index."""
  7. def __init__(self,
  8. max_text_length,
  9. character_dict_path=None,
  10. use_space_char=False,
  11. k=1,
  12. ch=False,
  13. prompt_error=False,
  14. **kwargs):
  15. super(IGTRLabelEncode,
  16. self).__init__(max_text_length, character_dict_path,
  17. use_space_char)
  18. self.ignore_index = self.dict['<pad>']
  19. self.k = k
  20. self.prompt_error = prompt_error
  21. self.ch = ch
  22. rare_file = kwargs.get('rare_file', None)
  23. siml_file = kwargs.get('siml_file', None)
  24. siml_char_dict = {}
  25. siml_char_list = [0 for _ in range(self.num_character)]
  26. if siml_file is not None:
  27. with open(siml_file, 'r') as f:
  28. for lin in f.readlines():
  29. lin_s = lin.strip().split('\t')
  30. char_siml = lin_s[0]
  31. if char_siml in self.dict:
  32. siml_list = []
  33. siml_prob = []
  34. for i in range(1, len(lin_s), 2):
  35. c = lin_s[i]
  36. prob = int(lin_s[i + 1])
  37. if c in self.dict and prob >= 1:
  38. siml_list.append(self.dict[c])
  39. siml_prob.append(prob)
  40. siml_prob = np.array(siml_prob,
  41. dtype=np.float32) / sum(siml_prob)
  42. siml_char_dict[self.dict[char_siml]] = [
  43. siml_list, siml_prob.tolist()
  44. ]
  45. siml_char_list[self.dict[char_siml]] = 1
  46. self.siml_char_dict = siml_char_dict
  47. self.siml_char_list = siml_char_list
  48. rare_char_list = [0 for _ in range(self.num_character)]
  49. if rare_file is not None:
  50. with open(rare_file, 'r') as f:
  51. for lin in f.readlines():
  52. lin_s = lin.strip().split('\t')
  53. # print(lin_s)
  54. char_rare = lin_s[0]
  55. num_appear = int(lin_s[1])
  56. if char_rare in self.dict and num_appear < 1000:
  57. rare_char_list[self.dict[char_rare]] = 1
  58. self.rare_char_list = rare_char_list # [self.dict[char] for char in rare_char_list]
  59. def __call__(self, data):
  60. text = data['label'] # coffee
  61. encoder_result = self.encode(text)
  62. if encoder_result is None:
  63. return None
  64. text, text_char_num, ques_list_s, prompt_list_s = encoder_result
  65. if len(text) > self.max_text_len:
  66. return None
  67. data['length'] = np.array(len(text))
  68. text = [self.dict['<s>']] + text + [self.dict['</s>']]
  69. text = text + [self.dict['<pad>']
  70. ] * (self.max_text_len + 2 - len(text))
  71. data['label'] = np.array(text) # 6
  72. ques_len_list = []
  73. ques2_len_list = []
  74. prompt_len_list = []
  75. prompt_pos_idx_list = []
  76. prompt_char_idx_list = []
  77. ques_pos_idx_list = []
  78. ques1_answer_list = []
  79. ques2_char_idx_list = []
  80. ques2_answer_list = []
  81. ques4_char_num_list = []
  82. train_step = 0
  83. for prompt_list, ques_list in zip(prompt_list_s, ques_list_s):
  84. prompt_len = len(prompt_list) + 1
  85. prompt_len_list.append(prompt_len)
  86. prompt_list = np.array(
  87. [[0, self.dict['<s>'], 0]] + prompt_list +
  88. [[self.max_text_len + 2, self.dict['<pad>'], 0]] *
  89. (self.max_text_len - len(prompt_list)))
  90. prompt_pos_idx_list.append(prompt_list[:, 0])
  91. prompt_char_idx_list.append(prompt_list[:, 1])
  92. ques_len = len(ques_list)
  93. ques_len_list.append(ques_len)
  94. ques_list = np.array(
  95. ques_list + [[self.max_text_len + 2, self.dict['<pad>'], 0]] *
  96. (self.max_text_len + 1 - ques_len))
  97. ques_pos_idx_list.append(ques_list[:, 0])
  98. # what is the first and third char?
  99. # Is the first character 't'? and Is the third character 'f'?
  100. # How many 'c', 's' and 'f' are there in the text image?
  101. ques1_answer_list.append(ques_list[:, 1])
  102. ques2_char_idx = copy.deepcopy(ques_list[:ques_len, :2])
  103. new_ques2_char_idx = []
  104. ques2_answer = []
  105. for q_2, ques2_idx in enumerate(ques2_char_idx.tolist()):
  106. if (train_step == 2 or train_step == 3) and q_2 == ques_len - 1:
  107. new_ques2_char_idx.append(ques2_idx)
  108. ques2_answer.append(1)
  109. continue
  110. if ques2_idx[1] != self.dict['<pad>'] and random.random() > 0.5:
  111. select_idx = random.randint(0, self.num_character - 3)
  112. new_ques2_char_idx.append([ques2_idx[0], select_idx])
  113. if select_idx == ques2_idx[1]:
  114. ques2_answer.append(1)
  115. else:
  116. ques2_answer.append(0)
  117. if self.siml_char_list[
  118. ques2_idx[1]] == 1 and random.random() > 0.5:
  119. select_idx_sim_list = random.sample(
  120. self.siml_char_dict[ques2_idx[1]][0],
  121. min(3, len(self.siml_char_dict[ques2_idx[1]][0])),
  122. )
  123. for select_idx in select_idx_sim_list:
  124. new_ques2_char_idx.append(
  125. [ques2_idx[0], select_idx])
  126. if select_idx == ques2_idx[1]:
  127. ques2_answer.append(1)
  128. else:
  129. ques2_answer.append(0)
  130. else:
  131. new_ques2_char_idx.append(ques2_idx)
  132. ques2_answer.append(1)
  133. ques2_len_list.append(len(new_ques2_char_idx))
  134. ques2_char_idx_new = np.array(
  135. new_ques2_char_idx +
  136. [[self.max_text_len + 2, self.dict['<pad>']]] *
  137. (self.max_text_len * 4 + 1 - len(new_ques2_char_idx)))
  138. ques2_answer = np.array(
  139. ques2_answer + [0] *
  140. (self.max_text_len * 4 + 1 - len(ques2_answer)))
  141. ques2_char_idx_list.append(ques2_char_idx_new)
  142. ques2_answer_list.append(ques2_answer)
  143. ques4_char_num_list.append(ques_list[:, 2])
  144. train_step += 1
  145. data['ques_len_list'] = np.array(ques_len_list, dtype=np.int64)
  146. data['ques2_len_list'] = np.array(ques2_len_list, dtype=np.int64)
  147. data['prompt_len_list'] = np.array(prompt_len_list, dtype=np.int64)
  148. data['prompt_pos_idx_list'] = np.array(prompt_pos_idx_list,
  149. dtype=np.int64)
  150. data['prompt_char_idx_list'] = np.array(prompt_char_idx_list,
  151. dtype=np.int64)
  152. data['ques_pos_idx_list'] = np.array(ques_pos_idx_list, dtype=np.int64)
  153. data['ques1_answer_list'] = np.array(ques1_answer_list, dtype=np.int64)
  154. data['ques2_char_idx_list'] = np.array(ques2_char_idx_list,
  155. dtype=np.int64)
  156. data['ques2_answer_list'] = np.array(ques2_answer_list,
  157. dtype=np.float32)
  158. data['ques3_answer'] = np.array(
  159. text_char_num,
  160. dtype=np.int64) # np.array([1, 0, 2]) # answer 1, 0, 2
  161. data['ques4_char_num_list'] = np.array(ques4_char_num_list)
  162. return data
  163. def add_special_char(self, dict_character):
  164. dict_character = ['</s>'] + dict_character + ['<s>'] + ['<pad>']
  165. self.num_character = len(dict_character)
  166. return dict_character
  167. def encode(self, text):
  168. """
  169. Encodes the given text into a list of character IDs and generates various lists for question and prompt sequences.
  170. Args:
  171. text (str): The input text to be encoded.
  172. Returns:
  173. tuple: A tuple containing:
  174. - text_list (list): A list of character IDs corresponding to the input text.
  175. - char_num (list): A list of character counts for each character ID.
  176. - ques_list (list): A list of question sequences, each sequence is a list of [position, character ID, character count].
  177. - prompt_list (list): A list of prompt sequences, each sequence is a list of [position, character ID, character count].
  178. Notes:
  179. - If the input text is empty, the function returns None.
  180. - The function handles rare and unrare characters differently.
  181. - The function supports both lowercased and original text based on the `self.lower` attribute.
  182. - The function generates additional sequences if the length of the input text is greater than 1.
  183. """
  184. if len(text) == 0:
  185. return None
  186. if self.lower:
  187. text = text.lower()
  188. char_num = [0 for _ in range(self.num_character - 2)]
  189. char_num[0] = 1
  190. text_list = []
  191. qa_text = []
  192. pos_i = 0
  193. rare_char_qa = []
  194. unrare_char_qa = []
  195. for char in text:
  196. if char not in self.dict:
  197. continue
  198. char_id = self.dict[char]
  199. text_list.append(char_id)
  200. qa_text.append([pos_i + 1, char_id, char_num[char_id]])
  201. if self.rare_char_list[char_id] == 1:
  202. rare_char_qa.append([pos_i + 1, char_id, char_num[char_id]])
  203. else:
  204. unrare_char_qa.append([pos_i + 1, char_id, char_num[char_id]])
  205. char_num[char_id] += 1
  206. pos_i += 1
  207. if self.ch:
  208. char_num_ch = []
  209. char_num_ch_none = []
  210. rare_char_num_ch_none = []
  211. for i, num in enumerate(char_num):
  212. if self.rare_char_list[i] == 1:
  213. rare_char_num_ch_none.append([i, num])
  214. if num > 0:
  215. char_num_ch.append([i, num])
  216. else:
  217. char_num_ch_none.append([i, 0])
  218. none_char_index = random.sample(
  219. char_num_ch_none,
  220. min(37 - len(char_num_ch), len(char_num_ch_none)))
  221. if len(rare_char_num_ch_none) > 0:
  222. none_rare_char_index = random.sample(
  223. rare_char_num_ch_none,
  224. min(40 - len(char_num_ch) - len(none_char_index),
  225. len(rare_char_num_ch_none)),
  226. )
  227. char_num_ch = char_num_ch + none_char_index + none_rare_char_index
  228. else:
  229. char_num_ch = char_num_ch + none_char_index
  230. char_num_ch.sort(key=lambda x: x[0])
  231. char_num = char_num_ch
  232. len_ = len(text_list)
  233. if len_ == 0:
  234. return None
  235. ques_list = [
  236. qa_text + [[pos_i + 1, self.dict['</s>'], 0]],
  237. [[pos_i + 1, self.dict['</s>'], 0]],
  238. ]
  239. prompt_list = [qa_text[len_:], qa_text]
  240. if len_ == 1:
  241. ques_list.append([[self.max_text_len + 1, self.dict['</s>'], 0]])
  242. prompt_list.append(
  243. [[self.max_text_len + 2, self.dict['<pad>'], 0]] * 4 + qa_text)
  244. for _ in range(1, self.k):
  245. ques_list.append(
  246. [[self.max_text_len + 2, self.dict['<pad>'], 0]])
  247. prompt_list.append(qa_text[1:])
  248. else:
  249. next_id = random.sample(range(1, len_ + 1), 2)
  250. for slice_id in next_id:
  251. b_i = slice_id - 5 if slice_id - 5 > 0 else 0
  252. if slice_id == len_:
  253. ques_list.append(
  254. [[self.max_text_len + 1, self.dict['</s>'], 0]])
  255. else:
  256. ques_list.append(
  257. qa_text[slice_id:] +
  258. [[self.max_text_len + 1, qa_text[slice_id][1], 0]])
  259. prompt_list.append(
  260. [[self.max_text_len + 2, self.dict['<pad>'], 0]] *
  261. (5 - slice_id + b_i) + qa_text[b_i:slice_id])
  262. shuffle_id1 = random.sample(range(1, len_),
  263. 2) if len_ > 2 else [1, 0]
  264. for slice_id in shuffle_id1:
  265. if slice_id == 0:
  266. ques_list.append(
  267. [[self.max_text_len + 2, self.dict['<pad>'], 0]])
  268. prompt_list.append(qa_text[:0])
  269. else:
  270. ques_list.append(qa_text[slice_id:] +
  271. [[pos_i + 1, self.dict['</s>'], 0]])
  272. prompt_list.append(qa_text[:slice_id])
  273. if len_ > 2:
  274. shuffle_id2 = random.sample(
  275. range(1, len_),
  276. self.k - 4 if len_ - 1 > self.k - 4 else len_ - 1)
  277. if self.k - 4 != len(shuffle_id2):
  278. shuffle_id2 += random.sample(range(1, len_),
  279. self.k - 4 - len(shuffle_id2))
  280. rare_slice_id = len(rare_char_qa)
  281. unrare_slice_id = len(unrare_char_qa)
  282. for slice_id in shuffle_id2:
  283. random.shuffle(qa_text)
  284. if len(rare_char_qa) > 0 and random.random() < 0.5:
  285. ques_list.append(rare_char_qa[:rare_slice_id] +
  286. unrare_char_qa[unrare_slice_id:] +
  287. [[pos_i + 1, self.dict['</s>'], 0]])
  288. if len(unrare_char_qa[:unrare_slice_id]) > 0:
  289. prompt_list1 = random.sample(
  290. unrare_char_qa[:unrare_slice_id],
  291. random.randint(
  292. 1, len(unrare_char_qa[:unrare_slice_id]))
  293. if len(unrare_char_qa[:unrare_slice_id]) > 1
  294. else 1,
  295. )
  296. else:
  297. prompt_list1 = []
  298. if len(rare_char_qa[rare_slice_id:]) > 0:
  299. prompt_list2 = random.sample(
  300. rare_char_qa[rare_slice_id:],
  301. random.randint(
  302. 1,
  303. len(rare_char_qa[rare_slice_id:])
  304. if len(rare_char_qa[rare_slice_id:]) > 1
  305. else 1,
  306. ),
  307. )
  308. else:
  309. prompt_list2 = []
  310. prompt_list.append(prompt_list1 + prompt_list2)
  311. random.shuffle(rare_char_qa)
  312. random.shuffle(unrare_char_qa)
  313. rare_slice_id = random.randint(
  314. 1,
  315. len(rare_char_qa)) if len(rare_char_qa) > 1 else 1
  316. unrare_slice_id = random.randint(
  317. 1, len(unrare_char_qa)
  318. ) if len(unrare_char_qa) > 1 else 1
  319. else:
  320. ques_list.append(qa_text[slice_id:] +
  321. [[pos_i + 1, self.dict['</s>'], 0]])
  322. prompt_list.append(qa_text[:slice_id])
  323. else:
  324. ques_list.append(qa_text[1:] +
  325. [[pos_i + 1, self.dict['</s>'], 0]])
  326. prompt_list.append(qa_text[:1])
  327. ques_list.append(qa_text[:1] +
  328. [[pos_i + 1, self.dict['</s>'], 0]])
  329. prompt_list.append(qa_text[1:])
  330. ques_list += [[[self.max_text_len + 2, self.dict['<pad>'], 0]]
  331. ] * (self.k - 6)
  332. prompt_list += [qa_text[:0]] * (self.k - 6)
  333. return text_list, char_num, ques_list, prompt_list