demo_gradio.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # @Author: OpenOCR
  2. # @Contact: 784990967@qq.com
  3. import os
  4. import gradio as gr # gradio==4.20.0
  5. os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
  6. import cv2
  7. import numpy as np
  8. import json
  9. import time
  10. from PIL import Image
  11. from tools.infer_e2e import OpenOCR, check_and_download_font, draw_ocr_box_txt
  12. def initialize_ocr(model_type, drop_score):
  13. return OpenOCR(mode=model_type, drop_score=drop_score)
  14. # Default model type
  15. model_type = 'mobile'
  16. drop_score = 0.4
  17. text_sys = initialize_ocr(model_type, drop_score)
  18. # warm up 5 times
  19. if True:
  20. img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
  21. for i in range(5):
  22. res = text_sys(img_numpy=img)
  23. font_path = './simfang.ttf'
  24. font_path = check_and_download_font(font_path)
  25. def main(input_image,
  26. model_type_select,
  27. det_input_size_textbox=960,
  28. rec_drop_score=0.4,
  29. mask_thresh=0.3,
  30. box_thresh=0.6,
  31. unclip_ratio=1.5,
  32. det_score_mode='slow'):
  33. global text_sys, model_type
  34. # Update OCR model if the model type changes
  35. if model_type_select != model_type:
  36. model_type = model_type_select
  37. text_sys = initialize_ocr(model_type, rec_drop_score)
  38. img = input_image[:, :, ::-1]
  39. starttime = time.time()
  40. results, time_dict, mask = text_sys(
  41. img_numpy=img,
  42. return_mask=True,
  43. det_input_size=int(det_input_size_textbox),
  44. thresh=mask_thresh,
  45. box_thresh=box_thresh,
  46. unclip_ratio=unclip_ratio,
  47. score_mode=det_score_mode)
  48. elapse = time.time() - starttime
  49. save_pred = json.dumps(results[0], ensure_ascii=False)
  50. image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
  51. boxes = [res['points'] for res in results[0]]
  52. txts = [res['transcription'] for res in results[0]]
  53. scores = [res['score'] for res in results[0]]
  54. draw_img = draw_ocr_box_txt(
  55. image,
  56. boxes,
  57. txts,
  58. scores,
  59. drop_score=rec_drop_score,
  60. font_path=font_path,
  61. )
  62. mask = mask[0, 0, :, :] > mask_thresh
  63. return save_pred, elapse, draw_img, mask.astype('uint8') * 255
  64. def get_all_file_names_including_subdirs(dir_path):
  65. all_file_names = []
  66. for root, dirs, files in os.walk(dir_path):
  67. for file_name in files:
  68. all_file_names.append(os.path.join(root, file_name))
  69. file_names_only = [os.path.basename(file) for file in all_file_names]
  70. return file_names_only
  71. def list_image_paths(directory):
  72. image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff')
  73. image_paths = []
  74. for root, dirs, files in os.walk(directory):
  75. for file in files:
  76. if file.lower().endswith(image_extensions):
  77. relative_path = os.path.relpath(os.path.join(root, file),
  78. directory)
  79. full_path = os.path.join(directory, relative_path)
  80. image_paths.append(full_path)
  81. image_paths = sorted(image_paths)
  82. return image_paths
  83. def find_file_in_current_dir_and_subdirs(file_name):
  84. for root, dirs, files in os.walk('.'):
  85. if file_name in files:
  86. relative_path = os.path.join(root, file_name)
  87. return relative_path
  88. e2e_img_example = list_image_paths('./OCR_e2e_img')
  89. if __name__ == '__main__':
  90. css = '.image-container img { width: 100%; max-height: 320px;}'
  91. with gr.Blocks(css=css) as demo:
  92. gr.HTML("""
  93. <h1 style='text-align: center;'><a href="https://github.com/Topdu/OpenOCR">OpenOCR</a></h1>
  94. <p style='text-align: center;'>准确高效的通用 OCR 系统 (由<a href="https://fvl.fudan.edu.cn">FVL实验室</a> <a href="https://github.com/Topdu/OpenOCR">OCR Team</a> 创建) <a href="https://github.com/Topdu/OpenOCR/tree/main?tab=readme-ov-file#quick-start">[本地快速部署]</a></p>"""
  95. )
  96. with gr.Row():
  97. with gr.Column(scale=1):
  98. input_image = gr.Image(label='Input image',
  99. elem_classes=['image-container'])
  100. examples = gr.Examples(examples=e2e_img_example,
  101. inputs=input_image,
  102. label='Examples')
  103. downstream = gr.Button('Run')
  104. # 添加参数调节组件
  105. with gr.Column():
  106. with gr.Row():
  107. det_input_size_textbox = gr.Number(
  108. label='Detection Input Size',
  109. value=960,
  110. info='检测网络输入尺寸的最长边,默认为960。')
  111. det_score_mode_dropdown = gr.Dropdown(
  112. ['slow', 'fast'],
  113. value='slow',
  114. label='Detection Score Mode',
  115. info='文本框的置信度计算模式,默认为 slow。slow 模式计算速度较慢,但准确度较高。fast 模式计算速度较快,但准确度较低。'
  116. )
  117. with gr.Row():
  118. rec_drop_score_slider = gr.Slider(
  119. 0.0,
  120. 1.0,
  121. value=0.4,
  122. step=0.01,
  123. label='Recognition Drop Score',
  124. info='识别置信度阈值,默认值为0.4。低于该阈值的识别结果和对应的文本框被丢弃。')
  125. mask_thresh_slider = gr.Slider(
  126. 0.0,
  127. 1.0,
  128. value=0.3,
  129. step=0.01,
  130. label='Mask Threshold',
  131. info='Mask 阈值,用于二值化 mask,默认值为0.3。如果存在文本截断时,请调低该值。')
  132. with gr.Row():
  133. box_thresh_slider = gr.Slider(
  134. 0.0,
  135. 1.0,
  136. value=0.6,
  137. step=0.01,
  138. label='Box Threshold',
  139. info='文本框置信度阈值,默认值为0.6。如果存在文本被漏检时,请调低该值。')
  140. unclip_ratio_slider = gr.Slider(
  141. 1.5,
  142. 2.0,
  143. value=1.5,
  144. step=0.05,
  145. label='Unclip Ratio',
  146. info='文本框解析时的膨胀系数,默认值为1.5。值越大文本框越大。')
  147. # 模型选择组件
  148. model_type_dropdown = gr.Dropdown(
  149. ['mobile', 'server'],
  150. value='mobile',
  151. label='Model Type',
  152. info='选择 OCR 模型类型:高效率模型mobile,高精度模型server。')
  153. with gr.Column(scale=1):
  154. img_mask = gr.Image(label='mask',
  155. interactive=False,
  156. elem_classes=['image-container'])
  157. img_output = gr.Image(label=' ',
  158. interactive=False,
  159. elem_classes=['image-container'])
  160. output = gr.Textbox(label='Result')
  161. confidence = gr.Textbox(label='Latency')
  162. downstream.click(fn=main,
  163. inputs=[
  164. input_image, model_type_dropdown,
  165. det_input_size_textbox, rec_drop_score_slider,
  166. mask_thresh_slider, box_thresh_slider,
  167. unclip_ratio_slider, det_score_mode_dropdown
  168. ],
  169. outputs=[
  170. output,
  171. confidence,
  172. img_output,
  173. img_mask,
  174. ])
  175. demo.launch(share=True)