eval_rec_all_en.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import csv
  2. import os
  3. import sys
  4. import numpy as np
  5. __dir__ = os.path.dirname(os.path.abspath(__file__))
  6. sys.path.append(__dir__)
  7. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  8. from tools.data import build_dataloader
  9. from tools.engine.config import Config
  10. from tools.engine.trainer import Trainer
  11. from tools.utility import ArgsParser
  12. def parse_args():
  13. parser = ArgsParser()
  14. args = parser.parse_args()
  15. return args
  16. def main():
  17. FLAGS = parse_args()
  18. cfg = Config(FLAGS.config)
  19. FLAGS = vars(FLAGS)
  20. opt = FLAGS.pop('opt')
  21. cfg.merge_dict(FLAGS)
  22. cfg.merge_dict(opt)
  23. msr = False
  24. if 'RatioDataSet' in cfg.cfg['Eval']['dataset']['name']:
  25. msr = True
  26. if cfg.cfg['Global']['output_dir'][-1] == '/':
  27. cfg.cfg['Global']['output_dir'] = cfg.cfg['Global']['output_dir'][:-1]
  28. if cfg.cfg['Global']['pretrained_model'] is None:
  29. cfg.cfg['Global'][
  30. 'pretrained_model'] = cfg.cfg['Global']['output_dir'] + '/best.pth'
  31. cfg.cfg['Global']['use_amp'] = False
  32. cfg.cfg['PostProcess']['with_ratio'] = True
  33. cfg.cfg['Metric']['with_ratio'] = True
  34. cfg.cfg['Metric']['max_len'] = 25
  35. cfg.cfg['Metric']['max_ratio'] = 12
  36. cfg.cfg['Eval']['dataset']['transforms'][-1]['KeepKeys'][
  37. 'keep_keys'].append('real_ratio')
  38. trainer = Trainer(cfg, mode='eval')
  39. best_model_dict = trainer.status.get('metrics', {})
  40. trainer.logger.info('metric in ckpt ***************')
  41. for k, v in best_model_dict.items():
  42. trainer.logger.info('{}:{}'.format(k, v))
  43. data_dirs_list = [
  44. [
  45. '../test/IIIT5k/', '../test/SVT/', '../test/IC13_857/',
  46. '../test/IC15_1811/', '../test/SVTP/', '../test/CUTE80/'
  47. ],
  48. [
  49. '../u14m/curve/', '../u14m/multi_oriented/', '../u14m/artistic/',
  50. '../u14m/contextless/', '../u14m/salient/', '../u14m/multi_words/',
  51. '../u14m/general/'
  52. ], ['../OST/weak/', '../OST/heavy/'],
  53. ['../wordart_test/', '../test/IC13_1015/', '../test/IC15_2077/']
  54. ]
  55. cfg = cfg.cfg
  56. file_csv = open(
  57. cfg['Global']['output_dir'] + '/' +
  58. cfg['Global']['output_dir'].split('/')[-1] +
  59. '_eval_all_length_ratio.csv', 'w')
  60. csv_w = csv.writer(file_csv)
  61. cfg['Eval']['dataset']['name'] = cfg['Eval']['dataset']['name'] + 'Test'
  62. for data_dirs in data_dirs_list:
  63. acc_each = []
  64. acc_each_real = []
  65. acc_each_lower = []
  66. acc_each_ingore_space = []
  67. acc_each_ingore_space_lower = []
  68. acc_each_ignore_space_symbol = []
  69. acc_each_lower_ignore_space_symbol = []
  70. acc_each_num = []
  71. acc_each_dis = []
  72. each_len = {}
  73. each_ratio = {}
  74. for datadir in data_dirs:
  75. config_each = cfg.copy()
  76. if msr:
  77. config_each['Eval']['dataset']['data_dir_list'] = [datadir]
  78. else:
  79. config_each['Eval']['dataset']['data_dir'] = datadir
  80. valid_dataloader = build_dataloader(config_each, 'Eval',
  81. trainer.logger)
  82. trainer.logger.info(
  83. f'{datadir} valid dataloader has {len(valid_dataloader)} iters'
  84. )
  85. trainer.valid_dataloader = valid_dataloader
  86. metric = trainer.eval()
  87. acc_each.append(metric['acc'] * 100)
  88. acc_each_real.append(metric['acc_real'] * 100)
  89. acc_each_lower.append(metric['acc_lower'] * 100)
  90. acc_each_ingore_space.append(metric['acc_ignore_space'] * 100)
  91. acc_each_ingore_space_lower.append(
  92. metric['acc_ignore_space_lower'] * 100)
  93. acc_each_ignore_space_symbol.append(
  94. metric['acc_ignore_space_symbol'] * 100)
  95. acc_each_lower_ignore_space_symbol.append(
  96. metric['acc_ignore_space_lower_symbol'] * 100)
  97. acc_each_dis.append(metric['norm_edit_dis'])
  98. acc_each_num.append(metric['num_samples'])
  99. trainer.logger.info('metric eval ***************')
  100. csv_w.writerow([datadir])
  101. for k, v in metric.items():
  102. trainer.logger.info('{}:{}'.format(k, v))
  103. if 'each' in k:
  104. csv_w.writerow([k] + v)
  105. if 'each_len' in k:
  106. each_len[k] = each_len.get(k, []) + [np.array(v)]
  107. if 'each_ratio' in k:
  108. each_ratio[k] = each_ratio.get(k, []) + [np.array(v)]
  109. data_name = [
  110. data_n[:-1].split('/')[-1]
  111. if data_n[-1] == '/' else data_n.split('/')[-1]
  112. for data_n in data_dirs
  113. ]
  114. csv_w.writerow(['-'] + data_name + ['arithmetic_avg'] +
  115. ['weighted_avg'])
  116. csv_w.writerow([''] + acc_each_num)
  117. avg1 = np.array(acc_each) * np.array(acc_each_num) / sum(acc_each_num)
  118. csv_w.writerow(['acc'] + acc_each + [sum(acc_each) / len(acc_each)] +
  119. [avg1.sum().tolist()])
  120. print(acc_each + [sum(acc_each) / len(acc_each)] +
  121. [avg1.sum().tolist()])
  122. avg1 = np.array(acc_each_dis) * np.array(acc_each_num) / sum(
  123. acc_each_num)
  124. csv_w.writerow(['norm_edit_dis'] + acc_each_dis +
  125. [sum(acc_each_dis) / len(acc_each)] +
  126. [avg1.sum().tolist()])
  127. avg1 = np.array(acc_each_real) * np.array(acc_each_num) / sum(
  128. acc_each_num)
  129. csv_w.writerow(['acc_real'] + acc_each_real +
  130. [sum(acc_each_real) / len(acc_each_real)] +
  131. [avg1.sum().tolist()])
  132. avg1 = np.array(acc_each_lower) * np.array(acc_each_num) / sum(
  133. acc_each_num)
  134. csv_w.writerow(['acc_lower'] + acc_each_lower +
  135. [sum(acc_each_lower) / len(acc_each_lower)] +
  136. [avg1.sum().tolist()])
  137. avg1 = np.array(acc_each_ingore_space) * np.array(acc_each_num) / sum(
  138. acc_each_num)
  139. csv_w.writerow(
  140. ['acc_ignore_space'] + acc_each_ingore_space +
  141. [sum(acc_each_ingore_space) / len(acc_each_ingore_space)] +
  142. [avg1.sum().tolist()])
  143. avg1 = np.array(acc_each_ingore_space_lower) * np.array(
  144. acc_each_num) / sum(acc_each_num)
  145. csv_w.writerow(['acc_ignore_space_lower'] +
  146. acc_each_ingore_space_lower + [
  147. sum(acc_each_ingore_space_lower) /
  148. len(acc_each_ingore_space_lower)
  149. ] + [avg1.sum().tolist()])
  150. avg1 = np.array(acc_each_ignore_space_symbol) * np.array(
  151. acc_each_num) / sum(acc_each_num)
  152. csv_w.writerow(['acc_ignore_space_symbol'] +
  153. acc_each_ignore_space_symbol + [
  154. sum(acc_each_ignore_space_symbol) /
  155. len(acc_each_ignore_space_symbol)
  156. ] + [avg1.sum().tolist()])
  157. avg1 = np.array(acc_each_lower_ignore_space_symbol) * np.array(
  158. acc_each_num) / sum(acc_each_num)
  159. csv_w.writerow(['acc_ignore_space_lower_symbol'] +
  160. acc_each_lower_ignore_space_symbol + [
  161. sum(acc_each_lower_ignore_space_symbol) /
  162. len(acc_each_lower_ignore_space_symbol)
  163. ] + [avg1.sum().tolist()])
  164. sum_all = np.array(each_len['each_len_num']).sum(0)
  165. for k, v in each_len.items():
  166. if k != 'each_len_num':
  167. v_sum_weight = (np.array(v) *
  168. np.array(each_len['each_len_num'])).sum(0)
  169. sum_all_pad = np.where(sum_all == 0, 1., sum_all)
  170. v_all = v_sum_weight / sum_all_pad
  171. v_all = np.where(sum_all == 0, 0., v_all)
  172. csv_w.writerow([k] + v_all.tolist())
  173. else:
  174. csv_w.writerow([k] + sum_all.tolist())
  175. sum_all = np.array(each_ratio['each_ratio_num']).sum(0)
  176. for k, v in each_ratio.items():
  177. if k != 'each_ratio_num':
  178. v_sum_weight = (np.array(v) *
  179. np.array(each_ratio['each_ratio_num'])).sum(0)
  180. sum_all_pad = np.where(sum_all == 0, 1., sum_all)
  181. v_all = v_sum_weight / sum_all_pad
  182. v_all = np.where(sum_all == 0, 0., v_all)
  183. csv_w.writerow([k] + v_all.tolist())
  184. else:
  185. csv_w.writerow([k] + sum_all.tolist())
  186. file_csv.close()
  187. if __name__ == '__main__':
  188. main()