utility.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import logging
  2. import os
  3. import cv2
  4. import numpy as np
  5. import importlib.util
  6. import sys
  7. import subprocess
  8. def get_check_global_params(mode):
  9. check_params = [
  10. "use_gpu",
  11. "max_text_length",
  12. "image_shape",
  13. "image_shape",
  14. "character_type",
  15. "loss_type",
  16. ]
  17. if mode == "train_eval":
  18. check_params = check_params + [
  19. "train_batch_size_per_card",
  20. "test_batch_size_per_card",
  21. ]
  22. elif mode == "test":
  23. check_params = check_params + ["test_batch_size_per_card"]
  24. return check_params
  25. def _check_image_file(path):
  26. img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
  27. return any([path.lower().endswith(e) for e in img_end])
  28. def get_image_file_list(img_file):
  29. imgs_lists = []
  30. if img_file is None or not os.path.exists(img_file):
  31. raise Exception("not found any img file in {}".format(img_file))
  32. if os.path.isfile(img_file) and _check_image_file(img_file):
  33. imgs_lists.append(img_file)
  34. elif os.path.isdir(img_file):
  35. for single_file in os.listdir(img_file):
  36. file_path = os.path.join(img_file, single_file)
  37. if os.path.isfile(file_path) and _check_image_file(file_path):
  38. imgs_lists.append(file_path)
  39. if len(imgs_lists) == 0:
  40. raise Exception("not found any img file in {}".format(img_file))
  41. imgs_lists = sorted(imgs_lists)
  42. return imgs_lists
  43. def binarize_img(img):
  44. if len(img.shape) == 3 and img.shape[2] == 3:
  45. gray = cv2.cvtColor(img,
  46. cv2.COLOR_BGR2GRAY) # conversion to grayscale image
  47. # use cv2 threshold binarization
  48. _, gray = cv2.threshold(gray, 0, 255,
  49. cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  50. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  51. return img
  52. def alpha_to_color(img, alpha_color=(255, 255, 255)):
  53. if len(img.shape) == 3 and img.shape[2] == 4:
  54. B, G, R, A = cv2.split(img)
  55. alpha = A / 255
  56. R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
  57. G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
  58. B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
  59. img = cv2.merge((B, G, R))
  60. return img
  61. def check_and_read(img_path):
  62. if os.path.basename(img_path)[-3:].lower() == "gif":
  63. gif = cv2.VideoCapture(img_path)
  64. ret, frame = gif.read()
  65. if not ret:
  66. logger = logging.getLogger("openrec")
  67. logger.info("Cannot read {}. This gif image maybe corrupted.")
  68. return None, False
  69. if len(frame.shape) == 2 or frame.shape[-1] == 1:
  70. frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
  71. imgvalue = frame[:, :, ::-1]
  72. return imgvalue, True, False
  73. elif os.path.basename(img_path)[-3:].lower() == "pdf":
  74. import fitz
  75. from PIL import Image
  76. imgs = []
  77. with fitz.open(img_path) as pdf:
  78. for pg in range(0, pdf.page_count):
  79. page = pdf[pg]
  80. mat = fitz.Matrix(2, 2)
  81. pm = page.get_pixmap(matrix=mat, alpha=False)
  82. # if width or height > 2000 pixels, don't enlarge the image
  83. if pm.width > 2000 or pm.height > 2000:
  84. pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
  85. img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
  86. img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  87. imgs.append(img)
  88. return imgs, False, True
  89. return None, False, False
  90. def load_vqa_bio_label_maps(label_map_path):
  91. with open(label_map_path, "r", encoding="utf-8") as fin:
  92. lines = fin.readlines()
  93. old_lines = [line.strip() for line in lines]
  94. lines = ["O"]
  95. for line in old_lines:
  96. # "O" has already been in lines
  97. if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
  98. continue
  99. lines.append(line)
  100. labels = ["O"]
  101. for line in lines[1:]:
  102. labels.append("B-" + line)
  103. labels.append("I-" + line)
  104. label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
  105. id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
  106. return label2id_map, id2label_map
  107. def check_install(module_name, install_name):
  108. spec = importlib.util.find_spec(module_name)
  109. if spec is None:
  110. print(f"Warnning! The {module_name} module is NOT installed")
  111. print(
  112. f"Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}."
  113. )
  114. python = sys.executable
  115. try:
  116. subprocess.check_call(
  117. [python, "-m", "pip", "install", install_name],
  118. stdout=subprocess.DEVNULL, )
  119. print(f"The {module_name} module is now installed")
  120. except subprocess.CalledProcessError as exc:
  121. raise Exception(
  122. f"Install {module_name} failed, please install manually")
  123. else:
  124. print(f"{module_name} has been installed.")
  125. class AverageMeter:
  126. def __init__(self):
  127. self.reset()
  128. def reset(self):
  129. """reset"""
  130. self.val = 0
  131. self.avg = 0
  132. self.sum = 0
  133. self.count = 0
  134. def update(self, val, n=1):
  135. """update"""
  136. self.val = val
  137. self.sum += val * n
  138. self.count += n
  139. self.avg = self.sum / self.count