@@ -66,57 +66,49 @@ def get_benchmark_cols(
6666 ) -> Tuple [np .ndarray , List [float ], int ]:
6767 longest_col = max (rows .values (), key = lambda x : len (x ))
6868 longest_col_points = polygons [longest_col ]
69- longest_x = longest_col_points [:, 0 , 0 ]
70-
69+ longest_x_start = list (longest_col_points [:, 0 , 0 ])
70+ longest_x_end = list (longest_col_points [:, 2 , 0 ])
71+ min_x = longest_x_start [0 ]
72+ max_x = longest_x_end [- 1 ]
7173 theta = 15
72- for row_value in rows .values ():
73- cur_row = polygons [row_value ][:, 0 , 0 ]
74-
75- range_res = {}
76- for idx , cur_v in enumerate (cur_row ):
77- start_idx , end_idx = None , None
78- for i , v in enumerate (longest_x ):
79- if cur_v - theta <= v <= cur_v + theta :
80- break
8174
82- if cur_v > v :
83- start_idx = i
84- continue
75+ # 根据当前col的起始x坐标,更新col的边界
76+ def update_longest_col (col_x_list , cur_v , min_x_ , max_x_ ):
77+ for i , v in enumerate (col_x_list ):
78+ if cur_v - theta <= v <= cur_v + theta :
79+ break
80+ if cur_v > v :
81+ continue
82+ if cur_v < min_x_ :
83+ col_x_list .insert (0 , cur_v )
84+ min_x_ = cur_v
85+ break
86+ if cur_v > max_x_ :
87+ col_x_list .append (max_x_ )
88+ max_x_ = cur_v
89+ if cur_v < v :
90+ col_x_list .insert (i , cur_v )
91+ break
92+ return min_x_ , max_x_
8593
86- if cur_v < v :
87- end_idx = i
88- break
94+ for row_value in rows .values ():
95+ cur_row_start = list (polygons [row_value ][:, 0 , 0 ])
96+ cur_row_end = list (polygons [row_value ][:, 2 , 0 ])
97+ for idx , (cur_v_start , cur_v_end ) in enumerate (
98+ zip (cur_row_start , cur_row_end )
99+ ):
100+ min_x , max_x = update_longest_col (
101+ longest_x_start , cur_v_start , min_x , max_x
102+ )
103+ min_x , max_x = update_longest_col (
104+ longest_x_start , cur_v_end , min_x , max_x
105+ )
89106
90- range_res [idx ] = [start_idx , end_idx ]
91-
92- sorted_res = dict (
93- sorted (range_res .items (), key = lambda x : x [0 ], reverse = True )
94- )
95- for k , v in sorted_res .items ():
96- # bugfix: https://github.com/RapidAI/TableStructureRec/discussions/55
97- # 最长列不包含第一列和最后一列的场景需要兼容
98- if all (v ) or v [1 ] == 0 :
99- longest_x = np .insert (longest_x , v [1 ], cur_row [k ])
100- longest_col_points = np .insert (
101- longest_col_points , v [1 ], polygons [row_value [k ]], axis = 0
102- )
103- elif v [0 ] and v [0 ] + 1 == len (longest_x ):
104- longest_x = np .append (longest_x , cur_row [k ])
105- longest_col_points = np .append (
106- longest_col_points ,
107- polygons [row_value [k ]][np .newaxis , :, :],
108- axis = 0 ,
109- )
110- # 求出最右侧所有cell的宽,其中最小的作为最后一列宽度
111- rightmost_idxs = [v [- 1 ] for v in rows .values ()]
112- rightmost_boxes = polygons [rightmost_idxs ]
113- min_width = min ([self .compute_L2 (v [3 , :], v [0 , :]) for v in rightmost_boxes ])
114-
115- each_col_widths = (longest_x [1 :] - longest_x [:- 1 ]).tolist ()
116- each_col_widths .append (min_width )
117-
118- col_nums = longest_x .shape [0 ]
119- return longest_col_points , each_col_widths , col_nums
107+ longest_x_start = np .array (longest_x_start )
108+ each_col_widths = (longest_x_start [1 :] - longest_x_start [:- 1 ]).tolist ()
109+ each_col_widths .append (max_x - longest_x_start [- 1 ])
110+ col_nums = longest_x_start .shape [0 ]
111+ return longest_x_start , each_col_widths , col_nums
120112
121113 def get_benchmark_rows (
122114 self , rows : Dict [int , List ], polygons : np .ndarray
@@ -160,7 +152,7 @@ def get_merge_cells(
160152 box_width = self .compute_L2 (box [3 , :], box [0 , :])
161153
162154 # 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
163- loc_col_idx = np .argmin (np .abs (longest_col [:, 0 , 0 ] - box [0 , 0 ]))
155+ loc_col_idx = np .argmin (np .abs (longest_col - box [0 , 0 ]))
164156 col_start = max (sum (one_col_result .values ()), loc_col_idx )
165157
166158 # 计算合并多少个列方向单元格
0 commit comments