eval_rec_all_long_simple.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import csv
  2. import os
  3. import sys
  4. __dir__ = os.path.dirname(os.path.abspath(__file__))
  5. sys.path.append(__dir__)
  6. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  7. import numpy as np
  8. from tools.data import build_dataloader
  9. from tools.engine.config import Config
  10. from tools.engine.trainer import Trainer
  11. from tools.utility import ArgsParser
  12. def parse_args():
  13. parser = ArgsParser()
  14. args = parser.parse_args()
  15. return args
  16. def main():
  17. FLAGS = parse_args()
  18. cfg = Config(FLAGS.config)
  19. FLAGS = vars(FLAGS)
  20. opt = FLAGS.pop('opt')
  21. cfg.merge_dict(FLAGS)
  22. cfg.merge_dict(opt)
  23. cfg.cfg['Global']['use_amp'] = False
  24. if cfg.cfg['Global']['output_dir'][-1] == '/':
  25. cfg.cfg['Global']['output_dir'] = cfg.cfg['Global']['output_dir'][:-1]
  26. cfg.cfg['Global']['max_text_length'] = 200
  27. cfg.cfg['Architecture']['Decoder']['max_len'] = 200
  28. cfg.cfg['Metric']['name'] = 'RecMetricLong'
  29. if cfg.cfg['Global']['pretrained_model'] is None:
  30. cfg.cfg['Global'][
  31. 'pretrained_model'] = cfg.cfg['Global']['output_dir'] + '/best.pth'
  32. trainer = Trainer(cfg, mode='eval')
  33. best_model_dict = trainer.status.get('metrics', {})
  34. trainer.logger.info('metric in ckpt ***************')
  35. for k, v in best_model_dict.items():
  36. trainer.logger.info('{}:{}'.format(k, v))
  37. data_dirs_list = [
  38. [
  39. '../ltb/ultra_long_26_35_list.txt',
  40. '../ltb/ultra_long_36_55_list.txt',
  41. '../ltb/ultra_long_56_list.txt',
  42. ],
  43. ]
  44. cfg = cfg.cfg
  45. cfg['Eval']['dataset']['name'] = 'SimpleDataSet'
  46. file_csv = open(
  47. cfg['Global']['output_dir'] + '/' +
  48. cfg['Global']['output_dir'].split('/')[-1] +
  49. '_result1_1_test_all_long_simple_bi_bs1.csv', 'w')
  50. csv_w = csv.writer(file_csv)
  51. for data_dirs in data_dirs_list:
  52. acc_each = []
  53. acc_each_num = []
  54. acc_each_dis = []
  55. each_long = {}
  56. for datadir in data_dirs:
  57. config_each = cfg.copy()
  58. config_each['Eval']['dataset']['label_file_list'] = [datadir]
  59. valid_dataloader = build_dataloader(config_each, 'Eval',
  60. trainer.logger)
  61. trainer.logger.info(
  62. f'{datadir} valid dataloader has {len(valid_dataloader)} iters'
  63. )
  64. trainer.valid_dataloader = valid_dataloader
  65. metric = trainer.eval()
  66. acc_each.append(metric['acc'] * 100)
  67. acc_each_dis.append(metric['norm_edit_dis'])
  68. acc_each_num.append(metric['all_num'])
  69. trainer.logger.info('metric eval ***************')
  70. for k, v in metric.items():
  71. trainer.logger.info('{}:{}'.format(k, v))
  72. if 'each' in k:
  73. csv_w.writerow([k] + v[26:])
  74. each_long[k] = each_long.get(k, []) + [np.array(v[26:])]
  75. avg1 = np.array(acc_each) * np.array(acc_each_num) / sum(acc_each_num)
  76. csv_w.writerow(acc_each + [avg1.sum().tolist()] +
  77. [sum(acc_each) / len(acc_each)])
  78. print(acc_each + [avg1.sum().tolist()] +
  79. [sum(acc_each) / len(acc_each)])
  80. avg1 = np.array(acc_each_dis) * np.array(acc_each_num) / sum(
  81. acc_each_num)
  82. csv_w.writerow(acc_each_dis + [avg1.sum().tolist()] +
  83. [sum(acc_each_dis) / len(acc_each)])
  84. sum_all = np.array(each_long['each_len_num']).sum(0)
  85. for k, v in each_long.items():
  86. if k != 'each_len_num':
  87. v_sum_weight = (np.array(v) *
  88. np.array(each_long['each_len_num'])).sum(0)
  89. sum_all_pad = np.where(sum_all == 0, 1., sum_all)
  90. v_all = v_sum_weight / sum_all_pad
  91. v_all = np.where(sum_all == 0, 0., v_all)
  92. csv_w.writerow([k] + v_all.tolist())
  93. v_26_40 = (v_all[:10] * sum_all[:10]) / sum_all[:10].sum()
  94. csv_w.writerow([k + '26_35'] + [v_26_40.sum().tolist()] +
  95. [sum_all[:10].sum().tolist()])
  96. v_41_55 = (v_all[10:30] *
  97. sum_all[10:30]) / sum_all[10:30].sum()
  98. csv_w.writerow([k + '36_55'] + [v_41_55.sum().tolist()] +
  99. [sum_all[10:30].sum().tolist()])
  100. v_56_70 = (v_all[30:] * sum_all[30:]) / sum_all[30:].sum()
  101. csv_w.writerow([k + '56'] + [v_56_70.sum().tolist()] +
  102. [sum_all[30:].sum().tolist()])
  103. else:
  104. csv_w.writerow([k] + sum_all.tolist())
  105. file_csv.close()
  106. if __name__ == '__main__':
  107. main()