Skip to content

Commit a5167be

Browse files
committed
fix: correct yolo cls model preprocess
1 parent 37fa544 commit a5167be

File tree

4 files changed

+50
-23
lines changed

4 files changed

+50
-23
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,5 @@ long1.jpg
158158
.DS_Store
159159
*.npy
160160
outputs/
161+
/tests/test_files/standard_dataset/
162+
/lineless_table_rec/images/

README.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
### 最近更新
1616
- **2024.10.13**
1717
- 补充最新paddlex-SLANet-plus 测评结果(已集成模型到[RapidTable](https://github.com/RapidAI/RapidTable)仓库)
18-
- **2024.10.17**
19-
- 补充最新surya 表格识别测评结果
2018
- **2024.10.22**
21-
- 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection)
19+
- 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection)
20+
- **2024.10.29**
21+
- 使用yolo11重新训练表格分类器,修正wired_table_rec v2逻辑坐标还原错误,并更新测评
22+
2223
### 简介
2324
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。
2425

@@ -57,10 +58,10 @@
5758
| [deepdoctection(rag-flow)](https://github.com/deepdoctection/deepdoctection?tab=readme-ov-file) | 0.59975 | 0.69918 |
5859
| [ppstructure_table_master](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.61606 | 0.73892 |
5960
| [ppsturcture_table_engine](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.67924 | 0.78653 |
60-
| table_cls + wired_table_rec v1 + lineless_table_rec | 0.68507 | 0.75140 |
6161
| [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) | 0.67310 | 0.81210 |
6262
| [RapidTable(SLANet)](https://github.com/RapidAI/RapidTable) | 0.71654 | 0.81067 |
63-
| table_cls + wired_table_rec v2 + lineless_table_rec | 0.73702 | 0.80210 |
63+
| table_cls + wired_table_rec v1 + lineless_table_rec | 0.75288 | 0.82574 |
64+
| table_cls + wired_table_rec v2 + lineless_table_rec | 0.77676 | 0.84580 |
6465
| [RapidTable(SLANet-plus)](https://github.com/RapidAI/RapidTable) | **0.84481** | **0.91369** |
6566

6667
### 使用建议
@@ -87,6 +88,8 @@ from wired_table_rec import WiredTableRecognition
8788
lineless_engine = LinelessTableRecognition()
8889
wired_engine = WiredTableRecognition()
8990
table_cls = TableCls()
91+
# 分类精度降低,但耗时减少 3/5(0.2s->0.08s)
92+
# table_cls = TableCls(mode="q")
9093
img_path = f'images/img14.jpg'
9194

9295
cls,elasp = table_cls(img_path)
@@ -158,7 +161,8 @@ for i, res in enumerate(result):
158161
- [x] 图片小角度偏移修正方法补充
159162
- [x] 增加数据集数量,增加更多评测对比
160163
- [x] 补充复杂场景表格检测和提取,解决旋转和透视导致的低识别率
161-
- [ ] 优化表格分类器,优化无线表格模型
164+
- [x] 优化表格分类器
165+
- [ ] 优化无线表格模型
162166

163167
### 处理流程
164168

table_cls/main.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from PIL import Image
77

8-
from .utils import InputType, LoadImage, OrtInferSession, ResizePad
8+
from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop
99

1010
cur_dir = Path(__file__).resolve().parent
1111
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
@@ -64,10 +64,15 @@ class YoloCls:
6464
def __init__(self, model_path):
6565
self.table_cls = OrtInferSession(model_path)
6666
self.cls = {0: "wireless", 1: "wired"}
67+
self.mean = np.array([0, 0, 0], dtype=np.float32)
68+
self.std = np.array([1, 1, 1], dtype=np.float32)
6769

6870
def preprocess(self, img):
69-
img, *_ = ResizePad(img, 640)
70-
img = np.array(img, dtype=np.float32) / 255.0
71+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
72+
img = resize_and_center_crop(img, 640)
73+
img = np.array(img, dtype=np.float32) / 255
74+
img -= self.mean
75+
img /= self.std
7176
img = img.transpose(2, 0, 1) # HWC to CHW
7277
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
7378
return img

table_cls/utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -180,17 +180,33 @@ def verify_exist(file_path: Union[str, Path]):
180180
raise LoadImageError(f"{file_path} does not exist.")
181181

182182

183-
def ResizePad(img, target_size):
184-
h, w = img.shape[:2]
185-
m = max(h, w)
186-
ratio = target_size / m
187-
new_w, new_h = int(ratio * w), int(ratio * h)
188-
img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR)
189-
top = (target_size - new_h) // 2
190-
bottom = (target_size - new_h) - top
191-
left = (target_size - new_w) // 2
192-
right = (target_size - new_w) - left
193-
img1 = cv2.copyMakeBorder(
194-
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
195-
)
196-
return img1, new_w, new_h, left, top
183+
def resize_and_center_crop(image: np.ndarray, target_size: int):
184+
"""
185+
Resize the image so that the smallest side is equal to the target size,
186+
then crop the center of the image to the specified target size.
187+
188+
Args:
189+
image (np.ndarray): Input image as a NumPy array with shape (height, width, channels).
190+
target_size (int): Target size for the smallest side of the image and the output size.
191+
192+
Returns:
193+
(np.ndarray): Resized and cropped image as a NumPy array.
194+
"""
195+
# 获取输入图像的尺寸
196+
h, w = image.shape[:2]
197+
198+
# 计算缩放比例
199+
scale = target_size / min(h, w)
200+
new_h, new_w = int(h * scale), int(w * scale)
201+
202+
# 缩放图像
203+
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
204+
205+
# 计算裁剪的起始位置
206+
i = (new_h - target_size) // 2
207+
j = (new_w - target_size) // 2
208+
209+
# 裁剪图像
210+
cropped_image = resized_image[i : i + target_size, j : j + target_size]
211+
212+
return cropped_image

0 commit comments

Comments
 (0)