1313
1414
1515class TableCls :
16- def __init__ (self , model = "yolo" ):
17- if model == "yolo" :
18- self .table_engine = YoloCls ()
16+ def __init__ (self , model_type = "yolo" , model_path = yolo_cls_model_path ):
17+ if model_type == "yolo" :
18+ self .table_engine = YoloCls (model_path )
1919 else :
20- self .table_engine = QanythingCls ()
20+ model_path = q_cls_model_path
21+ self .table_engine = QanythingCls (model_path )
2122 self .load_img = LoadImage ()
2223
2324 def __call__ (self , content : InputType ):
@@ -30,8 +31,8 @@ def __call__(self, content: InputType):
3031
3132
3233class QanythingCls :
33- def __init__ (self ):
34- self .table_cls = OrtInferSession (q_cls_model_path )
34+ def __init__ (self , model_path ):
35+ self .table_cls = OrtInferSession (model_path )
3536 self .inp_h = 224
3637 self .inp_w = 224
3738 self .mean = np .array ([0.485 , 0.456 , 0.406 ], dtype = np .float32 )
@@ -60,8 +61,8 @@ def __call__(self, img):
6061
6162
6263class YoloCls :
63- def __init__ (self ):
64- self .table_cls = OrtInferSession (yolo_cls_model_path )
64+ def __init__ (self , model_path ):
65+ self .table_cls = OrtInferSession (model_path )
6566 self .cls = {0 : "wireless" , 1 : "wired" }
6667
6768 def preprocess (self , img ):
0 commit comments