Skip to content

Commit 4c4cad0

Browse files
committed
feat: add yolo model and auto download
1 parent 9dfe95f commit 4c4cad0

File tree

13 files changed

+647
-594
lines changed

13 files changed

+647
-594
lines changed

README.md

Lines changed: 50 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -10,63 +10,82 @@
1010
</div>
1111

1212
### 最近更新
13+
1314
- **2024.10.15**
1415
- 完成初版代码,包含目标检测,语义分割,角点方向识别三个模块
16+
- **2024.11.2**
17+
- 补充新训练yolo11的目标检测模型和边缘检测模型,增加自动下载,轻量化包体积,自由组合各个模块
1518

1619
### 简介
20+
1721
💡✨ 强大且高效的表格检测,支持论文、期刊、杂志、发票、收据、签到单等各种表格。
1822

19-
🚀 支持高精度 Paddle 版本和量化 ONNX 版本,单图 CPU 推理仅需 1.5 秒,Paddle-GPU(V100) 仅需 0.2 秒。
23+
🚀 支持来源于paddle和yolo的版本,平衡速度和精度下单图 CPU 推理仅需 1 秒,Paddle-GPU(V100) 仅需 0.2 秒。
2024

2125
🛠️ 支持三个模块自由组合,独立训练调优,提供 ONNX 转换脚本和微调训练方案。
2226

2327
🌟 whl 包轻松集成使用,为下游 OCR、表格识别和数据采集提供强力支撑。
2428

25-
📚参考项目 [百度表格检测大赛第2名方案](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL) 的实现方案,补充大量真实场景数据再训练
29+
📚参考项目 [百度表格检测大赛第2名方案](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL)
30+
的实现方案,补充大量真实场景数据再训练
2631
![img.png](readme_resource/structure.png)
2732
👇🏻训练数据集在致谢, 作者天天上班摸鱼搞开源,希望大家点个⭐️支持一下
2833

2934
### 使用建议
35+
3036
📚 文档场景: 无透视旋转,只使用目标检测\
3137
📷 拍照场景小角度旋转(-90~90): 默认左上角,不使用角点方向识别\
3238
🔍 使用在线体验找到适合你场景的模型组合
33-
### 在线体验
3439

40+
### 在线体验
3541

3642
### 效果展示
43+
3744
![res_show.jpg](readme_resource/res_show.jpg)![res_show2.jpg](readme_resource/res_show2.jpg)
45+
3846
### 安装
39-
为简化使用,已经将最小的量化模型打包到 rapid_table_det 中,需要更高精度或gpu推理,请自行下载对应模型
40-
🪜下载模型 [modescope模型仓](https://www.modelscope.cn/models/jockerK/TableExtractor) [release assets](https://github.com/Joker1212/RapidTableDetection/releases/tag/v0.0.0)
47+
48+
🪜模型会自动下载,也可以自己去仓库下载 [modescope模型仓](https://www.modelscope.cn/models/jockerK/TableExtractor)
49+
4150
``` python {linenos=table}
4251
# 建议使用清华源安装 https://pypi.tuna.tsinghua.edu.cn/simple
4352
pip install rapid-table-det
4453
```
4554

4655
#### 参数说明
47-
cpu和gpu的初始化完全一致
48-
```
49-
table_det = TableDetector(
50-
# 目标检测表格模型
51-
obj_model_path="models/obj_det_paddle(obj_det.onnx)",
52-
# 边角检测表格模型(从复杂环境得到表格多边形框)
53-
edge_model_path="models/edge_det_paddle(edge_det.onnx)",
54-
# 角点方向识别
55-
cls_model_path="models/cls_det_paddle(cls_det.onnx)",
56-
# 文档场景已经由版面识别模型提取设置为False
57-
use_obj_det=True,
58-
# 只有90,180,270大角度旋转且无透视时候设置为False
59-
use_edge_det=True,
60-
# 小角度(-90~90)旋转设置为False
61-
use_cls_det=True,
62-
)
63-
```
56+
57+
默认值
58+
use_cuda: False : 启用gpu加速推理 \
59+
obj_model_type="yolo_obj_det", \
60+
edge_model_type= "yolo_edge_det", \
61+
cls_model_type= "paddle_cls_det"
62+
63+
| `model_type` | 任务类型 | 训练来源 | 大小 | 单表格耗时 |
64+
|:---------------------|:-------|:-------------------------------------|:-------|:----------------------|
65+
| **yolo_obj_det** | 表格目标检测 | `yolo11-l` | `100m` | `cpu:500ms, gpu:0.2` |
66+
| `paddle_obj_det` | 表格目标检测 | `paddle yoloe-plus-x` | `380m` | `cpu:500ms, gpu:0.2` |
67+
| `paddle_obj_det_s` | 表格目标检测 | `paddle yoloe-plus-x + quantization` | `95m` | `cpu:1000ms, gpu:0.2` |
68+
| **yolo_edge_det** | 语义分割 | `yolo11-l-segment` | `108m` | `cpu:500ms, gpu:0.2` |
69+
| `yolo_edge_det_s` | 语义分割 | `yolo11-s-segment` | `11m` | `cpu:100ms, gpu:0.2` |
70+
| `paddle_edge_det` | 语义分割 | `paddle-dbnet` | `99m` | `cpu:600ms, gpu:0.2` |
71+
| `paddle_edge_det_s` | 语义分割 | `paddle-dbnet + quantization` | `25m` | `cpu:500ms, gpu:0.2` |
72+
| **paddle_cls_det** | 方向分类 | `paddle pplcnet` | `6.5m` | `cpu:70ms, gpu:0.2` |
73+
74+
75+
执行参数
76+
det_accuracy=0.7,
77+
use_obj_det=True,
78+
use_edge_det=True,
79+
use_cls_det=True,
6480

6581
### 快速使用
82+
6683
``` python {linenos=table}
6784
from rapid_table_det.inference import TableDetector
85+
86+
img_path = f"images/weixin.png"
6887
table_det = TableDetector()
69-
img_path = f"tests/test_files/chip.jpg"
88+
7089
result, elapse = table_det(img_path)
7190
obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
7291
print(
@@ -75,7 +94,8 @@ print(
7594
# 输出可视化
7695
# import os
7796
# import cv2
78-
# from rapid_table_det.utils import img_loader, visuallize, extract_table_img
97+
# from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
98+
#
7999
# img = img_loader(img_path)
80100
# file_name_with_ext = os.path.basename(img_path)
81101
# file_name, file_ext = os.path.splitext(file_name_with_ext)
@@ -94,68 +114,28 @@ print(
94114
# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
95115

96116
```
97-
### gpu版本使用
98-
必须下载模型,指定模型位置!
99-
``` python {linenos=table}
100-
# 建议使用清华源安装 https://pypi.tuna.tsinghua.edu.cn/simple
101-
pip install rapid-table-det-paddle (默认安装gpu版本,可以自行覆盖安装cpu版本paddlepaddle)
102-
```
103-
```python
104-
from rapid_table_det_paddle.inference import TableDetector
105-
106-
img_path = f"tests/test_files/chip.jpg"
107-
108-
table_det = TableDetector(
109-
obj_model_path="models/obj_det_paddle",
110-
edge_model_path="models/edge_det_paddle",
111-
cls_model_path="models/cls_det_paddle",
112-
use_obj_det=True,
113-
use_edge_det=True,
114-
use_cls_det=True,
115-
)
116-
result, elapse = table_det(img_path)
117-
obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
118-
print(
119-
f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
120-
)
121-
# 一张图片中可能有多个表格
122-
# img = img_loader(img_path)
123-
# file_name_with_ext = os.path.basename(img_path)
124-
# file_name, file_ext = os.path.splitext(file_name_with_ext)
125-
# out_dir = "rapid_table_det_paddle/outputs"
126-
# if not os.path.exists(out_dir):
127-
# os.makedirs(out_dir)
128-
# extract_img = img.copy()
129-
# for i, res in enumerate(result):
130-
# box = res["box"]
131-
# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
132-
# # 带识别框和左上角方向位置
133-
# img = visuallize(img, box, lt, rt, rb, lb)
134-
# # 透视变换提取表格图片
135-
# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
136-
# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
137-
# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
138-
139-
```
140-
141117

142118
## FAQ (Frequently Asked Questions)
143119

144120
1. **问:如何微调模型适应特定场景?**
145-
- 答:直接参考这个项目,有非常详细的可视化操作步骤,可以得到paddle的推理模型 [百度表格检测大赛](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL)
121+
-
122+
答:直接参考这个项目,有非常详细的可视化操作步骤,可以得到paddle的推理模型 [百度表格检测大赛](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL)
146123

147124
2. **问:如何导出onnx**
148-
- 答:在本项目tools下,有onnx_transform.ipynb文件,可以照步骤执行(因为pp-yoloe导出onnx有bug一直没修,这里我自己写了一个fix_onnx脚本改动onnx模型节点来临时解决了)
125+
- 答:在本项目tools下,有onnx_transform.ipynb文件,可以照步骤执行(
126+
因为pp-yoloe导出onnx有bug一直没修,这里我自己写了一个fix_onnx脚本改动onnx模型节点来临时解决了)
149127

150128
3. **问:图片有扭曲可以修正吗?**
151129
- 答:本项目只解决旋转和透视场景的表格提取,对于扭曲的场景,需要先进行扭曲修正
152130

153131
### 致谢
132+
154133
[百度表格检测大赛第2名方案](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL) \
155134
[WTW 自然场景表格数据集](https://tianchi.aliyun.com/dataset/108587) \
156135
[FinTabNet PDF文档表格数据集](https://developer.ibm.com/exchanges/data/all/fintabnet/) \
157136
[TableBank 表格数据集](https://doc-analysis.github.io/tablebank-page/) \
158137
[TableGeneration 表格自动生成工具](https://github.com/WenmuZhou/TableGeneration)
138+
159139
### 贡献指南
160140

161141
欢迎提交请求。对于重大更改,请先打开issue讨论您想要改变的内容。

demo_onnx.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
11
from rapid_table_det.inference import TableDetector
22

3-
# img_path = f"tests/test_files/chip2.jpg"
43
img_path = f"images/weixin.png"
54
table_det = TableDetector(
6-
obj_model_path="rapid_table_det/models/yolo_obj_det_l.onnx",
7-
edge_model_path="rapid_table_det/models/yolo_edge_det_s.onnx",
5+
obj_model_type="paddle_obj_det_s", edge_model_type="paddle_edge_det_s"
86
)
9-
if __name__ == "__main__":
10-
result, elapse = table_det(img_path)
11-
obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
12-
print(
13-
f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
14-
)
15-
# 输出可视化
16-
import os
17-
import cv2
18-
from rapid_table_det.utils import img_loader, visuallize, extract_table_img
197

20-
img = img_loader(img_path)
21-
file_name_with_ext = os.path.basename(img_path)
22-
file_name, file_ext = os.path.splitext(file_name_with_ext)
23-
out_dir = "rapid_table_det/outputs"
24-
if not os.path.exists(out_dir):
25-
os.makedirs(out_dir)
26-
extract_img = img.copy()
27-
for i, res in enumerate(result):
28-
box = res["box"]
29-
lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
30-
# 带识别框和左上角方向位置
31-
img = visuallize(img, box, lt, rt, rb, lb)
32-
# 透视变换提取表格图片
33-
wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
34-
cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
35-
cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
8+
result, elapse = table_det(img_path)
9+
obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
10+
print(
11+
f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
12+
)
13+
# 输出可视化
14+
# import os
15+
# import cv2
16+
# from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
17+
#
18+
# img = img_loader(img_path)
19+
# file_name_with_ext = os.path.basename(img_path)
20+
# file_name, file_ext = os.path.splitext(file_name_with_ext)
21+
# out_dir = "rapid_table_det/outputs"
22+
# if not os.path.exists(out_dir):
23+
# os.makedirs(out_dir)
24+
# extract_img = img.copy()
25+
# for i, res in enumerate(result):
26+
# box = res["box"]
27+
# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
28+
# # 带识别框和左上角方向位置
29+
# img = visuallize(img, box, lt, rt, rb, lb)
30+
# # 透视变换提取表格图片
31+
# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
32+
# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
33+
# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)

rapid_table_det/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,3 @@
22
# @Author: Jocker1212
33
# @Contact: xinyijianggo@gmail.com
44
from .inference import TableDetector
5-
from .utils import img_loader, visuallize, extract_table_img
6-
7-
#
8-
__all__ = ["TableDetector", "img_loader", "visuallize", "extract_table_img"]

rapid_table_det/inference.py

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,75 @@
11
import os
22
from pathlib import Path
3+
from typing import Union
34

45
import cv2
56
import numpy as np
67

7-
from rapid_table_det.predictor import DbNet, ObjectDetector, PPLCNet, YoloSeg, YoloDet
8-
from rapid_table_det.utils import LoadImage
8+
from .predictor import DbNet, PaddleYoloEDet, PPLCNet, YoloSeg, YoloDet
9+
from .utils.download_model import DownloadModel
10+
11+
from .utils.logger import get_logger
12+
from .utils.load_image import LoadImage
913

1014
root_dir = Path(__file__).resolve().parent
1115
model_dir = os.path.join(root_dir, "models")
1216

17+
ROOT_DIR = Path(__file__).resolve().parent
18+
logger = get_logger("rapid_layout")
19+
20+
ROOT_URL = "https://www.modelscope.cn/models/jockerK/TableExtractor/resolve/master/rapid_table_det/models/"
21+
KEY_TO_MODEL_URL = {
22+
"yolo_obj_det": f"{ROOT_URL}/yolo_obj_det.onnx",
23+
"yolo_edge_det": f"{ROOT_URL}/yolo_edge_det.onnx",
24+
"yolo_edge_det_s": f"{ROOT_URL}/yolo_edge_det_s.onnx",
25+
"paddle_obj_det": f"{ROOT_URL}/paddle_obj_det.onnx",
26+
"paddle_obj_det_s": f"{ROOT_URL}/paddle_obj_det_s.onnx",
27+
"paddle_edge_det": f"{ROOT_URL}/paddle_edge_det.onnx",
28+
"paddle_edge_det_s": f"{ROOT_URL}/paddle_edge_det_s.onnx",
29+
"paddle_cls_det": f"{ROOT_URL}/paddle_cls_det.onnx",
30+
}
31+
1332

1433
class TableDetector:
1534
def __init__(
1635
self,
17-
obj_model="yolo",
18-
edge_model="yolo",
19-
obj_model_path=os.path.join(model_dir, "obj_det_quantized.onnx"),
20-
edge_model_path=os.path.join(model_dir, "yolo_edge_det_s.onnx"),
21-
cls_model_path=os.path.join(model_dir, "cls_det.onnx"),
36+
use_cuda=False,
37+
use_dml=False,
38+
obj_model_path=None,
39+
edge_model_path=None,
40+
cls_model_path=None,
41+
obj_model_type="yolo_obj_det",
42+
edge_model_type="yolo_edge_det",
43+
cls_model_type="paddle_cls_det",
2244
):
2345
self.img_loader = LoadImage()
24-
if obj_model == "yolo":
25-
self.obj_detector = YoloDet(obj_model_path)
46+
obj_det_config = {
47+
"model_path": self.get_model_path(obj_model_type, obj_model_path),
48+
"use_cuda": use_cuda,
49+
"use_dml": use_dml,
50+
}
51+
edge_det_config = {
52+
"model_path": self.get_model_path(edge_model_type, edge_model_path),
53+
"use_cuda": use_cuda,
54+
"use_dml": use_dml,
55+
}
56+
cls_det_config = {
57+
"model_path": self.get_model_path(cls_model_type, cls_model_path),
58+
"use_cuda": use_cuda,
59+
"use_dml": use_dml,
60+
}
61+
if "yolo" in obj_model_type:
62+
self.obj_detector = YoloDet(obj_det_config)
2663
else:
27-
self.obj_detector = ObjectDetector(obj_model_path)
28-
if edge_model == "yolo":
29-
self.dbnet = YoloSeg(edge_model_path)
64+
self.obj_detector = PaddleYoloEDet(obj_det_config)
65+
if "yolo" in edge_model_type:
66+
self.dbnet = YoloSeg(edge_det_config)
3067
else:
31-
self.dbnet = DbNet(edge_model_path)
32-
self.pplcnet = PPLCNet(cls_model_path)
68+
self.dbnet = DbNet(edge_det_config)
69+
if "yolo" in cls_model_type:
70+
self.pplcnet = PPLCNet(cls_det_config)
71+
else:
72+
self.pplcnet = PPLCNet(cls_det_config)
3373

3474
def __call__(
3575
self,
@@ -101,6 +141,16 @@ def init_default_output(self, h, w):
101141
return obj_det_res, pred_label
102142

103143
def add_pre_info_for_cls(self, cls_img, edge_box, xmin_cls, ymin_cls):
144+
"""
145+
Args:
146+
cls_img:
147+
edge_box:
148+
xmin_cls:
149+
ymin_cls:
150+
151+
Returns: 带边缘划线的图片,给方向分类提供先验信息
152+
153+
"""
104154
cls_box = edge_box.copy()
105155
cls_box[:, 0] = cls_box[:, 0] - xmin_cls
106156
cls_box[:, 1] = cls_box[:, 1] - ymin_cls
@@ -166,3 +216,18 @@ def pad_box_points(self, h, w, xmax, xmin, ymax, ymin, pad):
166216
ymax_edge = min(ymax + pad, h)
167217
xmax_edge = min(xmax + pad, w)
168218
return xmin_edge, ymin_edge, xmax_edge, ymax_edge
219+
220+
@staticmethod
221+
def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str:
222+
if model_path is not None:
223+
return model_path
224+
225+
model_url = KEY_TO_MODEL_URL.get(model_type, None)
226+
if model_url:
227+
model_path = DownloadModel.download(model_url)
228+
return model_path
229+
230+
logger.info(
231+
"model url is None, using the default download model %s", model_path
232+
)
233+
return model_path

0 commit comments

Comments
 (0)