Skip to content

Commit 91c37e0

Browse files
Preparing DJL support, added support for SageMaker defaults, improved MME support for TF, added support for pre-trained HuggingFace models
1 parent 89346f9 commit 91c37e0

File tree

3 files changed

+304
-139
lines changed

3 files changed

+304
-139
lines changed

.idea/vcs.xml

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Original from:
2+
# https://github.com/aws/amazon-sagemaker-examples/blob/main/advanced_functionality/pytorch_deploy_large_GPT_model/GPT-J-6B-model-parallel-inference-DJL.ipynb
3+
import logging
4+
# We need to add lib into sys.path, see:
5+
# https://github.com/aws/sagemaker-python-sdk/blob/93af78b2120b33859505f8b26976c1fd243c44b7/src/sagemaker/workflow/_repack_model.py#L79
6+
import os
7+
import sys
8+
sys.path.append(os.path.join(os.path.dirname(__file__), "lib"))
9+
10+
import sagemaker_ssh_helper
11+
sagemaker_ssh_helper.setup_and_start_ssh()
12+
13+
from djl_python import Input, Output
14+
import os
15+
import deepspeed
16+
import torch
17+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
18+
19+
predictor = None
20+
21+
22+
def get_model():
23+
model_name = 'EleutherAI/gpt-j-6B'
24+
tensor_parallel = int(os.getenv('TENSOR_PARALLEL_DEGREE', '1'))
25+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
26+
logging.info(f"Loading model with tensor_parallel={tensor_parallel} and local_rank={local_rank}")
27+
model = AutoModelForCausalLM.from_pretrained(model_name, revision="float32", torch_dtype=torch.float32)
28+
tokenizer = AutoTokenizer.from_pretrained(model_name)
29+
30+
# #033[33mWARN #033[m #033[92mPyProcess#033[m [1,0]<stderr>:The model 'InferenceEngine' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegatronBertForCausalLM', 'MvpForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'Speech2Text2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel'].
31+
32+
model = deepspeed.init_inference(model,
33+
mp_size=tensor_parallel,
34+
dtype=model.dtype,
35+
replace_method='auto',
36+
replace_with_kernel_inject=True)
37+
generator = pipeline(task='text-generation', model=model, tokenizer=tokenizer, device=local_rank)
38+
return generator
39+
40+
41+
def handle(inputs: Input) -> None:
42+
global predictor
43+
if not predictor:
44+
predictor = get_model()
45+
46+
if inputs.is_empty():
47+
# Model server makes an empty call to warmup the model on startup
48+
return None
49+
50+
import subprocess
51+
# Take the command from Python Debug Server dialog in PyCharm
52+
subprocess.check_call("pip install pydevd-pycharm~=222.4459.20".split())
53+
54+
# Next command is the patch for https://youtrack.jetbrains.com/issue/PY-40552
55+
subprocess.check_call("sed -i~ -e s~s.replace~str(s).replace~ "
56+
"/usr/local/lib/python3.9/dist-packages/_pydevd_bundle/pydevd_xml.py".split())
57+
58+
logging.info("Connecting to remote debug server")
59+
import pydevd_pycharm
60+
pydevd_pycharm.settrace('127.0.0.1', port=12345, stdoutToServer=True, stderrToServer=True)
61+
logging.info("Connection complete")
62+
63+
data = inputs.get_as_string()
64+
result = predictor(data, do_sample=True, min_tokens=200, max_new_tokens=256)
65+
return Output().add(result)
66+
67+
68+
if __name__ == '__main__':
69+
logging.basicConfig(stream=sys.stdout,
70+
format="%(message)s",
71+
level=logging.INFO)
72+
predictor = get_model()
73+
result = predictor("Hello world!", do_sample=True, min_tokens=200, max_new_tokens=256)
74+
print(result)
75+
sys.exit(0)

0 commit comments

Comments
 (0)