Skip to content

Commit 3d777b3

Browse files
added utils for task 2
1 parent 89b499b commit 3d777b3

File tree

1 file changed

+109
-0
lines changed
  • workshops/fine-tuning-with-sagemakerai-and-bedrock/task_02_customize_foundation_model

1 file changed

+109
-0
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import matplotlib.pyplot as plt
2+
from typing import List, Dict
3+
import boto3
4+
5+
def calculate_message_lengths(dataset: List[Dict]) -> List[int]:
6+
"""
7+
Calculate the length of content/text for each element in the dataset.
8+
9+
Args:
10+
dataset: List of dictionaries containing messages or text
11+
12+
Returns:
13+
List of word counts for each element
14+
"""
15+
try:
16+
# First try to process as messages format
17+
return [sum(len(msg["content"].split())
18+
for msg in element["messages"])
19+
for element in dataset]
20+
except KeyError:
21+
# Fallback to direct text/content format
22+
key = "content" if "content" in dataset[0] else "text"
23+
return [len(element[key].split()) for element in dataset]
24+
25+
def plot_length_distribution(train_dataset: List[Dict],
26+
validation_dataset: List[Dict],
27+
bins: int = 20,
28+
figsize: tuple = (10, 6)) -> None:
29+
"""
30+
Plot the distribution of text lengths from training and validation datasets.
31+
32+
Args:
33+
train_dataset: Training dataset
34+
validation_dataset: Validation dataset
35+
bins: Number of histogram bins
36+
figsize: Figure size as (width, height)
37+
"""
38+
# Calculate lengths for both datasets
39+
train_lengths = calculate_message_lengths(train_dataset)
40+
val_lengths = calculate_message_lengths(validation_dataset)
41+
combined_lengths = train_lengths + val_lengths
42+
43+
# Create and configure the plot
44+
plt.figure(figsize=figsize)
45+
plt.hist(combined_lengths,
46+
bins=bins,
47+
alpha=0.7,
48+
color="blue")
49+
50+
# Set labels and title
51+
plt.xlabel("Prompt Lengths (words)")
52+
plt.ylabel("Frequency")
53+
plt.title("Distribution of Input Lengths")
54+
55+
plt.show()
56+
57+
58+
def get_last_job_name(job_name_prefix):
59+
sagemaker_client = boto3.client('sagemaker')
60+
61+
matching_jobs = []
62+
next_token = None
63+
64+
while True:
65+
# Prepare the search parameters
66+
search_params = {
67+
'Resource': 'TrainingJob',
68+
'SearchExpression': {
69+
'Filters': [
70+
{
71+
'Name': 'TrainingJobName',
72+
'Operator': 'Contains',
73+
'Value': job_name_prefix
74+
},
75+
{
76+
'Name': 'TrainingJobStatus',
77+
'Operator': 'Equals',
78+
'Value': "Completed"
79+
}
80+
]
81+
},
82+
'SortBy': 'CreationTime',
83+
'SortOrder': 'Descending',
84+
'MaxResults': 100
85+
}
86+
87+
# Add NextToken if we have one
88+
if next_token:
89+
search_params['NextToken'] = next_token
90+
91+
# Make the search request
92+
search_response = sagemaker_client.search(**search_params)
93+
94+
# Filter and add matching jobs
95+
matching_jobs.extend([
96+
job['TrainingJob']['TrainingJobName']
97+
for job in search_response['Results']
98+
if job['TrainingJob']['TrainingJobName'].startswith(job_name_prefix)
99+
])
100+
101+
# Check if we have more results to fetch
102+
next_token = search_response.get('NextToken')
103+
if not next_token or matching_jobs: # Stop if we found at least one match or no more results
104+
break
105+
106+
if not matching_jobs:
107+
raise ValueError(f"No completed training jobs found starting with prefix '{job_name_prefix}'")
108+
109+
return matching_jobs[0]

0 commit comments

Comments
 (0)