33
44import cv2
55import numpy as np
6- import onnxruntime
76from PIL import Image
87
9- from .utils import InputType , LoadImage
8+ from .utils import InputType , LoadImage , OrtInferSession , ResizePad
109
1110cur_dir = Path (__file__ ).resolve ().parent
12- table_cls_model_path = cur_dir / "models" / "table_cls.onnx"
11+ q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
12+ yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
1313
1414
1515class TableCls :
16- def __init__ (self , device = "cpu" ):
17- providers = (
18- ["CUDAExecutionProvider" ] if device == "cuda" else ["CPUExecutionProvider" ]
19- )
20- self .table_cls = onnxruntime .InferenceSession (
21- table_cls_model_path , providers = providers
22- )
16+ def __init__ (self , model_type = "yolo" , model_path = yolo_cls_model_path ):
17+ if model_type == "yolo" :
18+ self .table_engine = YoloCls (model_path )
19+ else :
20+ model_path = q_cls_model_path
21+ self .table_engine = QanythingCls (model_path )
22+ self .load_img = LoadImage ()
23+
24+ def __call__ (self , content : InputType ):
25+ ss = time .perf_counter ()
26+ img = self .load_img (content )
27+ img = self .table_engine .preprocess (img )
28+ predict_cla = self .table_engine ([img ])
29+ table_elapse = time .perf_counter () - ss
30+ return predict_cla , table_elapse
31+
32+
33+ class QanythingCls :
34+ def __init__ (self , model_path ):
35+ self .table_cls = OrtInferSession (model_path )
2336 self .inp_h = 224
2437 self .inp_w = 224
2538 self .mean = np .array ([0.485 , 0.456 , 0.406 ], dtype = np .float32 )
2639 self .std = np .array ([0.229 , 0.224 , 0.225 ], dtype = np .float32 )
2740 self .cls = {0 : "wired" , 1 : "wireless" }
28- self .load_img = LoadImage ()
2941
30- def _preprocess (self , image ):
31- img = Image .fromarray (np .uint8 (image ))
42+ def preprocess (self , img ):
43+ img = cv2 .cvtColor (img .copy (), cv2 .COLOR_BGR2RGB )
44+ img = cv2 .cvtColor (img , cv2 .COLOR_BGR2GRAY )
45+ img = np .stack ((img ,) * 3 , axis = - 1 )
46+ img = Image .fromarray (np .uint8 (img ))
3247 img = img .resize ((self .inp_h , self .inp_w ))
3348 img = np .array (img , dtype = np .float32 ) / 255.0
3449 img -= self .mean
@@ -37,15 +52,27 @@ def _preprocess(self, image):
3752 img = np .expand_dims (img , axis = 0 ) # Add batch dimension, only one image
3853 return img
3954
40- def __call__ (self , content : InputType ):
41- ss = time .perf_counter ()
42- img = self .load_img (content )
43- gray_img = cv2 .cvtColor (img , cv2 .COLOR_BGR2GRAY )
44- gray_img = np .stack ((gray_img ,) * 3 , axis = - 1 )
45- gray_img = self ._preprocess (gray_img )
46- output = self .table_cls .run (None , {"input" : gray_img })
55+ def __call__ (self , img ):
56+ output = self .table_cls (img )
4757 predict = np .exp (output [0 ] - np .max (output [0 ], axis = 1 , keepdims = True ))
4858 predict /= np .sum (predict , axis = 1 , keepdims = True )
4959 predict_cla = np .argmax (predict , axis = 1 )[0 ]
50- table_elapse = time .perf_counter () - ss
51- return self .cls [predict_cla ], table_elapse
60+ return self .cls [predict_cla ]
61+
62+
63+ class YoloCls :
64+ def __init__ (self , model_path ):
65+ self .table_cls = OrtInferSession (model_path )
66+ self .cls = {0 : "wireless" , 1 : "wired" }
67+
68+ def preprocess (self , img ):
69+ img , * _ = ResizePad (img , 640 )
70+ img = np .array (img , dtype = np .float32 ) / 255.0
71+ img = img .transpose (2 , 0 , 1 ) # HWC to CHW
72+ img = np .expand_dims (img , axis = 0 ) # Add batch dimension, only one image
73+ return img
74+
75+ def __call__ (self , img ):
76+ output = self .table_cls (img )
77+ predict_cla = np .argmax (output [0 ], axis = 1 )[0 ]
78+ return self .cls [predict_cla ]
0 commit comments