Skip to content

Commit 9dfe95f

Browse files
committed
feat: add yolo model
1 parent eb54901 commit 9dfe95f

File tree

4 files changed

+207
-93
lines changed

4 files changed

+207
-93
lines changed

demo_onnx.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,35 @@
11
from rapid_table_det.inference import TableDetector
22

3-
img_path = f"tests/test_files/chip2.jpg"
3+
# img_path = f"tests/test_files/chip2.jpg"
4+
img_path = f"images/weixin.png"
45
table_det = TableDetector(
5-
obj_model_path="rapid_table_det/models/obj_det.onnx",
6-
edge_model_path="rapid_table_det/models/edge_det.onnx",
6+
obj_model_path="rapid_table_det/models/yolo_obj_det_l.onnx",
7+
edge_model_path="rapid_table_det/models/yolo_edge_det_s.onnx",
78
)
8-
result, elapse = table_det(img_path)
9-
obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
10-
print(
11-
f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
12-
)
13-
# 输出可视化
14-
import os
15-
import cv2
16-
from rapid_table_det.utils import img_loader, visuallize, extract_table_img
9+
if __name__ == "__main__":
10+
result, elapse = table_det(img_path)
11+
obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
12+
print(
13+
f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
14+
)
15+
# 输出可视化
16+
import os
17+
import cv2
18+
from rapid_table_det.utils import img_loader, visuallize, extract_table_img
1719

18-
img = img_loader(img_path)
19-
file_name_with_ext = os.path.basename(img_path)
20-
file_name, file_ext = os.path.splitext(file_name_with_ext)
21-
out_dir = "rapid_table_det/outputs"
22-
if not os.path.exists(out_dir):
23-
os.makedirs(out_dir)
24-
extract_img = img.copy()
25-
for i, res in enumerate(result):
26-
box = res["box"]
27-
lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
28-
# 带识别框和左上角方向位置
29-
img = visuallize(img, box, lt, rt, rb, lb)
30-
# 透视变换提取表格图片
31-
wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
32-
cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
33-
cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
20+
img = img_loader(img_path)
21+
file_name_with_ext = os.path.basename(img_path)
22+
file_name, file_ext = os.path.splitext(file_name_with_ext)
23+
out_dir = "rapid_table_det/outputs"
24+
if not os.path.exists(out_dir):
25+
os.makedirs(out_dir)
26+
extract_img = img.copy()
27+
for i, res in enumerate(result):
28+
box = res["box"]
29+
lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
30+
# 带识别框和左上角方向位置
31+
img = visuallize(img, box, lt, rt, rb, lb)
32+
# 透视变换提取表格图片
33+
wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
34+
cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
35+
cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)

rapid_table_det/inference.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,57 @@
44
import cv2
55
import numpy as np
66

7-
from rapid_table_det.predictor import DbNet, ObjectDetector, PPLCNet
7+
from rapid_table_det.predictor import DbNet, ObjectDetector, PPLCNet, YoloSeg, YoloDet
88
from rapid_table_det.utils import LoadImage
99

10-
MODEL_URLS = {
11-
"onnx_tiny": {
12-
"obj_det": "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/obj_det_quantized.zip",
13-
"edge_det": "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/edge_det_quantized.zip",
14-
"cls_det": "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/cls_det.zip",
15-
},
16-
"onnx": {
17-
"obj_det": "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/obj_det.zip",
18-
"edge_det": "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/edge_det.zip",
19-
"cls_det": "https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/cls_det.zip",
20-
},
21-
}
2210
root_dir = Path(__file__).resolve().parent
2311
model_dir = os.path.join(root_dir, "models")
2412

2513

2614
class TableDetector:
2715
def __init__(
2816
self,
17+
obj_model="yolo",
18+
edge_model="yolo",
2919
obj_model_path=os.path.join(model_dir, "obj_det_quantized.onnx"),
30-
edge_model_path=os.path.join(model_dir, "edge_det_quantized.onnx"),
20+
edge_model_path=os.path.join(model_dir, "yolo_edge_det_s.onnx"),
3121
cls_model_path=os.path.join(model_dir, "cls_det.onnx"),
32-
use_obj_det=True,
33-
use_edge_det=True,
34-
use_cls_det=True,
3522
):
36-
self.use_obj_det = use_obj_det
37-
self.use_edge_det = use_edge_det
38-
self.use_cls_det = use_cls_det
3923
self.img_loader = LoadImage()
40-
41-
if self.use_obj_det:
24+
if obj_model == "yolo":
25+
self.obj_detector = YoloDet(obj_model_path)
26+
else:
4227
self.obj_detector = ObjectDetector(obj_model_path)
43-
if self.use_edge_det:
28+
if edge_model == "yolo":
29+
self.dbnet = YoloSeg(edge_model_path)
30+
else:
4431
self.dbnet = DbNet(edge_model_path)
45-
if self.use_cls_det:
46-
self.pplcnet = PPLCNet(cls_model_path)
32+
self.pplcnet = PPLCNet(cls_model_path)
4733

48-
def __call__(self, img, det_accuracy=0.7):
34+
def __call__(
35+
self,
36+
img,
37+
det_accuracy=0.7,
38+
use_obj_det=True,
39+
use_edge_det=True,
40+
use_cls_det=True,
41+
):
4942
img = self.img_loader(img)
5043
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
5144
img_mask = img.copy()
5245
h, w = img.shape[:-1]
53-
img_box = np.array([0, 0, w, h])
54-
lb, lt, rb, rt = self.get_box_points(img_box)
55-
# 初始化默认值
56-
obj_det_res, edge_box, pred_label = (
57-
[[1.0, img_box]],
58-
img_box.reshape([-1, 2]),
59-
0,
60-
)
46+
obj_det_res, pred_label = self.init_default_output(h, w)
6147
result = []
6248
obj_det_elapse, edge_elapse, rotate_det_elapse = 0, 0, 0
63-
if self.use_obj_det:
49+
if use_obj_det:
6450
obj_det_res, obj_det_elapse = self.obj_detector(img, score=det_accuracy)
6551
for i in range(len(obj_det_res)):
6652
det_res = obj_det_res[i]
6753
score, box = det_res
6854
xmin, ymin, xmax, ymax = box
6955
edge_box = box.reshape([-1, 2])
7056
lb, lt, rb, rt = self.get_box_points(box)
71-
if self.use_edge_det:
57+
if use_edge_det:
7258
xmin_edge, ymin_edge, xmax_edge, ymax_edge = self.pad_box_points(
7359
h, w, xmax, xmin, ymax, ymin, 10
7460
)
@@ -77,30 +63,16 @@ def __call__(self, img, det_accuracy=0.7):
7763
edge_elapse += tmp_edge_elapse
7864
if edge_box is None:
7965
continue
80-
edge_box[:, 0] += xmin_edge
81-
edge_box[:, 1] += ymin_edge
82-
lt, lb, rt, rb = (
83-
lt + [xmin_edge, ymin_edge],
84-
lb + [xmin_edge, ymin_edge],
85-
rt + [xmin_edge, ymin_edge],
86-
rb + [xmin_edge, ymin_edge],
66+
lb, lt, rb, rt = self.adjust_edge_points_axis(
67+
edge_box, lb, lt, rb, rt, xmin_edge, ymin_edge
8768
)
88-
if self.use_cls_det:
69+
if use_cls_det:
8970
xmin_cls, ymin_cls, xmax_cls, ymax_cls = self.pad_box_points(
9071
h, w, xmax, xmin, ymax, ymin, 5
9172
)
92-
cls_box = edge_box.copy()
9373
cls_img = img_mask[ymin_cls:ymax_cls, xmin_cls:xmax_cls, :]
94-
cls_box[:, 0] = cls_box[:, 0] - xmin_cls
95-
cls_box[:, 1] = cls_box[:, 1] - ymin_cls
96-
# 画框增加先验信息,辅助方向label识别
97-
cv2.polylines(
98-
cls_img,
99-
[np.array(cls_box).astype(np.int32).reshape((-1, 1, 2))],
100-
True,
101-
color=(255, 0, 255),
102-
thickness=5,
103-
)
74+
# 增加先验信息
75+
self.add_pre_info_for_cls(cls_img, edge_box, xmin_cls, ymin_cls)
10476
pred_label, tmp_rotate_det_elapse = self.pplcnet(cls_img)
10577
rotate_det_elapse += tmp_rotate_det_elapse
10678
lb1, lt1, rb1, rt1 = self.get_real_rotated_points(
@@ -118,6 +90,40 @@ def __call__(self, img, det_accuracy=0.7):
11890
elapse = [obj_det_elapse, edge_elapse, rotate_det_elapse]
11991
return result, elapse
12092

93+
def init_default_output(self, h, w):
94+
img_box = np.array([0, 0, w, h])
95+
# 初始化默认值
96+
obj_det_res, edge_box, pred_label = (
97+
[[1.0, img_box]],
98+
img_box.reshape([-1, 2]),
99+
0,
100+
)
101+
return obj_det_res, pred_label
102+
103+
def add_pre_info_for_cls(self, cls_img, edge_box, xmin_cls, ymin_cls):
104+
cls_box = edge_box.copy()
105+
cls_box[:, 0] = cls_box[:, 0] - xmin_cls
106+
cls_box[:, 1] = cls_box[:, 1] - ymin_cls
107+
# 画框增加先验信息,辅助方向label识别
108+
cv2.polylines(
109+
cls_img,
110+
[np.array(cls_box).astype(np.int32).reshape((-1, 1, 2))],
111+
True,
112+
color=(255, 0, 255),
113+
thickness=5,
114+
)
115+
116+
def adjust_edge_points_axis(self, edge_box, lb, lt, rb, rt, xmin_edge, ymin_edge):
117+
edge_box[:, 0] += xmin_edge
118+
edge_box[:, 1] += ymin_edge
119+
lt, lb, rt, rb = (
120+
lt + [xmin_edge, ymin_edge],
121+
lb + [xmin_edge, ymin_edge],
122+
rt + [xmin_edge, ymin_edge],
123+
rb + [xmin_edge, ymin_edge],
124+
)
125+
return lb, lt, rb, rt
126+
121127
def get_box_points(self, img_box):
122128
x1, y1, x2, y2 = img_box
123129
lt = np.array([x1, y1]) # 左上角

0 commit comments

Comments
 (0)