|
5 | 5 | import numpy as np |
6 | 6 | from PIL import Image |
7 | 7 |
|
8 | | -from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop |
| 8 | +from .utils import InputType, LoadImage, OrtInferSession |
9 | 9 |
|
10 | 10 | cur_dir = Path(__file__).resolve().parent |
11 | 11 | q_cls_model_path = cur_dir / "models" / "table_cls.onnx" |
12 | 12 | yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx" |
| 13 | +yolo_cls_x_model_path = cur_dir / "models" / "yolo_cls_x.onnx" |
13 | 14 |
|
14 | 15 |
|
15 | 16 | class TableCls: |
16 | 17 | def __init__(self, model_type="yolo", model_path=yolo_cls_model_path): |
17 | 18 | if model_type == "yolo": |
18 | 19 | self.table_engine = YoloCls(model_path) |
| 20 | + elif model_type == "yolox": |
| 21 | + self.table_engine = YoloCls(yolo_cls_x_model_path) |
19 | 22 | else: |
20 | 23 | model_path = q_cls_model_path |
21 | 24 | self.table_engine = QanythingCls(model_path) |
@@ -64,15 +67,11 @@ class YoloCls: |
64 | 67 | def __init__(self, model_path): |
65 | 68 | self.table_cls = OrtInferSession(model_path) |
66 | 69 | self.cls = {0: "wireless", 1: "wired"} |
67 | | - self.mean = np.array([0, 0, 0], dtype=np.float32) |
68 | | - self.std = np.array([1, 1, 1], dtype=np.float32) |
69 | 70 |
|
70 | 71 | def preprocess(self, img): |
71 | 72 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
72 | | - img = resize_and_center_crop(img, 640) |
| 73 | + img = cv2.resize(img, (640, 640)) |
73 | 74 | img = np.array(img, dtype=np.float32) / 255 |
74 | | - img -= self.mean |
75 | | - img /= self.std |
76 | 75 | img = img.transpose(2, 0, 1) # HWC to CHW |
77 | 76 | img = np.expand_dims(img, axis=0) # Add batch dimension, only one image |
78 | 77 | return img |
|
0 commit comments