Skip to content

Commit a4339fc

Browse files
committed
fix: adjust resize mode
1 parent 29c965f commit a4339fc

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

table_cls/main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
import numpy as np
66
from PIL import Image
77

8-
from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop
8+
from .utils import InputType, LoadImage, OrtInferSession
99

1010
cur_dir = Path(__file__).resolve().parent
1111
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
1212
yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
13+
yolo_cls_x_model_path = cur_dir / "models" / "yolo_cls_x.onnx"
1314

1415

1516
class TableCls:
1617
def __init__(self, model_type="yolo", model_path=yolo_cls_model_path):
1718
if model_type == "yolo":
1819
self.table_engine = YoloCls(model_path)
20+
elif model_type == "yolox":
21+
self.table_engine = YoloCls(yolo_cls_x_model_path)
1922
else:
2023
model_path = q_cls_model_path
2124
self.table_engine = QanythingCls(model_path)
@@ -64,15 +67,11 @@ class YoloCls:
6467
def __init__(self, model_path):
6568
self.table_cls = OrtInferSession(model_path)
6669
self.cls = {0: "wireless", 1: "wired"}
67-
self.mean = np.array([0, 0, 0], dtype=np.float32)
68-
self.std = np.array([1, 1, 1], dtype=np.float32)
6970

7071
def preprocess(self, img):
7172
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
72-
img = resize_and_center_crop(img, 640)
73+
img = cv2.resize(img, (640, 640))
7374
img = np.array(img, dtype=np.float32) / 255
74-
img -= self.mean
75-
img /= self.std
7675
img = img.transpose(2, 0, 1) # HWC to CHW
7776
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
7877
return img

0 commit comments

Comments
 (0)