Skip to content

Commit a8c1207

Browse files
committed
Modernized the fine-tuning job
1 parent 150fa35 commit a8c1207

File tree

8 files changed

+475
-500
lines changed

8 files changed

+475
-500
lines changed

workshops/fine-tuning-with-sagemakerai-and-bedrock/task_05_fmops/05.00_fmops_examples.ipynb

Lines changed: 48 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@
102102
"outputs": [],
103103
"source": [
104104
"sagemaker_session = sagemaker.session.Session()\n",
105-
"role = sagemaker.get_execution_role()"
105+
"role = sagemaker.get_execution_role()\n",
106+
"region = sagemaker_session.boto_session.region_name"
106107
]
107108
},
108109
{
@@ -114,7 +115,7 @@
114115
"\n",
115116
"We define appropriate paths in S3 to store model files, define the model we will be working with, and define the model endpoint name.\n",
116117
"\n",
117-
"In this lab, we are working with [DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B). It is easy to fine-tune as we will see in the next lab, and is small enough to fit on a reasonably sized GPU-accelerated hosting endpoint."
118+
"In this lab, we are working with [Qwen3-4B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507). It is easy to fine-tune as we will see in the next lab, and is small enough to fit on a reasonably sized GPU-accelerated hosting endpoint."
118119
]
119120
},
120121
{
@@ -138,12 +139,14 @@
138139
"metadata": {},
139140
"outputs": [],
140141
"source": [
141-
"model_id = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
142+
"model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n",
142143
"model_id_filesafe = model_id.replace(\"/\",\"_\").replace(\".\", \"_\")\n",
143144
"model_name_safe = model_id.split('/')[-1].replace('.', '-').replace('_', '-')\n",
144145
"endpoint_name = f\"Example-{model_name_safe}\"\n",
145146
"instance_count = 1\n",
146-
"instance_type = \"ml.g5.2xlarge\""
147+
"instance_type = \"ml.g5.2xlarge\"\n",
148+
"health_check_timeout = 1800\n",
149+
"data_download_timeout = 3600"
147150
]
148151
},
149152
{
@@ -170,15 +173,14 @@
170173
"source": [
171174
"mlflow_tracking_server_arn = \"<REPLACE WITH YOUR ARN>\"\n",
172175
"\n",
173-
"if not mlflow_tracking_server_arn:\n",
174-
" try:\n",
175-
" response = boto3.client('sagemaker').describe_mlflow_tracking_server(\n",
176-
" TrackingServerName='genai-mlflow-tracker'\n",
177-
" )\n",
178-
" mlflow_tracking_server_arn = response['TrackingServerArn']\n",
179-
" print(f\"MLflow Tracking Server ARN: {mlflow_tracking_server_arn}\")\n",
180-
" except botocore.exceptions.ClientError:\n",
181-
" print(\"No MLflow Tracking Server Found, please input a value for mlflow_tracking_server_arn\")\n",
176+
"try:\n",
177+
" response = boto3.client('sagemaker').describe_mlflow_tracking_server(\n",
178+
" TrackingServerName='genai-mlflow-tracker'\n",
179+
" )\n",
180+
" mlflow_tracking_server_arn = response['TrackingServerArn']\n",
181+
" print(f\"MLflow Tracking Server ARN: {mlflow_tracking_server_arn}\")\n",
182+
"except botocore.exceptions.ClientError:\n",
183+
" print(\"No MLflow Tracking Server Found, please input a value for mlflow_tracking_server_arn\")\n",
182184
"\n",
183185
"os.environ[\"mlflow_tracking_server_arn\"] = mlflow_tracking_server_arn"
184186
]
@@ -188,7 +190,7 @@
188190
"metadata": {},
189191
"source": [
190192
"### 4. Model Deployment\n",
191-
"There are several approaches to deploying a model to a SageMaker AI managed endpoint. In this section, we explore the most direct option which downloads a model directly from HuggingFace to the managed endpoint via SageMaker JumpStart. We are still using DeepSeek-R1-Distill-Llama-8B, but we have not fine-tuned it. The purpose of this section is to illustrate the components required to customize a model deployment on SageMaker before fine-tuning it."
193+
"There are several approaches to deploying a model to a SageMaker AI managed endpoint. In this section, we explore the most direct option which downloads a model directly from HuggingFace to the managed endpoint via SageMaker JumpStart. We are still using Qwen3-4B-Instruct-2507, but we have not fine-tuned it. The purpose of this section is to illustrate the components required to customize a model deployment on SageMaker before fine-tuning it."
192194
]
193195
},
194196
{
@@ -207,12 +209,8 @@
207209
"metadata": {},
208210
"outputs": [],
209211
"source": [
210-
"# Create and deploy model\n",
211-
"image_uri = sagemaker.image_uris.retrieve(\n",
212-
" framework=\"djl-lmi\",\n",
213-
" region=sagemaker_session.boto_session.region_name,\n",
214-
" version=\"latest\"\n",
215-
")"
212+
"inference_image_uri = f\"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.33.0-lmi15.0.0-cu128\"\n",
213+
"print(f\"using image to host: {inference_image_uri}\")"
216214
]
217215
},
218216
{
@@ -242,7 +240,7 @@
242240
" 'OPTION_MAX_MODEL_LEN': '4096'\n",
243241
"}\n",
244242
"model = HuggingFaceModel(\n",
245-
" image_uri=image_uri,\n",
243+
" image_uri=inference_image_uri,\n",
246244
" env=model_config,\n",
247245
" role=role\n",
248246
")"
@@ -276,8 +274,6 @@
276274
"with mlflow.start_run(run_name=\"example_model_deployment\"):\n",
277275
" deployment_start_time = time.time()\n",
278276
"\n",
279-
" health_check_timeout = 1800\n",
280-
" data_download_timeout = 3600\n",
281277
"\n",
282278
" # Log deployment parameters\n",
283279
" mlflow.log_params({\n",
@@ -297,7 +293,7 @@
297293
" instance_type=instance_type,\n",
298294
" container_startup_health_check_timeout=health_check_timeout,\n",
299295
" model_data_download_timeout=data_download_timeout,\n",
300-
" endpoint_name=endpoint_name\n",
296+
" endpoint_name=f\"{endpoint_name}\"\n",
301297
" )\n",
302298
"\n",
303299
" # Log deployment metrics\n",
@@ -339,7 +335,7 @@
339335
"from sagemaker.deserializers import JSONDeserializer\n",
340336
"\n",
341337
"predictor = Predictor(\n",
342-
" endpoint_name=endpoint_name,\n",
338+
" endpoint_name=f\"{endpoint_name}\",\n",
343339
" serializer=JSONSerializer(),\n",
344340
" deserializer=JSONDeserializer()\n",
345341
")\n",
@@ -436,7 +432,7 @@
436432
"metadata": {},
437433
"source": [
438434
"### 5. Qualitative Model Evaluation\n",
439-
"Let's test the default DeepSeek-R1-Distill-Llama-8B using MLFlow's LLM-as-a-Judge capability. We'll use [Anthropic's Claude 3 Haiku](https://www.anthropic.com/news/claude-3-haiku) model on [Amazon Bedrock](https://aws.amazon.com/bedrock/) as the judge. We'll also wrap our model endpoint invocation in a method making it easier to call in the evaluation. \n",
435+
"Let's test the default Qwen3-4B-Instruct-2507 using MLFlow's LLM-as-a-Judge capability. We'll use [Anthropic's Claude 3 Haiku](https://www.anthropic.com/news/claude-3-haiku) model on [Amazon Bedrock](https://aws.amazon.com/bedrock/) as the judge. We'll also wrap our model endpoint invocation in a method making it easier to call in the evaluation. \n",
440436
"\n",
441437
"This particular endpoint is the [cross-region inference endpoint](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html) name for Claude 3 Haiku.\n",
442438
"\n",
@@ -498,7 +494,7 @@
498494
"cell_type": "markdown",
499495
"metadata": {},
500496
"source": [
501-
"Now use Managed MLFlow 3.0 on Amazon SageMaker AI's `EvaluationExample` object to provide examples of good and bad model responses. This synthetic data will be used to evaluate our Example DeepSeek-R1_Distill_Llama-8B along several qualitative metrics. We create these qualitative metrics using `make_genai_metric`."
497+
"Now use Managed MLFlow 3.0 on Amazon SageMaker AI's `EvaluationExample` object to provide examples of good and bad model responses. This synthetic data will be used to evaluate our Example Qwen3-4B-Instruct-2507 along with several qualitative metrics. We create these qualitative metrics using `make_genai_metric`."
502498
]
503499
},
504500
{
@@ -914,7 +910,7 @@
914910
"cell_type": "markdown",
915911
"metadata": {},
916912
"source": [
917-
"In the next workshop we fine-tune DeepSeek-R1-Distill-Llama-8B to become a medical expert. To accomplish this, we execute a fine-tuning job using Managed MLflow on SageMaker AI. We get our data from the [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) dataset, available on HuggingFace.\n",
913+
"In the next workshop we fine-tune Qwen3-4B-Instruct-2507 to become a medical expert. To accomplish this, we execute a fine-tuning job using Managed MLflow on SageMaker AI. We get our data from the [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) dataset, available on HuggingFace.\n",
918914
"\n",
919915
"In this lab, we show a small example of what fine-tuning looks like for a single record of the dataset."
920916
]
@@ -931,56 +927,25 @@
931927
" \"Response\": \"Cystometry in this case of stress urinary incontinence would most likely reveal a normal post-void residual volume, as stress incontinence typically does not involve issues with bladder emptying. Additionally, since stress urinary incontinence is primarily related to physical exertion and not an overactive bladder, you would not expect to see any involuntary detrusor contractions during the test.\"\n",
932928
"}\n",
933929
"\n",
930+
"SYSTEM_PROMPT = \"\"\"You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \n",
931+
"Below is an instruction that describes a task, paired with an input that provides further context. \n",
932+
"Write a response that appropriately completes the request.\n",
933+
"Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\"\"\"\n",
934934
"\n",
935-
"PROMPT_TEMPLATE = \"\"\"\n",
936-
"<|begin_of_text|>\n",
937-
" <|start_header_id|>system<|end_header_id|>\n",
938-
" You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \n",
939-
" Below is an instruction that describes a task, paired with an input that provides further context. \n",
940-
" Write a response that appropriately completes the request.\n",
941-
" Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\n",
942-
" <|eot_id|>\n",
943-
" <|start_header_id|>user<|end_header_id|>\n",
944-
" {{question}}\n",
945-
" <|eot_id|>\n",
946-
" <|start_header_id|>assistant<|end_header_id|>\n",
947-
" {{complex_cot}}\n",
948-
" {{answer}}\n",
949-
"<|eot_id|>\n",
950-
"\"\"\"\n",
935+
"def convert_to_messages(sample, system_prompt=\"\"):\n",
936+
" \n",
937+
" messages = [\n",
938+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
939+
" {\"role\": \"user\", \"content\": sample[\"Question\"]},\n",
940+
" {\"role\": \"assistant\", \"content\": f\"{sample[\"Complex_CoT\"]}\\n\\n{sample[\"Response\"]}\"}\n",
941+
" ]\n",
942+
"\n",
943+
" sample[\"messages\"] = messages\n",
944+
" \n",
945+
" return sample\n",
951946
"\n",
952-
"# Template dataset to add prompt to each sample\n",
953-
"def template_dataset(sample):\n",
954-
" try:\n",
955-
" sample[\"text\"] = PROMPT_TEMPLATE.format(question=sample[\"Question\"],\n",
956-
" complex_cot=sample[\"Complex_CoT\"],\n",
957-
" answer=sample[\"Response\"])\n",
958-
" return sample\n",
959-
" except KeyError as e:\n",
960-
" print(f\"KeyError in template_dataset: {str(e)}\")\n",
961-
" # Provide default values for missing fields\n",
962-
" missing_key = str(e).strip(\"'\")\n",
963-
" if missing_key == \"Question\":\n",
964-
" sample[\"text\"] = PROMPT_TEMPLATE.format(\n",
965-
" question=\"[Missing question]\",\n",
966-
" complex_cot=sample.get(\"Complex_CoT\", \"[Missing CoT]\"),\n",
967-
" answer=sample.get(\"Response\", \"[Missing response]\")\n",
968-
" )\n",
969-
" elif missing_key == \"Complex_CoT\":\n",
970-
" sample[\"text\"] = PROMPT_TEMPLATE.format(\n",
971-
" question=sample[\"Question\"],\n",
972-
" complex_cot=\"[Missing CoT]\",\n",
973-
" answer=sample.get(\"Response\", \"[Missing response]\")\n",
974-
" )\n",
975-
" elif missing_key == \"Response\":\n",
976-
" sample[\"text\"] = PROMPT_TEMPLATE.format(\n",
977-
" question=sample[\"Question\"],\n",
978-
" complex_cot=sample.get(\"Complex_CoT\", \"[Missing CoT]\"),\n",
979-
" answer=\"[Missing response]\"\n",
980-
" )\n",
981-
" return sample\n",
982947
"\n",
983-
"PROCESSED_SAMPLE = template_dataset(FINE_TUNING_DATA_SAMPLE)\n",
948+
"PROCESSED_SAMPLE = convert_to_messages(FINE_TUNING_DATA_SAMPLE)\n",
984949
"print(PROCESSED_SAMPLE)"
985950
]
986951
},
@@ -1097,8 +1062,15 @@
10971062
"4. Creating and applying Guardrails to our model\n",
10981063
"5. Tracing model calls using MLFlow tracing\n",
10991064
"\n",
1100-
"Next, we show how to actually perform fine-tuning on this DeepSeek model to improve the model's performance in this domain. Moreover, we'll orchestrate all of these steps into a fine-tuning pipeline powered by Managed MLFlow and SageMaker AI Pipelines."
1065+
"Next, we show how to actually perform fine-tuning on this Qwen3 model to improve the model's performance in this domain. Moreover, we'll orchestrate all of these steps into a fine-tuning pipeline powered by Managed MLFlow and SageMaker AI Pipelines."
11011066
]
1067+
},
1068+
{
1069+
"cell_type": "code",
1070+
"execution_count": null,
1071+
"metadata": {},
1072+
"outputs": [],
1073+
"source": []
11021074
}
11031075
],
11041076
"metadata": {

workshops/fine-tuning-with-sagemakerai-and-bedrock/task_05_fmops/05.01_fine-tuning-pipeline.ipynb

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@
114114
"sagemaker_session = sagemaker.session.Session()\n",
115115
"role = sagemaker.get_execution_role()\n",
116116
"instance_type = \"ml.m5.xlarge\"\n",
117-
"pipeline_name = \"AIM405-deepseek-finetune-pipeline\"\n",
117+
"pipeline_name = \"AIM405-qwen3-finetune-pipeline\"\n",
118118
"bucket_name = sagemaker_session.default_bucket()\n",
119119
"default_prefix = sagemaker_session.default_bucket_prefix\n",
120120
"if default_prefix:\n",
121121
" input_path = f'{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft'\n",
122122
"else:\n",
123123
" input_path = f'datasets/llm-fine-tuning-modeltrainer-sft'\n",
124124
"\n",
125-
"model_id = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
125+
"model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n",
126126
"model_id_filesafe = model_id.replace(\"/\",\"_\").replace(\".\", \"_\")"
127127
]
128128
},
@@ -157,15 +157,14 @@
157157
"source": [
158158
"mlflow_tracking_server_arn = \"<REPLACE WITH YOUR ARN>\"\n",
159159
"\n",
160-
"if not mlflow_tracking_server_arn:\n",
161-
" try:\n",
162-
" response = boto3.client('sagemaker').describe_mlflow_tracking_server(\n",
163-
" TrackingServerName='genai-mlflow-tracker'\n",
164-
" )\n",
165-
" mlflow_tracking_server_arn = response['TrackingServerArn']\n",
166-
" print(f\"MLflow Tracking Server ARN: {mlflow_tracking_server_arn}\")\n",
167-
" except ClientError:\n",
168-
" print(\"No MLflow Tracking Server Found, please input a value for mlflow_tracking_server_arn\")\n",
160+
"try:\n",
161+
" response = boto3.client('sagemaker').describe_mlflow_tracking_server(\n",
162+
" TrackingServerName='genai-mlflow-tracker'\n",
163+
" )\n",
164+
" mlflow_tracking_server_arn = response['TrackingServerArn']\n",
165+
" print(f\"MLflow Tracking Server ARN: {mlflow_tracking_server_arn}\")\n",
166+
"except ClientError:\n",
167+
" print(\"No MLflow Tracking Server Found, please input a value for mlflow_tracking_server_arn\")\n",
169168
"\n",
170169
"os.environ[\"mlflow_tracking_server_arn\"] = mlflow_tracking_server_arn\n",
171170
"os.environ[\"pipeline_name\"] = pipeline_name"
@@ -520,16 +519,19 @@
520519
" test_dataset_s3_path=preprocessing_step[2],\n",
521520
" train_config_s3_path=train_config_s3_path,\n",
522521
" role=role,\n",
523-
" model_id=model_s3_destination,\n",
522+
" model_id=model_s3_destination\n",
524523
")\n",
525524
"run_id=training_step[0]\n",
526525
"model_artifacts_s3_path=training_step[2]\n",
527-
"output_path=training_step[3]\n",
526+
"# output_path=training_step[3]\n",
528527
"\n",
529528
"deploy_step = deploy_step.deploy(\n",
529+
" tracking_server_arn=mlflow_tracking_server_arn,\n",
530530
" model_artifacts_s3_path=model_artifacts_s3_path,\n",
531-
" output_path=output_path,\n",
531+
" # output_path=output_path,\n",
532532
" model_id=model_s3_destination,\n",
533+
" experiment_name=pipeline_name,\n",
534+
" run_id=run_id,\n",
533535
")\n",
534536
"endpoint_name=deploy_step\n",
535537
"\n",
@@ -574,7 +576,7 @@
574576
" run_id=run_id, # Assuming training_step returns run_id as first output\n",
575577
" model_artifacts_s3_path=model_artifacts_s3_path, # Assuming training_step returns artifacts path as second output\n",
576578
" model_id=model_id,\n",
577-
" model_name=f\"Fine-Tuned-Medical-DeepSeek\",\n",
579+
" model_name=f\"Fine-Tuned-Medical-Qwen3-4B-Instruct-2507\",\n",
578580
" endpoint_name=endpoint_name,\n",
579581
" evaluation_score=quantitative_eval_step[\"rougeL_f\"], # Get the evaluation score\n",
580582
" pipeline_name=pipeline_name,\n",
@@ -728,8 +730,7 @@
728730
"\n",
729731
"# Clean up endpoint\n",
730732
"try:\n",
731-
" model_name_safe = model_id.split('/')[-1].replace('.', '-').replace('_', '-')\n",
732-
" endpoint_name = f\"{model_name_safe}-sft-djl\"\n",
733+
" endpoint_name = f\"{model_id.replace('/', '-').replace('_', '-')}-sft-djl\"\n",
733734
" \n",
734735
" print(f\"Cleaning up endpoint: {endpoint_name}\")\n",
735736
" if delete_endpoint_with_retry(endpoint_name):\n",
Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1-
awscli==1.42.25
2-
transformers==4.50.2
1+
transformers==4.52.2
32
peft==0.14.0
43
accelerate==1.3.0
54
bitsandbytes==0.45.1
6-
datasets==3.5.0
5+
datasets==3.2.0
76
evaluate==0.4.3
87
huggingface_hub[hf_transfer]==0.33.4
9-
mlflow
8+
mlflow==2.22.2
109
safetensors>=0.5.2
11-
sagemaker==2.244.0
10+
sagemaker==2.252.0
1211
sagemaker-mlflow==0.1.0
1312
sentencepiece==0.2.0
1413
scikit-learn==1.6.1
1514
tokenizers>=0.21.0
16-
trl==0.9.6
17-
psutil
18-
py7zr
19-
pynvml
20-
xtarfile
21-
rouge-score
15+
trl==0.18.0
16+
psutil==7.1.0
17+
py7zr==1.0.0
18+
pynvml==13.0.1
19+
xtarfile==0.2.1
20+
rouge-score==0.1.2

0 commit comments

Comments
 (0)