@@ -41,23 +41,23 @@ class ModelType(Enum):
4141
4242
4343@dataclass
44- class RapidTableInput :
44+ class WiredTableInput :
4545 model_type : Optional [str ] = ModelType .UNET .value
4646 model_path : Union [str , Path , None , Dict [str , str ]] = None
4747 use_cuda : bool = False
4848 device : str = "cpu"
4949
5050
5151@dataclass
52- class RapidTableOutput :
52+ class WiredTableOutput :
5353 pred_html : Optional [str ] = None
5454 cell_bboxes : Optional [np .ndarray ] = None
5555 logic_points : Optional [np .ndarray ] = None
5656 elapse : Optional [float ] = None
5757
5858
5959class WiredTableRecognition :
60- def __init__ (self , config : RapidTableInput ):
60+ def __init__ (self , config : WiredTableInput ):
6161 self .model_type = config .model_type
6262 if self .model_type not in KEY_TO_MODEL_URL :
6363 model_list = "," .join (KEY_TO_MODEL_URL )
@@ -85,7 +85,7 @@ def __call__(
8585 img : InputType ,
8686 ocr_result : Optional [List [Union [List [List [float ]], str , str ]]] = None ,
8787 ** kwargs ,
88- ) -> RapidTableOutput :
88+ ) -> WiredTableOutput :
8989 s = time .perf_counter ()
9090 rec_again = True
9191 need_ocr = True
@@ -100,7 +100,7 @@ def __call__(
100100 polygons , rotated_polygons = self .table_structure (img , ** kwargs )
101101 if polygons is None :
102102 logging .warning ("polygons is None." )
103- return RapidTableOutput ("" , None , None , 0.0 )
103+ return WiredTableOutput ("" , None , None , 0.0 )
104104
105105 try :
106106 table_res , logi_points = self .table_recover (
@@ -115,7 +115,7 @@ def __call__(
115115 sorted_polygons , idx_list = sorted_ocr_boxes (
116116 [box_4_2_poly_to_box_4_1 (box ) for box in polygons ]
117117 )
118- return RapidTableOutput (
118+ return WiredTableOutput (
119119 "" ,
120120 sorted_polygons ,
121121 logi_points [idx_list ],
@@ -137,14 +137,14 @@ def __call__(
137137 for i , t_box_ocr in enumerate (t_rec_ocr_list )
138138 }
139139 pred_html = plot_html_table (logi_points , cell_box_det_map )
140- polygons = polygons .reshape (- 1 , 8 )
140+ polygons = np . array ( polygons ) .reshape (- 1 , 8 )
141141 logi_points = np .array (logi_points )
142142 elapse = time .perf_counter () - s
143143
144144 except Exception :
145145 logging .warning (traceback .format_exc ())
146- return RapidTableOutput ("" , None , None , 0.0 )
147- return RapidTableOutput (pred_html , polygons , logi_points , elapse )
146+ return WiredTableOutput ("" , None , None , 0.0 )
147+ return WiredTableOutput (pred_html , polygons , logi_points , elapse )
148148
149149 def transform_res (
150150 self ,
@@ -276,12 +276,12 @@ def main():
276276 raise ModuleNotFoundError (
277277 "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
278278 ) from exc
279-
280- table_rec = WiredTableRecognition ()
279+ input_args = WiredTableInput ()
280+ table_rec = WiredTableRecognition (input_args )
281281 ocr_result , _ = ocr_engine (args .img_path )
282- table_str , elapse = table_rec (args .img_path , ocr_result )
283- print (table_str )
284- print (f"cost: { elapse :.5f} " )
282+ table_results = table_rec (args .img_path , ocr_result )
283+ print (table_results . pred_html )
284+ print (f"cost: { table_results . elapse :.5f} " )
285285
286286
287287if __name__ == "__main__" :
0 commit comments