33
44import cv2
55import numpy as np
6- import onnxruntime
76from PIL import Image
87
9- from .utils import InputType , LoadImage
8+ from .utils import InputType , LoadImage , OrtInferSession , ResizePad
109
1110cur_dir = Path (__file__ ).resolve ().parent
12- table_cls_model_path = cur_dir / "models" / "table_cls.onnx"
11+ q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
12+ yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
1313
1414
1515class TableCls :
16- def __init__ (self , device = "cpu" ):
17- providers = (
18- ["CUDAExecutionProvider" ] if device == "cuda" else ["CPUExecutionProvider" ]
19- )
20- self .table_cls = onnxruntime .InferenceSession (
21- table_cls_model_path , providers = providers
22- )
16+ def __init__ (self , model = "yolo" ):
17+ if model == "yolo" :
18+ self .table_engine = YoloCls ()
19+ else :
20+ self .table_engine = QanythingCls ()
21+ self .load_img = LoadImage ()
22+
23+ def __call__ (self , content : InputType ):
24+ ss = time .perf_counter ()
25+ img = self .load_img (content )
26+ img = self .table_engine .preprocess (img )
27+ predict_cla = self .table_engine ([img ])
28+ table_elapse = time .perf_counter () - ss
29+ return predict_cla , table_elapse
30+
31+
32+ class QanythingCls :
33+ def __init__ (self ):
34+ self .table_cls = OrtInferSession (q_cls_model_path )
2335 self .inp_h = 224
2436 self .inp_w = 224
2537 self .mean = np .array ([0.485 , 0.456 , 0.406 ], dtype = np .float32 )
2638 self .std = np .array ([0.229 , 0.224 , 0.225 ], dtype = np .float32 )
2739 self .cls = {0 : "wired" , 1 : "wireless" }
28- self .load_img = LoadImage ()
2940
30- def _preprocess (self , image ):
31- img = Image .fromarray (np .uint8 (image ))
41+ def preprocess (self , img ):
42+ img = cv2 .cvtColor (img .copy (), cv2 .COLOR_BGR2RGB )
43+ img = cv2 .cvtColor (img , cv2 .COLOR_BGR2GRAY )
44+ img = np .stack ((img ,) * 3 , axis = - 1 )
45+ img = Image .fromarray (np .uint8 (img ))
3246 img = img .resize ((self .inp_h , self .inp_w ))
3347 img = np .array (img , dtype = np .float32 ) / 255.0
3448 img -= self .mean
@@ -37,15 +51,27 @@ def _preprocess(self, image):
3751 img = np .expand_dims (img , axis = 0 ) # Add batch dimension, only one image
3852 return img
3953
40- def __call__ (self , content : InputType ):
41- ss = time .perf_counter ()
42- img = self .load_img (content )
43- gray_img = cv2 .cvtColor (img , cv2 .COLOR_BGR2GRAY )
44- gray_img = np .stack ((gray_img ,) * 3 , axis = - 1 )
45- gray_img = self ._preprocess (gray_img )
46- output = self .table_cls .run (None , {"input" : gray_img })
54+ def __call__ (self , img ):
55+ output = self .table_cls (img )
4756 predict = np .exp (output [0 ] - np .max (output [0 ], axis = 1 , keepdims = True ))
4857 predict /= np .sum (predict , axis = 1 , keepdims = True )
4958 predict_cla = np .argmax (predict , axis = 1 )[0 ]
50- table_elapse = time .perf_counter () - ss
51- return self .cls [predict_cla ], table_elapse
59+ return self .cls [predict_cla ]
60+
61+
62+ class YoloCls :
63+ def __init__ (self ):
64+ self .table_cls = OrtInferSession (yolo_cls_model_path )
65+ self .cls = {0 : "wireless" , 1 : "wired" }
66+
67+ def preprocess (self , img ):
68+ img , * _ = ResizePad (img , 640 )
69+ img = np .array (img , dtype = np .float32 ) / 255.0
70+ img = img .transpose (2 , 0 , 1 ) # HWC to CHW
71+ img = np .expand_dims (img , axis = 0 ) # Add batch dimension, only one image
72+ return img
73+
74+ def __call__ (self , img ):
75+ output = self .table_cls (img )
76+ predict_cla = np .argmax (output [0 ], axis = 1 )[0 ]
77+ return self .cls [predict_cla ]
0 commit comments