Skip to content

Commit b2cba70

Browse files
Added support for pre-trained HuggingFace models
1 parent 2a7513a commit b2cba70

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
from torch import Tensor
3+
4+
# We need to add lib into sys.path, see:
5+
# https://github.com/aws/sagemaker-python-sdk/blob/93af78b2120b33859505f8b26976c1fd243c44b7/src/sagemaker/workflow/_repack_model.py#L79
6+
import os
7+
import sys
8+
sys.path.append(os.path.join(os.path.dirname(__file__), "lib"))
9+
10+
import sagemaker_ssh_helper
11+
sagemaker_ssh_helper.setup_and_start_ssh()
12+
13+
14+
# noinspection DuplicatedCode
15+
class MyModel:
16+
number_t: Tensor = torch.tensor(0)
17+
18+
def __init__(self, number: str):
19+
self.number_t = torch.tensor(int(number))
20+
21+
def predict(self, input_data: Tensor):
22+
# TODO: remove .tolist() and use the same script as for PyTorch
23+
return [(self.number_t + input_data[0]).tolist()]
24+
25+
26+
def model_fn(model_dir):
27+
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
28+
return MyModel(f.readline().decode('latin1'))
29+
30+
31+
def predict_fn(input_data, model):
32+
return model.predict(input_data)
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)