44import cv2
55import numpy as np
66
7- from rapid_table_det .predictor import DbNet , ObjectDetector , PPLCNet
7+ from rapid_table_det .predictor import DbNet , ObjectDetector , PPLCNet , YoloSeg , YoloDet
88from rapid_table_det .utils import LoadImage
99
10- MODEL_URLS = {
11- "onnx_tiny" : {
12- "obj_det" : "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/obj_det_quantized.zip" ,
13- "edge_det" : "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/edge_det_quantized.zip" ,
14- "cls_det" : "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/cls_det.zip" ,
15- },
16- "onnx" : {
17- "obj_det" : "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/obj_det.zip" ,
18- "edge_det" : "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/edge_det.zip" ,
19- "cls_det" : "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/cls_det.zip" ,
20- },
21- }
2210root_dir = Path (__file__ ).resolve ().parent
2311model_dir = os .path .join (root_dir , "models" )
2412
2513
2614class TableDetector :
2715 def __init__ (
2816 self ,
17+ obj_model = "yolo" ,
18+ edge_model = "yolo" ,
2919 obj_model_path = os .path .join (model_dir , "obj_det_quantized.onnx" ),
30- edge_model_path = os .path .join (model_dir , "edge_det_quantized .onnx" ),
20+ edge_model_path = os .path .join (model_dir , "yolo_edge_det_s .onnx" ),
3121 cls_model_path = os .path .join (model_dir , "cls_det.onnx" ),
32- use_obj_det = True ,
33- use_edge_det = True ,
34- use_cls_det = True ,
3522 ):
36- self .use_obj_det = use_obj_det
37- self .use_edge_det = use_edge_det
38- self .use_cls_det = use_cls_det
3923 self .img_loader = LoadImage ()
40-
41- if self .use_obj_det :
24+ if obj_model == "yolo" :
25+ self .obj_detector = YoloDet (obj_model_path )
26+ else :
4227 self .obj_detector = ObjectDetector (obj_model_path )
43- if self .use_edge_det :
28+ if edge_model == "yolo" :
29+ self .dbnet = YoloSeg (edge_model_path )
30+ else :
4431 self .dbnet = DbNet (edge_model_path )
45- if self .use_cls_det :
46- self .pplcnet = PPLCNet (cls_model_path )
32+ self .pplcnet = PPLCNet (cls_model_path )
4733
48- def __call__ (self , img , det_accuracy = 0.7 ):
34+ def __call__ (
35+ self ,
36+ img ,
37+ det_accuracy = 0.7 ,
38+ use_obj_det = True ,
39+ use_edge_det = True ,
40+ use_cls_det = True ,
41+ ):
4942 img = self .img_loader (img )
5043 img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
5144 img_mask = img .copy ()
5245 h , w = img .shape [:- 1 ]
53- img_box = np .array ([0 , 0 , w , h ])
54- lb , lt , rb , rt = self .get_box_points (img_box )
55- # 初始化默认值
56- obj_det_res , edge_box , pred_label = (
57- [[1.0 , img_box ]],
58- img_box .reshape ([- 1 , 2 ]),
59- 0 ,
60- )
46+ obj_det_res , pred_label = self .init_default_output (h , w )
6147 result = []
6248 obj_det_elapse , edge_elapse , rotate_det_elapse = 0 , 0 , 0
63- if self . use_obj_det :
49+ if use_obj_det :
6450 obj_det_res , obj_det_elapse = self .obj_detector (img , score = det_accuracy )
6551 for i in range (len (obj_det_res )):
6652 det_res = obj_det_res [i ]
6753 score , box = det_res
6854 xmin , ymin , xmax , ymax = box
6955 edge_box = box .reshape ([- 1 , 2 ])
7056 lb , lt , rb , rt = self .get_box_points (box )
71- if self . use_edge_det :
57+ if use_edge_det :
7258 xmin_edge , ymin_edge , xmax_edge , ymax_edge = self .pad_box_points (
7359 h , w , xmax , xmin , ymax , ymin , 10
7460 )
@@ -77,30 +63,16 @@ def __call__(self, img, det_accuracy=0.7):
7763 edge_elapse += tmp_edge_elapse
7864 if edge_box is None :
7965 continue
80- edge_box [:, 0 ] += xmin_edge
81- edge_box [:, 1 ] += ymin_edge
82- lt , lb , rt , rb = (
83- lt + [xmin_edge , ymin_edge ],
84- lb + [xmin_edge , ymin_edge ],
85- rt + [xmin_edge , ymin_edge ],
86- rb + [xmin_edge , ymin_edge ],
66+ lb , lt , rb , rt = self .adjust_edge_points_axis (
67+ edge_box , lb , lt , rb , rt , xmin_edge , ymin_edge
8768 )
88- if self . use_cls_det :
69+ if use_cls_det :
8970 xmin_cls , ymin_cls , xmax_cls , ymax_cls = self .pad_box_points (
9071 h , w , xmax , xmin , ymax , ymin , 5
9172 )
92- cls_box = edge_box .copy ()
9373 cls_img = img_mask [ymin_cls :ymax_cls , xmin_cls :xmax_cls , :]
94- cls_box [:, 0 ] = cls_box [:, 0 ] - xmin_cls
95- cls_box [:, 1 ] = cls_box [:, 1 ] - ymin_cls
96- # 画框增加先验信息,辅助方向label识别
97- cv2 .polylines (
98- cls_img ,
99- [np .array (cls_box ).astype (np .int32 ).reshape ((- 1 , 1 , 2 ))],
100- True ,
101- color = (255 , 0 , 255 ),
102- thickness = 5 ,
103- )
74+ # 增加先验信息
75+ self .add_pre_info_for_cls (cls_img , edge_box , xmin_cls , ymin_cls )
10476 pred_label , tmp_rotate_det_elapse = self .pplcnet (cls_img )
10577 rotate_det_elapse += tmp_rotate_det_elapse
10678 lb1 , lt1 , rb1 , rt1 = self .get_real_rotated_points (
@@ -118,6 +90,40 @@ def __call__(self, img, det_accuracy=0.7):
11890 elapse = [obj_det_elapse , edge_elapse , rotate_det_elapse ]
11991 return result , elapse
12092
93+ def init_default_output (self , h , w ):
94+ img_box = np .array ([0 , 0 , w , h ])
95+ # 初始化默认值
96+ obj_det_res , edge_box , pred_label = (
97+ [[1.0 , img_box ]],
98+ img_box .reshape ([- 1 , 2 ]),
99+ 0 ,
100+ )
101+ return obj_det_res , pred_label
102+
103+ def add_pre_info_for_cls (self , cls_img , edge_box , xmin_cls , ymin_cls ):
104+ cls_box = edge_box .copy ()
105+ cls_box [:, 0 ] = cls_box [:, 0 ] - xmin_cls
106+ cls_box [:, 1 ] = cls_box [:, 1 ] - ymin_cls
107+ # 画框增加先验信息,辅助方向label识别
108+ cv2 .polylines (
109+ cls_img ,
110+ [np .array (cls_box ).astype (np .int32 ).reshape ((- 1 , 1 , 2 ))],
111+ True ,
112+ color = (255 , 0 , 255 ),
113+ thickness = 5 ,
114+ )
115+
116+ def adjust_edge_points_axis (self , edge_box , lb , lt , rb , rt , xmin_edge , ymin_edge ):
117+ edge_box [:, 0 ] += xmin_edge
118+ edge_box [:, 1 ] += ymin_edge
119+ lt , lb , rt , rb = (
120+ lt + [xmin_edge , ymin_edge ],
121+ lb + [xmin_edge , ymin_edge ],
122+ rt + [xmin_edge , ymin_edge ],
123+ rb + [xmin_edge , ymin_edge ],
124+ )
125+ return lb , lt , rb , rt
126+
121127 def get_box_points (self , img_box ):
122128 x1 , y1 , x2 , y2 = img_box
123129 lt = np .array ([x1 , y1 ]) # 左上角
0 commit comments