visionlan_postprocess.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. from .ctc_postprocess import BaseRecLabelDecode
  5. class VisionLANLabelDecode(BaseRecLabelDecode):
  6. """Convert between text-label and text-index."""
  7. def __init__(self,
  8. character_dict_path=None,
  9. use_space_char=False,
  10. **kwargs):
  11. super(VisionLANLabelDecode, self).__init__(character_dict_path,
  12. use_space_char)
  13. self.max_text_length = kwargs.get('max_text_length', 25)
  14. self.nclass = len(self.character) + 1
  15. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  16. """convert text-index into text-label."""
  17. result_list = []
  18. ignored_tokens = self.get_ignored_tokens()
  19. batch_size = len(text_index)
  20. for batch_idx in range(batch_size):
  21. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  22. if is_remove_duplicate:
  23. selection[1:] = text_index[batch_idx][1:] != text_index[
  24. batch_idx][:-1]
  25. for ignored_token in ignored_tokens:
  26. selection &= text_index[batch_idx] != ignored_token
  27. char_list = [
  28. self.character[text_id - 1]
  29. for text_id in text_index[batch_idx][selection]
  30. ]
  31. if text_prob is not None:
  32. conf_list = text_prob[batch_idx][selection]
  33. else:
  34. conf_list = [1] * len(selection)
  35. if len(conf_list) == 0:
  36. conf_list = [0]
  37. text = ''.join(char_list)
  38. result_list.append((text, np.mean(conf_list).tolist()))
  39. return result_list
  40. def __call__(self, preds, batch=None, *args, **kwargs):
  41. if len(preds) == 2: # eval mode
  42. net_out, length = preds
  43. if batch is not None:
  44. label = batch[1]
  45. else: # train mode
  46. net_out = preds[0]
  47. label, length = batch[1], batch[5]
  48. net_out = torch.cat([t[:l] for t, l in zip(net_out, length)],
  49. dim=0)
  50. text = []
  51. if not isinstance(net_out, torch.Tensor):
  52. net_out = torch.tensor(net_out, dtype=torch.float32)
  53. net_out = F.softmax(net_out, dim=1)
  54. for i in range(0, length.shape[0]):
  55. preds_idx = (net_out[int(length[:i].sum()):int(length[:i].sum() +
  56. length[i])].topk(1)
  57. [1][:, 0].tolist())
  58. preds_text = ''.join([
  59. self.character[idx - 1]
  60. if idx > 0 and idx <= len(self.character) else ''
  61. for idx in preds_idx
  62. ])
  63. preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum() +
  64. length[i])].topk(
  65. 1)[0][:, 0]
  66. preds_prob = torch.exp(
  67. torch.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
  68. text.append((preds_text, float(preds_prob)))
  69. if batch is None:
  70. return text
  71. label = self.decode(label)
  72. return text, label