trainer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import datetime
  2. import os
  3. import random
  4. import time
  5. import numpy as np
  6. from tqdm import tqdm
  7. import torch
  8. import torch.distributed
  9. from tools.data import build_dataloader
  10. from tools.utils.ckpt import load_ckpt, save_ckpt
  11. from tools.utils.logging import get_logger
  12. from tools.utils.stats import TrainingStats
  13. from tools.utils.utility import AverageMeter
  14. __all__ = ['Trainer']
  15. def get_parameter_number(model):
  16. total_num = sum(p.numel() for p in model.parameters())
  17. trainable_num = sum(p.numel() for p in model.parameters()
  18. if p.requires_grad)
  19. return {'Total': total_num, 'Trainable': trainable_num}
  20. class Trainer(object):
  21. def __init__(self, cfg, mode='train', task='rec'):
  22. self.cfg = cfg.cfg
  23. self.task = task
  24. self.local_rank = (int(os.environ['LOCAL_RANK'])
  25. if 'LOCAL_RANK' in os.environ else 0)
  26. self.set_device(self.cfg['Global']['device'])
  27. mode = mode.lower()
  28. assert mode in [
  29. 'train_eval',
  30. 'train',
  31. 'eval',
  32. 'test',
  33. ], 'mode should be train, eval and test'
  34. if torch.cuda.device_count() > 1 and 'train' in mode:
  35. torch.distributed.init_process_group(backend='nccl')
  36. torch.cuda.set_device(self.device)
  37. self.cfg['Global']['distributed'] = True
  38. else:
  39. self.cfg['Global']['distributed'] = False
  40. self.local_rank = 0
  41. self.cfg['Global']['output_dir'] = self.cfg['Global'].get(
  42. 'output_dir', 'output')
  43. os.makedirs(self.cfg['Global']['output_dir'], exist_ok=True)
  44. self.writer = None
  45. if self.local_rank == 0 and self.cfg['Global'][
  46. 'use_tensorboard'] and 'train' in mode:
  47. from torch.utils.tensorboard import SummaryWriter
  48. self.writer = SummaryWriter(self.cfg['Global']['output_dir'])
  49. self.logger = get_logger(
  50. 'openrec' if task == 'rec' else 'opendet',
  51. os.path.join(self.cfg['Global']['output_dir'], 'train.log')
  52. if 'train' in mode else None,
  53. )
  54. cfg.print_cfg(self.logger.info)
  55. if self.cfg['Global']['device'] == 'gpu' and self.device.type == 'cpu':
  56. self.logger.info('cuda is not available, auto switch to cpu')
  57. self.set_random_seed(self.cfg['Global'].get('seed', 48))
  58. # build data loader
  59. self.train_dataloader = None
  60. if 'train' in mode:
  61. cfg.save(
  62. os.path.join(self.cfg['Global']['output_dir'], 'config.yml'),
  63. self.cfg)
  64. self.train_dataloader = build_dataloader(self.cfg,
  65. 'Train',
  66. self.logger,
  67. task=task)
  68. self.logger.info(
  69. f'train dataloader has {len(self.train_dataloader)} iters')
  70. self.valid_dataloader = None
  71. if 'eval' in mode and self.cfg['Eval']:
  72. self.valid_dataloader = build_dataloader(self.cfg,
  73. 'Eval',
  74. self.logger,
  75. task=task)
  76. self.logger.info(
  77. f'valid dataloader has {len(self.valid_dataloader)} iters')
  78. if task == 'rec':
  79. self._init_rec_model()
  80. elif task == 'det':
  81. self._init_det_model()
  82. else:
  83. raise NotImplementedError
  84. self.logger.info(get_parameter_number(model=self.model))
  85. self.model = self.model.to(self.device)
  86. use_sync_bn = self.cfg['Global'].get('use_sync_bn', False)
  87. if use_sync_bn:
  88. self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
  89. self.model)
  90. self.logger.info('convert_sync_batchnorm')
  91. from openrec.optimizer import build_optimizer
  92. self.optimizer, self.lr_scheduler = None, None
  93. if self.train_dataloader is not None:
  94. # build optim
  95. self.optimizer, self.lr_scheduler = build_optimizer(
  96. self.cfg['Optimizer'],
  97. self.cfg['LRScheduler'],
  98. epochs=self.cfg['Global']['epoch_num'],
  99. step_each_epoch=len(self.train_dataloader),
  100. model=self.model,
  101. )
  102. self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0)
  103. self.status = load_ckpt(self.model, self.cfg, self.optimizer,
  104. self.lr_scheduler)
  105. if self.cfg['Global']['distributed']:
  106. self.model = torch.nn.parallel.DistributedDataParallel(
  107. self.model, [self.local_rank], find_unused_parameters=False)
  108. # amp
  109. self.scaler = (torch.cuda.amp.GradScaler() if self.cfg['Global'].get(
  110. 'use_amp', False) else None)
  111. self.logger.info(
  112. f'run with torch {torch.__version__} and device {self.device}')
  113. def _init_rec_model(self):
  114. from openrec.losses import build_loss as build_rec_loss
  115. from openrec.metrics import build_metric as build_rec_metric
  116. from openrec.modeling import build_model as build_rec_model
  117. from openrec.postprocess import build_post_process as build_rec_post_process
  118. # build post process
  119. self.post_process_class = build_rec_post_process(
  120. self.cfg['PostProcess'], self.cfg['Global'])
  121. # build model
  122. # for rec algorithm
  123. char_num = self.post_process_class.get_character_num()
  124. self.cfg['Architecture']['Decoder']['out_channels'] = char_num
  125. self.model = build_rec_model(self.cfg['Architecture'])
  126. # build loss
  127. self.loss_class = build_rec_loss(self.cfg['Loss'])
  128. # build metric
  129. self.eval_class = build_rec_metric(self.cfg['Metric'])
  130. def _init_det_model(self):
  131. from opendet.losses import build_loss as build_det_loss
  132. from opendet.metrics import build_metric as build_det_metric
  133. from opendet.modeling import build_model as build_det_model
  134. from opendet.postprocess import build_post_process as build_det_post_process
  135. # build post process
  136. self.post_process_class = build_det_post_process(
  137. self.cfg['PostProcess'], self.cfg['Global'])
  138. # build detmodel
  139. self.model = build_det_model(self.cfg['Architecture'])
  140. # build loss
  141. self.loss_class = build_det_loss(self.cfg['Loss'])
  142. # build metric
  143. self.eval_class = build_det_metric(self.cfg['Metric'])
  144. def load_params(self, params):
  145. self.model.load_state_dict(params)
  146. def set_random_seed(self, seed):
  147. torch.manual_seed(seed) # 为CPU设置随机种子
  148. if self.device.type == 'cuda':
  149. torch.backends.cudnn.benchmark = True
  150. torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
  151. torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子
  152. random.seed(seed)
  153. np.random.seed(seed)
  154. def set_device(self, device):
  155. if device == 'gpu' and torch.cuda.is_available():
  156. device = torch.device(f'cuda:{self.local_rank}')
  157. else:
  158. device = torch.device('cpu')
  159. self.device = device
  160. def train(self):
  161. cal_metric_during_train = self.cfg['Global'].get(
  162. 'cal_metric_during_train', False)
  163. log_smooth_window = self.cfg['Global']['log_smooth_window']
  164. epoch_num = self.cfg['Global']['epoch_num']
  165. print_batch_step = self.cfg['Global']['print_batch_step']
  166. eval_epoch_step = self.cfg['Global'].get('eval_epoch_step', 1)
  167. start_eval_epoch = 0
  168. if self.valid_dataloader is not None:
  169. if type(eval_epoch_step) == list and len(eval_epoch_step) >= 2:
  170. start_eval_epoch = eval_epoch_step[0]
  171. eval_epoch_step = eval_epoch_step[1]
  172. if len(self.valid_dataloader) == 0:
  173. start_eval_epoch = 1e111
  174. self.logger.info(
  175. 'No Images in eval dataset, evaluation during training will be disabled'
  176. )
  177. self.logger.info(
  178. f'During the training process, after the {start_eval_epoch}th epoch, '
  179. f'an evaluation is run every {eval_epoch_step} epoch')
  180. else:
  181. start_eval_epoch = 1e111
  182. eval_batch_step = self.cfg['Global']['eval_batch_step']
  183. global_step = self.status.get('global_step', 0)
  184. start_eval_step = 0
  185. if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
  186. start_eval_step = eval_batch_step[0]
  187. eval_batch_step = eval_batch_step[1]
  188. if len(self.valid_dataloader) == 0:
  189. self.logger.info(
  190. 'No Images in eval dataset, evaluation during training '
  191. 'will be disabled')
  192. start_eval_step = 1e111
  193. self.logger.info(
  194. 'During the training process, after the {}th iteration, '
  195. 'an evaluation is run every {} iterations'.format(
  196. start_eval_step, eval_batch_step))
  197. save_epoch_step = self.cfg['Global'].get('save_epoch_step', [0, 1])
  198. start_save_epoch = save_epoch_step[0]
  199. save_epoch_step = save_epoch_step[1]
  200. start_epoch = self.status.get('epoch', 1)
  201. self.best_metric = self.status.get('metrics', {})
  202. if self.eval_class.main_indicator not in self.best_metric:
  203. self.best_metric[self.eval_class.main_indicator] = 0
  204. train_stats = TrainingStats(log_smooth_window, ['lr'])
  205. self.model.train()
  206. total_samples = 0
  207. train_reader_cost = 0.0
  208. train_batch_cost = 0.0
  209. reader_start = time.time()
  210. eta_meter = AverageMeter()
  211. for epoch in range(start_epoch, epoch_num + 1):
  212. if self.train_dataloader.dataset.need_reset:
  213. self.train_dataloader = build_dataloader(self.cfg,
  214. 'Train',
  215. self.logger,
  216. epoch=epoch,
  217. task=self.task)
  218. for idx, batch in enumerate(self.train_dataloader):
  219. batch_tensor = [t.to(self.device) for t in batch]
  220. batch_numpy = [t.numpy() for t in batch]
  221. self.optimizer.zero_grad()
  222. train_reader_cost += time.time() - reader_start
  223. # use amp
  224. if self.scaler:
  225. with torch.cuda.amp.autocast(
  226. enabled=self.device.type == 'cuda'):
  227. preds = self.model(batch_tensor[0],
  228. data=batch_tensor[1:])
  229. loss = self.loss_class(preds, batch_tensor)
  230. self.scaler.scale(loss['loss']).backward()
  231. if self.grad_clip_val > 0:
  232. torch.nn.utils.clip_grad_norm_(
  233. self.model.parameters(),
  234. max_norm=self.grad_clip_val)
  235. self.scaler.step(self.optimizer)
  236. self.scaler.update()
  237. else:
  238. preds = self.model(batch_tensor[0], data=batch_tensor[1:])
  239. loss = self.loss_class(preds, batch_tensor)
  240. avg_loss = loss['loss']
  241. avg_loss.backward()
  242. if self.grad_clip_val > 0:
  243. torch.nn.utils.clip_grad_norm_(
  244. self.model.parameters(),
  245. max_norm=self.grad_clip_val)
  246. self.optimizer.step()
  247. if cal_metric_during_train: # only rec and cls need
  248. post_result = self.post_process_class(preds,
  249. batch_numpy,
  250. training=True)
  251. self.eval_class(post_result, batch_numpy, training=True)
  252. metric = self.eval_class.get_metric()
  253. train_stats.update(metric)
  254. train_batch_time = time.time() - reader_start
  255. train_batch_cost += train_batch_time
  256. eta_meter.update(train_batch_time)
  257. global_step += 1
  258. total_samples += len(batch[0])
  259. self.lr_scheduler.step()
  260. # logger
  261. stats = {
  262. k: float(v)
  263. if v.shape == [] else v.detach().cpu().numpy().mean()
  264. for k, v in loss.items()
  265. }
  266. stats['lr'] = self.lr_scheduler.get_last_lr()[0]
  267. train_stats.update(stats)
  268. if self.writer is not None:
  269. for k, v in train_stats.get().items():
  270. self.writer.add_scalar(f'TRAIN/{k}', v, global_step)
  271. if self.local_rank == 0 and (
  272. (global_step > 0 and global_step % print_batch_step == 0) or (idx >= len(self.train_dataloader) - 1)):
  273. logs = train_stats.log()
  274. eta_sec = (
  275. (epoch_num + 1 - epoch) * len(self.train_dataloader) -
  276. idx - 1) * eta_meter.avg
  277. eta_sec_format = str(
  278. datetime.timedelta(seconds=int(eta_sec)))
  279. strs = (
  280. f'epoch: [{epoch}/{epoch_num}], global_step: {global_step}, {logs}, '
  281. f'avg_reader_cost: {train_reader_cost / print_batch_step:.5f} s, '
  282. f'avg_batch_cost: {train_batch_cost / print_batch_step:.5f} s, '
  283. f'avg_samples: {total_samples / print_batch_step}, '
  284. f'ips: {total_samples / train_batch_cost:.5f} samples/s, '
  285. f'eta: {eta_sec_format}')
  286. self.logger.info(strs)
  287. total_samples = 0
  288. train_reader_cost = 0.0
  289. train_batch_cost = 0.0
  290. reader_start = time.time()
  291. # eval iter step
  292. if (global_step > start_eval_step and
  293. (global_step - start_eval_step) % eval_batch_step == 0) and self.local_rank == 0:
  294. self.eval_step(global_step, epoch)
  295. # eval epoch step
  296. if self.local_rank == 0 and epoch > start_eval_epoch and (
  297. epoch - start_eval_epoch) % eval_epoch_step == 0:
  298. self.eval_step(global_step, epoch)
  299. if self.local_rank == 0:
  300. save_ckpt(self.model,
  301. self.cfg,
  302. self.optimizer,
  303. self.lr_scheduler,
  304. epoch,
  305. global_step,
  306. self.best_metric,
  307. is_best=False,
  308. prefix=None)
  309. if epoch > start_save_epoch and (
  310. epoch - start_save_epoch) % save_epoch_step == 0:
  311. save_ckpt(self.model,
  312. self.cfg,
  313. self.optimizer,
  314. self.lr_scheduler,
  315. epoch,
  316. global_step,
  317. self.best_metric,
  318. is_best=False,
  319. prefix='epoch_' + str(epoch))
  320. best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in self.best_metric.items()])}"
  321. self.logger.info(best_str)
  322. if self.writer is not None:
  323. self.writer.close()
  324. if torch.cuda.device_count() > 1:
  325. torch.distributed.barrier()
  326. torch.distributed.destroy_process_group()
  327. def eval_step(self, global_step, epoch):
  328. cur_metric = self.eval()
  329. cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}"
  330. self.logger.info(cur_metric_str)
  331. # logger metric
  332. if self.writer is not None:
  333. for k, v in cur_metric.items():
  334. if isinstance(v, (float, int)):
  335. self.writer.add_scalar(f'EVAL/{k}', cur_metric[k],
  336. global_step)
  337. if (cur_metric[self.eval_class.main_indicator] >=
  338. self.best_metric[self.eval_class.main_indicator]):
  339. self.best_metric.update(cur_metric)
  340. self.best_metric['best_epoch'] = epoch
  341. if self.writer is not None:
  342. self.writer.add_scalar(
  343. f'EVAL/best_{self.eval_class.main_indicator}',
  344. self.best_metric[self.eval_class.main_indicator],
  345. global_step,
  346. )
  347. save_ckpt(self.model,
  348. self.cfg,
  349. self.optimizer,
  350. self.lr_scheduler,
  351. epoch,
  352. global_step,
  353. self.best_metric,
  354. is_best=True,
  355. prefix=None)
  356. best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in self.best_metric.items()])}"
  357. self.logger.info(best_str)
  358. def eval(self):
  359. self.model.eval()
  360. with torch.no_grad():
  361. total_frame = 0.0
  362. total_time = 0.0
  363. pbar = tqdm(
  364. total=len(self.valid_dataloader),
  365. desc='eval model:',
  366. position=0,
  367. leave=True,
  368. )
  369. sum_images = 0
  370. for idx, batch in enumerate(self.valid_dataloader):
  371. batch_tensor = [t.to(self.device) for t in batch]
  372. batch_numpy = [t.numpy() for t in batch]
  373. start = time.time()
  374. if self.scaler:
  375. with torch.cuda.amp.autocast(
  376. enabled=self.device.type == 'cuda'):
  377. preds = self.model(batch_tensor[0],
  378. data=batch_tensor[1:])
  379. else:
  380. preds = self.model(batch_tensor[0], data=batch_tensor[1:])
  381. total_time += time.time() - start
  382. # Obtain usable results from post-processing methods
  383. # Evaluate the results of the current batch
  384. post_result = self.post_process_class(preds, batch_numpy)
  385. self.eval_class(post_result, batch_numpy)
  386. pbar.update(1)
  387. total_frame += len(batch[0])
  388. sum_images += 1
  389. # Get final metric,eg. acc or hmean
  390. metric = self.eval_class.get_metric()
  391. pbar.close()
  392. self.model.train()
  393. metric['fps'] = total_frame / total_time
  394. return metric
  395. def test_dataloader(self):
  396. starttime = time.time()
  397. count = 0
  398. try:
  399. for data in self.train_dataloader:
  400. count += 1
  401. if count % 1 == 0:
  402. batch_time = time.time() - starttime
  403. starttime = time.time()
  404. self.logger.info(
  405. f'reader: {count}, {data[0].shape}, {batch_time}')
  406. except:
  407. import traceback
  408. self.logger.info(traceback.format_exc())
  409. self.logger.info(f'finish reader: {count}, Success!')