11import copy
22import math
3- from typing import Optional , Dict , Any
3+ from typing import Optional , Dict , Any , Tuple
44
55import cv2
66import numpy as np
77from skimage import measure
8-
8+ import matplotlib . pyplot as plt
99from wired_table_rec .utils import OrtInferSession , resize_img
1010from wired_table_rec .utils_table_line_rec import (
1111 get_table_line ,
@@ -31,22 +31,31 @@ def __init__(self, model_path: Optional[str] = None):
3131
3232 self .session = OrtInferSession (model_path )
3333
34- def __call__ (self , img : np .ndarray , ** kwargs ) -> Optional [np .ndarray ]:
34+ def __call__ (
35+ self , img : np .ndarray , ** kwargs
36+ ) -> Tuple [Optional [np .ndarray ], Optional [np .ndarray ]]:
3537 img_info = self .preprocess (img )
3638 pred = self .infer (img_info )
37- polygons = self .postprocess (img , pred , ** kwargs )
39+ polygons , rotated_polygons = self .postprocess (img , pred , ** kwargs )
3840 if polygons .size == 0 :
39- return None
41+ return None , None
4042 polygons = polygons .reshape (polygons .shape [0 ], 4 , 2 )
4143 polygons [:, 3 , :], polygons [:, 1 , :] = (
4244 polygons [:, 1 , :].copy (),
4345 polygons [:, 3 , :].copy (),
4446 )
47+ rotated_polygons = rotated_polygons .reshape (rotated_polygons .shape [0 ], 4 , 2 )
48+ rotated_polygons [:, 3 , :], rotated_polygons [:, 1 , :] = (
49+ rotated_polygons [:, 1 , :].copy (),
50+ rotated_polygons [:, 3 , :].copy (),
51+ )
4552 _ , idx = sorted_ocr_boxes (
46- [box_4_2_poly_to_box_4_1 (poly_box ) for poly_box in polygons ], threhold = 0.4
53+ [box_4_2_poly_to_box_4_1 (poly_box ) for poly_box in rotated_polygons ],
54+ threhold = 0.4 ,
4755 )
4856 polygons = polygons [idx ]
49- return polygons
57+ rotated_polygons = rotated_polygons [idx ]
58+ return polygons , rotated_polygons
5059
5160 def preprocess (self , img ) -> Dict [str , Any ]:
5261 scale = (self .inp_height , self .inp_width )
@@ -86,7 +95,8 @@ def postprocess(self, img, pred, **kwargs):
8695 extend_line = (
8796 kwargs .get ("extend_line" , enhance_box_line ) if kwargs else enhance_box_line
8897 ) # 是否进行线段延长使得端点连接
89-
98+ # 是否进行旋转修正
99+ rotated_fix = kwargs .get ("rotated_fix" ) if kwargs else True
90100 ori_shape = img .shape
91101 pred = np .uint8 (pred )
92102 hpred = copy .deepcopy (pred ) # 横线
@@ -120,8 +130,109 @@ def postprocess(self, img, pred, **kwargs):
120130 colboxes += rboxes_col_
121131 if extend_line :
122132 rowboxes , colboxes = final_adjust_lines (rowboxes , colboxes )
123- tmp = np .zeros (img .shape [:2 ], dtype = "uint8" )
124- tmp = draw_lines (tmp , rowboxes + colboxes , color = 255 , lineW = 2 )
133+ line_img = np .zeros (img .shape [:2 ], dtype = "uint8" )
134+ line_img = draw_lines (line_img , rowboxes + colboxes , color = 255 , lineW = 2 )
135+ rotated_angle = self .cal_rotate_angle (line_img )
136+ if rotated_fix and abs (rotated_angle ) > 0.3 :
137+ rotated_line_img = self .rotate_image (line_img , rotated_angle )
138+ rotated_polygons = self .cal_region_boxes (rotated_line_img )
139+ polygons = self .unrotate_polygons (
140+ rotated_polygons , rotated_angle , line_img .shape
141+ )
142+ else :
143+ polygons = self .cal_region_boxes (line_img )
144+ rotated_polygons = polygons .copy ()
145+ return polygons , rotated_polygons
146+
147+ def find_max_corners (self , line_img ):
148+ # 找到所有轮廓
149+ contours , _ = cv2 .findContours (
150+ line_img , cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_SIMPLE
151+ )
152+
153+ # 如果没有找到轮廓,返回空列表
154+ if not contours :
155+ return []
156+
157+ # 找到面积最大的轮廓
158+ max_contour = max (contours , key = cv2 .contourArea )
159+ # 计算最大轮廓的最小外接矩形
160+ rect = cv2 .minAreaRect (max_contour )
161+
162+ # 获取最小外接矩形的四个角点
163+ box = cv2 .boxPoints (rect )
164+ box = np .int0 (box )
165+ #
166+ # 对角点进行排序
167+ # 计算中心点
168+ center = np .mean (box , axis = 0 )
169+
170+ # 计算每个点与中心点的角度
171+ angles = np .arctan2 (box [:, 1 ] - center [1 ], box [:, 0 ] - center [0 ])
172+
173+ # 按角度排序
174+ sorted_indices = np .argsort (angles )
175+ sorted_box = box [sorted_indices ]
176+
177+ # 确保顺序为左上、右上、右下、左下
178+ top_left = sorted_box [0 ]
179+ top_right = sorted_box [1 ]
180+ bottom_right = sorted_box [2 ]
181+ bottom_left = sorted_box [3 ]
182+
183+ # 创建一个纯黑色背景图像
184+ black_img = np .zeros_like (line_img )
185+
186+ # 可视化最大轮廓和四个角点
187+ plt .figure (figsize = (10 , 10 ))
188+ plt .imshow (black_img , cmap = "gray" )
189+ plt .title ("Max Contour and Corners on Black Background" )
190+
191+ # 绘制最大轮廓
192+ max_contour = max_contour .reshape (- 1 , 2 )
193+ plt .plot (max_contour [:, 0 ], max_contour [:, 1 ], "b-" , linewidth = 2 )
194+
195+ # 绘制四个角点
196+ plt .scatter (
197+ [top_left [0 ], top_right [0 ], bottom_right [0 ], bottom_left [0 ]],
198+ [top_left [1 ], top_right [1 ], bottom_right [1 ], bottom_left [1 ]],
199+ c = "g" ,
200+ s = 100 ,
201+ marker = "o" ,
202+ )
203+
204+ plt .axis ("off" )
205+ plt .show ()
206+
207+ return [top_left , top_right , bottom_right , bottom_left ]
208+
209+ def extend_image_and_adjust_coordinates (self , img , corners , polygons ):
210+ # 计算扩展边界
211+ min_x = min (point [0 ] for point in corners )
212+ min_y = min (point [1 ] for point in corners )
213+ max_x = max (point [0 ] for point in corners )
214+ max_y = max (point [1 ] for point in corners )
215+
216+ # 计算扩展的宽度和高度
217+ left = - min_x if min_x < 0 else 0
218+ top = - min_y if min_y < 0 else 0
219+ right = max_x - img .shape [1 ] if max_x > img .shape [1 ] else 0
220+ bottom = max_y - img .shape [0 ] if max_y > img .shape [0 ] else 0
221+
222+ # 扩展图像
223+ new_width = img .shape [1 ] + left + right
224+ new_height = img .shape [0 ] + top + bottom
225+ extended_img = np .zeros ((new_height , new_width ), dtype = img .dtype )
226+ extended_img [top : top + img .shape [0 ], left : left + img .shape [1 ]] = img
227+
228+ # 调整角点和多边形坐标
229+ adjusted_corners = [(point [0 ] + left , point [1 ] + top ) for point in corners ]
230+ adjusted_polygons = polygons .copy ()
231+ adjusted_polygons [:, 0 ::2 ] += left
232+ adjusted_polygons [:, 1 ::2 ] += top
233+ return extended_img , adjusted_corners , adjusted_polygons
234+
235+ def cal_region_boxes (self , tmp ):
125236 labels = measure .label (tmp < 255 , connectivity = 2 ) # 8连通区域标记
126237 regions = measure .regionprops (labels )
127238 ceilboxes = min_area_rect_box (
@@ -133,3 +244,52 @@ def postprocess(self, img, pred, **kwargs):
133244 adjust_box = False ,
134245 ) # 最后一个参数改为False
135246 return np .array (ceilboxes )
247+
248+ def cal_rotate_angle (self , tmp ):
249+ # 计算最外侧的旋转框
250+ contours , _ = cv2 .findContours (tmp , cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_SIMPLE )
251+ if not contours :
252+ return 0
253+ largest_contour = max (contours , key = cv2 .contourArea )
254+ rect = cv2 .minAreaRect (largest_contour )
255+ # 计算旋转角度
256+ angle = rect [2 ]
257+ if angle < - 45 :
258+ angle += 90
259+ elif angle > 45 :
260+ angle -= 90
261+ return angle
262+
263+ def rotate_image (self , image , angle ):
264+ # 获取图像的中心点
265+ (h , w ) = image .shape [:2 ]
266+ center = (w // 2 , h // 2 )
267+
268+ # 计算旋转矩阵
269+ M = cv2 .getRotationMatrix2D (center , angle , 1.0 )
270+
271+ # 进行旋转
272+ rotated_image = cv2 .warpAffine (
273+ image , M , (w , h ), flags = cv2 .INTER_NEAREST , borderMode = cv2 .BORDER_REPLICATE
274+ )
275+
276+ return rotated_image
277+
278+ def unrotate_polygons (
279+ self , polygons : np .ndarray , angle : float , img_shape : tuple
280+ ) -> np .ndarray :
281+ # 将多边形旋转回原始位置
282+ (h , w ) = img_shape
283+ center = (w // 2 , h // 2 )
284+ M_inv = cv2 .getRotationMatrix2D (center , - angle , 1.0 )
285+
286+ # 将 (N, 8) 转换为 (N, 4, 2)
287+ polygons_reshaped = polygons .reshape (- 1 , 4 , 2 )
288+
289+ # 批量逆旋转
290+ unrotated_polygons = cv2 .transform (polygons_reshaped , M_inv )
291+
292+ # 将 (N, 4, 2) 转换回 (N, 8)
293+ unrotated_polygons = unrotated_polygons .reshape (- 1 , 8 )
294+
295+ return unrotated_polygons
0 commit comments