11import os
22import time
3- from typing import Optional
3+ from typing import Optional , cast
44
55import click
6- import pandas as pd
6+
7+ import polars as pl
78from rich .columns import Columns
89from rich .console import Console
910from rich .live import Live
@@ -91,6 +92,11 @@ def cli():
9192 type = str ,
9293 help = "Enable pipeline parallelism, accepts 'True' or 'False', default to 'True' for supported models" ,
9394)
95+ @click .option (
96+ "--enforce-eager" ,
97+ type = str ,
98+ help = "Always use eager-mode PyTorch, accepts 'True' or 'False', default to 'False' for custom models if not set" ,
99+ )
94100@click .option (
95101 "--json-mode" ,
96102 is_flag = True ,
@@ -113,14 +119,17 @@ def launch(
113119 log_dir : Optional [str ] = None ,
114120 model_weights_parent_dir : Optional [str ] = None ,
115121 pipeline_parallelism : Optional [str ] = None ,
122+ enforce_eager : Optional [str ] = None ,
116123 json_mode : bool = False ,
117124) -> None :
118125 """
119126 Launch a model on the cluster
120127 """
121128
122129 if isinstance (pipeline_parallelism , str ):
123- pipeline_parallelism = pipeline_parallelism .lower () == "true"
130+ pipeline_parallelism = (
131+ "True" if pipeline_parallelism .lower () == "true" else "False"
132+ )
124133
125134 launch_script_path = os .path .join (
126135 os .path .dirname (os .path .dirname (os .path .realpath (__file__ ))), "launch_server.sh"
@@ -129,15 +138,15 @@ def launch(
129138
130139 models_df = utils .load_models_df ()
131140
132- if model_name in models_df ["model_name" ].values :
141+ if model_name in models_df ["model_name" ].to_list () :
133142 default_args = utils .load_default_args (models_df , model_name )
134143 for arg in default_args :
135144 if arg in locals () and locals ()[arg ] is not None :
136145 default_args [arg ] = locals ()[arg ]
137146 renamed_arg = arg .replace ("_" , "-" )
138147 launch_cmd += f" --{ renamed_arg } { default_args [arg ]} "
139148 else :
140- model_args = models_df .columns . tolist ()
149+ model_args = models_df .columns
141150 model_args .remove ("model_name" )
142151 model_args .remove ("model_type" )
143152 for arg in model_args :
@@ -265,45 +274,58 @@ def shutdown(slurm_job_id: int) -> None:
265274 is_flag = True ,
266275 help = "Output in JSON string" ,
267276)
268- def list (model_name : Optional [str ] = None , json_mode : bool = False ) -> None :
277+ def list_models (model_name : Optional [str ] = None , json_mode : bool = False ) -> None :
269278 """
270279 List all available models, or get default setup of a specific model
271280 """
272281
273- def list_model (model_name : str , models_df : pd .DataFrame , json_mode : bool ):
274- if model_name not in models_df ["model_name" ].values :
282+ def list_model (model_name : str , models_df : pl .DataFrame , json_mode : bool ):
283+ if model_name not in models_df ["model_name" ].to_list () :
275284 raise ValueError (f"Model name { model_name } not found in available models" )
276285
277286 excluded_keys = {"venv" , "log_dir" }
278- model_row = models_df .loc [ models_df ["model_name" ] == model_name ]
287+ model_row = models_df .filter ( models_df ["model_name" ] == model_name )
279288
280289 if json_mode :
281- filtered_model_row = model_row .drop (columns = excluded_keys , errors = "ignore" )
282- click .echo (filtered_model_row .to_json ( orient = "records" ) )
290+ filtered_model_row = model_row .drop (excluded_keys , strict = False )
291+ click .echo (filtered_model_row .to_dicts ()[ 0 ] )
283292 return
284293 table = utils .create_table (key_title = "Model Config" , value_title = "Value" )
285- for _ , row in model_row .iterrows ():
294+ for row in model_row .to_dicts ():
286295 for key , value in row .items ():
287296 if key not in excluded_keys :
288297 table .add_row (key , str (value ))
289298 CONSOLE .print (table )
290299
291- def list_all (models_df : pd .DataFrame , json_mode : bool ):
300+ def list_all (models_df : pl .DataFrame , json_mode : bool ):
292301 if json_mode :
293- click .echo (models_df ["model_name" ].to_json ( orient = "records" ))
302+ click .echo (models_df ["model_name" ].to_list ( ))
294303 return
295304 panels = []
296305 model_type_colors = {
297306 "LLM" : "cyan" ,
298307 "VLM" : "bright_blue" ,
299308 "Text Embedding" : "purple" ,
309+ "Reward Modeling" : "bright_magenta" ,
300310 }
301- custom_order = ["LLM" , "VLM" , "Text Embedding" ]
302- models_df ["model_type" ] = pd .Categorical (
303- models_df ["model_type" ], categories = custom_order , ordered = True
311+
312+ models_df = models_df .with_columns (
313+ pl .when (pl .col ("model_type" ) == "LLM" )
314+ .then (0 )
315+ .when (pl .col ("model_type" ) == "VLM" )
316+ .then (1 )
317+ .when (pl .col ("model_type" ) == "Text Embedding" )
318+ .then (2 )
319+ .when (pl .col ("model_type" ) == "Reward Modeling" )
320+ .then (3 )
321+ .otherwise (- 1 )
322+ .alias ("model_type_order" )
304323 )
305- models_df = models_df .sort_values (by = "model_type" )
306- for _ , row in models_df .iterrows ():
324+
325+ models_df = models_df .sort ("model_type_order" )
326+ models_df = models_df .drop ("model_type_order" )
327+
328+ for row in models_df .to_dicts ():
307329 panel_color = model_type_colors .get (row ["model_type" ], "white" )
308330 styled_text = (
309331 f"[magenta]{ row ['model_family' ]} [/magenta]-{ row ['model_variant' ]} "
@@ -336,10 +358,22 @@ def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:
336358
337359 with Live (refresh_per_second = 1 , console = CONSOLE ) as live :
338360 while True :
339- out_logs = utils .read_slurm_log (slurm_job_name , slurm_job_id , "out" , log_dir )
340- metrics = utils .get_latest_metric (out_logs )
361+ out_logs = utils .read_slurm_log (
362+ slurm_job_name , slurm_job_id , "out" , log_dir
363+ )
364+ # if out_logs is a string, then it is an error message
365+ if isinstance (out_logs , str ):
366+ live .update (out_logs )
367+ break
368+ out_logs = cast (list , out_logs )
369+ latest_metrics = utils .get_latest_metric (out_logs )
370+ # if latest_metrics is a string, then it is an error message
371+ if isinstance (latest_metrics , str ):
372+ live .update (latest_metrics )
373+ break
374+ latest_metrics = cast (dict , latest_metrics )
341375 table = utils .create_table (key_title = "Metric" , value_title = "Value" )
342- for key , value in metrics .items ():
376+ for key , value in latest_metrics .items ():
343377 table .add_row (key , value )
344378
345379 live .update (table )
0 commit comments