Skip to content

Commit b844584

Browse files
committed
fix: fix wired unet model rec
1 parent 7fd2549 commit b844584

File tree

7 files changed

+76
-30
lines changed

7 files changed

+76
-30
lines changed

demo_all.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from table_cls import TableCls
2+
from wired_table_rec.main import WiredTableInput, WiredTableRecognition
3+
from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition
4+
5+
if __name__ == "__main__":
6+
# Init
7+
wired_input = WiredTableInput()
8+
lineless_input = LinelessTableInput()
9+
wired_engine = WiredTableRecognition(wired_input)
10+
lineless_engine = LinelessTableRecognition(lineless_input)
11+
# 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型或paddle模型(0.03s)
12+
table_cls = TableCls()
13+
img_path = f"tests/test_files/table.jpg"
14+
15+
cls, elasp = table_cls(img_path)
16+
if cls == "wired":
17+
table_engine = wired_engine
18+
else:
19+
table_engine = lineless_engine
20+
21+
table_results = table_engine(img_path, enhance_box_line=False)
22+
# 使用RapidOCR输入
23+
# ocr_engine = RapidOCR()
24+
# ocr_result, _ = ocr_engine(img_path)
25+
# table_results = table_engine(img_path, ocr_result=ocr_result)
26+
27+
# Visualize table rec result
28+
# save_dir = Path("outputs")
29+
# save_dir.mkdir(parents=True, exist_ok=True)
30+
#
31+
# save_html_path = f"outputs/{Path(img_path).stem}.html"
32+
# save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
33+
# save_logic_path = (
34+
# f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}"
35+
# )
36+
37+
#
38+
# vis_table = VisTable()
39+
# vis_imged = vis_table(
40+
# img_path, table_results, save_html_path, save_drawed_path, save_logic_path
41+
# )

demo_lineless.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from rapidocr_onnxruntime import RapidOCR
77

88
from lineless_table_rec import LinelessTableRecognition
9-
from lineless_table_rec.main import RapidTableInput
9+
from lineless_table_rec.main import LinelessTableInput
1010
from lineless_table_rec.utils.utils import VisTable
1111

1212
output_dir = Path("outputs")
1313
output_dir.mkdir(parents=True, exist_ok=True)
14-
input_args = RapidTableInput()
14+
input_args = LinelessTableInput()
1515
table_engine = LinelessTableRecognition(input_args)
1616
ocr_engine = RapidOCR()
1717
viser = VisTable()

demo_wired.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from rapidocr_onnxruntime import RapidOCR
77

88
from wired_table_rec import WiredTableRecognition
9-
from wired_table_rec.main import RapidTableInput, ModelType
9+
from wired_table_rec.main import WiredTableInput
1010
from wired_table_rec.utils.utils import VisTable
1111

1212
output_dir = Path("outputs")
1313
output_dir.mkdir(parents=True, exist_ok=True)
14-
input_args = RapidTableInput(model_type=ModelType.CYCLE_CENTER_NET.value)
14+
input_args = WiredTableInput()
1515
table_engine = WiredTableRecognition(input_args)
1616
ocr_engine = RapidOCR()
1717
viser = VisTable()

lineless_table_rec/main.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
match_ocr_cell,
2525
plot_html_table,
2626
sorted_ocr_boxes,
27+
box_4_1_poly_to_box_4_2,
2728
)
2829

2930

@@ -41,23 +42,23 @@ class ModelType(Enum):
4142

4243

4344
@dataclass
44-
class RapidTableInput:
45+
class LinelessTableInput:
4546
model_type: Optional[str] = ModelType.LORE.value
4647
model_path: Union[str, Path, None, Dict[str, str]] = None
4748
use_cuda: bool = False
4849
device: str = "cpu"
4950

5051

5152
@dataclass
52-
class RapidTableOutput:
53+
class LinelessTableOutput:
5354
pred_html: Optional[str] = None
5455
cell_bboxes: Optional[np.ndarray] = None
5556
logic_points: Optional[np.ndarray] = None
5657
elapse: Optional[float] = None
5758

5859

5960
class LinelessTableRecognition:
60-
def __init__(self, config: RapidTableInput):
61+
def __init__(self, config: LinelessTableInput):
6162
self.model_type = config.model_type
6263
if self.model_type not in KEY_TO_MODEL_URL:
6364
model_list = ",".join(KEY_TO_MODEL_URL)
@@ -78,7 +79,7 @@ def __call__(
7879
content: InputType,
7980
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
8081
**kwargs,
81-
) -> RapidTableOutput:
82+
) -> LinelessTableOutput:
8283
s = time.perf_counter()
8384
rec_again = True
8485
need_ocr = True
@@ -92,7 +93,7 @@ def __call__(
9293
sorted_polygons, idx_list = sorted_ocr_boxes(
9394
[box_4_2_poly_to_box_4_1(box) for box in polygons]
9495
)
95-
return RapidTableOutput(
96+
return LinelessTableOutput(
9697
"",
9798
sorted_polygons,
9899
logi_points[idx_list],
@@ -121,6 +122,10 @@ def __call__(
121122
# 将同一个识别框中的ocr结果排序并同行合并
122123
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
123124
# 渲染为html
125+
polygons = [
126+
box_4_1_poly_to_box_4_2(t_box_ocr["t_box"])
127+
for t_box_ocr in t_rec_ocr_list
128+
]
124129
logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
125130
cell_box_det_map = {
126131
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
@@ -132,13 +137,13 @@ def __call__(
132137
_, idx_list = sorted_ocr_boxes(
133138
[t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list]
134139
)
135-
polygons = polygons.reshape(-1, 8)
140+
polygons = np.array(polygons).reshape(-1, 8)
136141
logi_points = np.array(logi_points)
137142
elapse = time.perf_counter() - s
138143
except Exception:
139144
logging.warning(traceback.format_exc())
140-
return RapidTableOutput("", None, None, 0.0)
141-
return RapidTableOutput(pred_html, polygons, logi_points, elapse)
145+
return LinelessTableOutput("", None, None, 0.0)
146+
return LinelessTableOutput(pred_html, polygons, logi_points, elapse)
142147

143148
def transform_res(
144149
self,

tests/test_lineless_table_rec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66
import pytest
77

8-
from lineless_table_rec.main import RapidTableInput, ModelType
8+
from lineless_table_rec.main import LinelessTableInput, ModelType
99

1010
cur_dir = Path(__file__).resolve().parent
1111
root_dir = cur_dir.parent
@@ -16,7 +16,7 @@
1616
from lineless_table_rec import LinelessTableRecognition
1717

1818
test_file_dir = cur_dir / "test_files"
19-
input_args = RapidTableInput(model_type=ModelType.LORE.value)
19+
input_args = LinelessTableInput(model_type=ModelType.LORE.value)
2020
table_recog = LinelessTableRecognition(input_args)
2121

2222

tests/test_wired_table_rec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bs4 import BeautifulSoup
99
from rapidocr_onnxruntime import RapidOCR
1010

11-
from wired_table_rec.main import RapidTableInput, ModelType
11+
from wired_table_rec.main import WiredTableInput, ModelType
1212
from wired_table_rec.utils.utils import rescale_size
1313
from wired_table_rec.utils.utils_table_recover import (
1414
plot_html_table,
@@ -26,7 +26,7 @@
2626
from wired_table_rec import WiredTableRecognition
2727

2828
test_file_dir = cur_dir / "test_files" / "wired"
29-
input_args = RapidTableInput(model_type=ModelType.UNET.value)
29+
input_args = WiredTableInput(model_type=ModelType.UNET.value)
3030
table_recog = WiredTableRecognition(input_args)
3131
ocr_engine = RapidOCR()
3232

wired_table_rec/main.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,23 @@ class ModelType(Enum):
4141

4242

4343
@dataclass
44-
class RapidTableInput:
44+
class WiredTableInput:
4545
model_type: Optional[str] = ModelType.UNET.value
4646
model_path: Union[str, Path, None, Dict[str, str]] = None
4747
use_cuda: bool = False
4848
device: str = "cpu"
4949

5050

5151
@dataclass
52-
class RapidTableOutput:
52+
class WiredTableOutput:
5353
pred_html: Optional[str] = None
5454
cell_bboxes: Optional[np.ndarray] = None
5555
logic_points: Optional[np.ndarray] = None
5656
elapse: Optional[float] = None
5757

5858

5959
class WiredTableRecognition:
60-
def __init__(self, config: RapidTableInput):
60+
def __init__(self, config: WiredTableInput):
6161
self.model_type = config.model_type
6262
if self.model_type not in KEY_TO_MODEL_URL:
6363
model_list = ",".join(KEY_TO_MODEL_URL)
@@ -85,7 +85,7 @@ def __call__(
8585
img: InputType,
8686
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
8787
**kwargs,
88-
) -> RapidTableOutput:
88+
) -> WiredTableOutput:
8989
s = time.perf_counter()
9090
rec_again = True
9191
need_ocr = True
@@ -100,7 +100,7 @@ def __call__(
100100
polygons, rotated_polygons = self.table_structure(img, **kwargs)
101101
if polygons is None:
102102
logging.warning("polygons is None.")
103-
return RapidTableOutput("", None, None, 0.0)
103+
return WiredTableOutput("", None, None, 0.0)
104104

105105
try:
106106
table_res, logi_points = self.table_recover(
@@ -115,7 +115,7 @@ def __call__(
115115
sorted_polygons, idx_list = sorted_ocr_boxes(
116116
[box_4_2_poly_to_box_4_1(box) for box in polygons]
117117
)
118-
return RapidTableOutput(
118+
return WiredTableOutput(
119119
"",
120120
sorted_polygons,
121121
logi_points[idx_list],
@@ -137,14 +137,14 @@ def __call__(
137137
for i, t_box_ocr in enumerate(t_rec_ocr_list)
138138
}
139139
pred_html = plot_html_table(logi_points, cell_box_det_map)
140-
polygons = polygons.reshape(-1, 8)
140+
polygons = np.array(polygons).reshape(-1, 8)
141141
logi_points = np.array(logi_points)
142142
elapse = time.perf_counter() - s
143143

144144
except Exception:
145145
logging.warning(traceback.format_exc())
146-
return RapidTableOutput("", None, None, 0.0)
147-
return RapidTableOutput(pred_html, polygons, logi_points, elapse)
146+
return WiredTableOutput("", None, None, 0.0)
147+
return WiredTableOutput(pred_html, polygons, logi_points, elapse)
148148

149149
def transform_res(
150150
self,
@@ -276,12 +276,12 @@ def main():
276276
raise ModuleNotFoundError(
277277
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
278278
) from exc
279-
280-
table_rec = WiredTableRecognition()
279+
input_args = WiredTableInput()
280+
table_rec = WiredTableRecognition(input_args)
281281
ocr_result, _ = ocr_engine(args.img_path)
282-
table_str, elapse = table_rec(args.img_path, ocr_result)
283-
print(table_str)
284-
print(f"cost: {elapse:.5f}")
282+
table_results = table_rec(args.img_path, ocr_result)
283+
print(table_results.pred_html)
284+
print(f"cost: {table_results.elapse:.5f}")
285285

286286

287287
if __name__ == "__main__":

0 commit comments

Comments
 (0)