1+ import matplotlib .pyplot as plt
2+ from typing import List , Dict
3+ import boto3
4+
5+ def calculate_message_lengths (dataset : List [Dict ]) -> List [int ]:
6+ """
7+ Calculate the length of content/text for each element in the dataset.
8+
9+ Args:
10+ dataset: List of dictionaries containing messages or text
11+
12+ Returns:
13+ List of word counts for each element
14+ """
15+ try :
16+ # First try to process as messages format
17+ return [sum (len (msg ["content" ].split ())
18+ for msg in element ["messages" ])
19+ for element in dataset ]
20+ except KeyError :
21+ # Fallback to direct text/content format
22+ key = "content" if "content" in dataset [0 ] else "text"
23+ return [len (element [key ].split ()) for element in dataset ]
24+
25+ def plot_length_distribution (train_dataset : List [Dict ],
26+ validation_dataset : List [Dict ],
27+ bins : int = 20 ,
28+ figsize : tuple = (10 , 6 )) -> None :
29+ """
30+ Plot the distribution of text lengths from training and validation datasets.
31+
32+ Args:
33+ train_dataset: Training dataset
34+ validation_dataset: Validation dataset
35+ bins: Number of histogram bins
36+ figsize: Figure size as (width, height)
37+ """
38+ # Calculate lengths for both datasets
39+ train_lengths = calculate_message_lengths (train_dataset )
40+ val_lengths = calculate_message_lengths (validation_dataset )
41+ combined_lengths = train_lengths + val_lengths
42+
43+ # Create and configure the plot
44+ plt .figure (figsize = figsize )
45+ plt .hist (combined_lengths ,
46+ bins = bins ,
47+ alpha = 0.7 ,
48+ color = "blue" )
49+
50+ # Set labels and title
51+ plt .xlabel ("Prompt Lengths (words)" )
52+ plt .ylabel ("Frequency" )
53+ plt .title ("Distribution of Input Lengths" )
54+
55+ plt .show ()
56+
57+
58+ def get_last_job_name (job_name_prefix ):
59+ sagemaker_client = boto3 .client ('sagemaker' )
60+
61+ matching_jobs = []
62+ next_token = None
63+
64+ while True :
65+ # Prepare the search parameters
66+ search_params = {
67+ 'Resource' : 'TrainingJob' ,
68+ 'SearchExpression' : {
69+ 'Filters' : [
70+ {
71+ 'Name' : 'TrainingJobName' ,
72+ 'Operator' : 'Contains' ,
73+ 'Value' : job_name_prefix
74+ },
75+ {
76+ 'Name' : 'TrainingJobStatus' ,
77+ 'Operator' : 'Equals' ,
78+ 'Value' : "Completed"
79+ }
80+ ]
81+ },
82+ 'SortBy' : 'CreationTime' ,
83+ 'SortOrder' : 'Descending' ,
84+ 'MaxResults' : 100
85+ }
86+
87+ # Add NextToken if we have one
88+ if next_token :
89+ search_params ['NextToken' ] = next_token
90+
91+ # Make the search request
92+ search_response = sagemaker_client .search (** search_params )
93+
94+ # Filter and add matching jobs
95+ matching_jobs .extend ([
96+ job ['TrainingJob' ]['TrainingJobName' ]
97+ for job in search_response ['Results' ]
98+ if job ['TrainingJob' ]['TrainingJobName' ].startswith (job_name_prefix )
99+ ])
100+
101+ # Check if we have more results to fetch
102+ next_token = search_response .get ('NextToken' )
103+ if not next_token or matching_jobs : # Stop if we found at least one match or no more results
104+ break
105+
106+ if not matching_jobs :
107+ raise ValueError (f"No completed training jobs found starting with prefix '{ job_name_prefix } '" )
108+
109+ return matching_jobs [0 ]
0 commit comments