@@ -40,15 +40,29 @@ def __init__(self, params: dict[str, Any]):
4040 self .model_weights_path = str (
4141 Path (self .params ["model_weights_parent_dir" ], self .params ["model_name" ])
4242 )
43+ self .env_str = self ._generate_env_str ()
44+
45+ def _generate_env_str (self ) -> str :
46+ """Generate the environment variables string for the Slurm script.
47+
48+ Returns
49+ -------
50+ str
51+ Formatted environment variables string for container or shell export commands.
52+ """
4353 env_dict : dict [str , str ] = self .params .get ("env" , {})
44- # Create string of environment variables
45- self .env_str = ""
46- for key , val in env_dict .items ():
47- if len (self .env_str ) == 0 :
48- self .env_str = "--env "
49- else :
50- self .env_str += ","
51- self .env_str += key + "=" + val
54+
55+ if not env_dict :
56+ return ""
57+
58+ if self .use_container :
59+ # Format for container: --env KEY1=VAL1,KEY2=VAL2
60+ env_pairs = [f"{ key } ={ val } " for key , val in env_dict .items ()]
61+ return f"--env { ',' .join (env_pairs )} "
62+ else :
63+ # Format for shell: export KEY1=VAL1\nexport KEY2=VAL2
64+ export_lines = [f"export { key } ={ val } " for key , val in env_dict .items ()]
65+ return "\n " .join (export_lines )
5266
5367 def _generate_script_content (self ) -> str :
5468 """Generate the complete Slurm script content.
@@ -94,7 +108,12 @@ def _generate_server_setup(self) -> str:
94108 server_script = ["\n " ]
95109 if self .use_container :
96110 server_script .append ("\n " .join (SLURM_SCRIPT_TEMPLATE ["container_setup" ]))
97- server_script .append ("\n " .join (SLURM_SCRIPT_TEMPLATE ["env_vars" ]))
111+ server_script .append ("\n " .join (SLURM_SCRIPT_TEMPLATE ["container_env_vars" ]))
112+ else :
113+ server_script .append (
114+ SLURM_SCRIPT_TEMPLATE ["activate_venv" ].format (venv = self .params ["venv" ])
115+ )
116+ server_script .append (self .env_str )
98117 server_script .append (
99118 SLURM_SCRIPT_TEMPLATE ["imports" ].format (src_dir = self .params ["src_dir" ])
100119 )
@@ -111,6 +130,11 @@ def _generate_server_setup(self) -> str:
111130 env_str = self .env_str ,
112131 ),
113132 )
133+ else :
134+ server_setup_str = server_setup_str .replace (
135+ "CONTAINER_PLACEHOLDER" ,
136+ "\\ " ,
137+ )
114138 else :
115139 server_setup_str = "\n " .join (
116140 SLURM_SCRIPT_TEMPLATE ["server_setup" ]["single_node" ]
@@ -144,10 +168,7 @@ def _generate_launch_cmd(self) -> str:
144168 env_str = self .env_str ,
145169 )
146170 )
147- else :
148- launcher_script .append (
149- SLURM_SCRIPT_TEMPLATE ["activate_venv" ].format (venv = self .params ["venv" ])
150- )
171+
151172 launcher_script .append (
152173 "\n " .join (SLURM_SCRIPT_TEMPLATE ["launch_cmd" ]).format (
153174 model_weights_path = self .model_weights_path ,
0 commit comments