@@ -226,13 +226,12 @@ def base_patches(test_paths, mock_truediv, debug_helper):
226226 "pathlib.Path.parent" , return_value = debug_helper .config_file .parent .parent
227227 ),
228228 patch ("pathlib.Path.__truediv__" , side_effect = mock_truediv ),
229- patch ("pathlib.Path.iterdir" , return_value = []), # Mock empty directory listing
229+ patch ("pathlib.Path.iterdir" , return_value = []),
230230 patch ("json.dump" ),
231231 patch ("pathlib.Path.touch" ),
232232 patch ("vec_inf.client._utils.Path" , return_value = test_paths ["weights_dir" ]),
233- patch (
234- "pathlib.Path.home" , return_value = Path ("/home/user" )
235- ), # Mock home directory
233+ patch ("pathlib.Path.home" , return_value = Path ("/home/user" )),
234+ patch ("pathlib.Path.rename" ),
236235 ]
237236
238237
@@ -246,25 +245,19 @@ def apply_base_patches(base_patches):
246245 yield
247246
248247
249- def test_launch_command_success (runner , mock_launch_output , path_exists , debug_helper ):
248+ def test_launch_command_success (
249+ runner , mock_launch_output , path_exists , debug_helper , mock_truediv , test_paths , base_patches
250+ ):
250251 """Test successful model launch with minimal required arguments."""
251- test_log_dir = Path ("/tmp/test_vec_inf_logs" )
252+ with ExitStack () as stack :
253+ # Apply all base patches
254+ for patch_obj in base_patches :
255+ stack .enter_context (patch_obj )
256+
257+ # Apply specific patches for this test
258+ mock_run = stack .enter_context (patch ("vec_inf.client._utils.run_bash_command" ))
259+ stack .enter_context (patch ("pathlib.Path.exists" , new = path_exists ))
252260
253- with (
254- patch ("vec_inf.client._utils.run_bash_command" ) as mock_run ,
255- patch ("pathlib.Path.mkdir" ),
256- patch ("builtins.open" , debug_helper .tracked_mock_open ),
257- patch ("pathlib.Path.open" , debug_helper .tracked_mock_open ),
258- patch ("pathlib.Path.exists" , new = path_exists ),
259- patch ("pathlib.Path.expanduser" , return_value = test_log_dir ),
260- patch ("pathlib.Path.resolve" , return_value = debug_helper .config_file .parent ),
261- patch (
262- "pathlib.Path.parent" , return_value = debug_helper .config_file .parent .parent
263- ),
264- patch ("json.dump" ),
265- patch ("pathlib.Path.touch" ),
266- patch ("pathlib.Path.__truediv__" , return_value = test_log_dir ),
267- ):
268261 expected_job_id = "14933053"
269262 mock_run .return_value = mock_launch_output (expected_job_id )
270263
@@ -277,25 +270,18 @@ def test_launch_command_success(runner, mock_launch_output, path_exists, debug_h
277270
278271
279272def test_launch_command_with_json_output (
280- runner , mock_launch_output , path_exists , debug_helper
273+ runner , mock_launch_output , path_exists , debug_helper , mock_truediv , test_paths , base_patches
281274):
282275 """Test JSON output format for launch command."""
283- test_log_dir = Path ("/tmp/test_vec_inf_logs" )
284- with (
285- patch ("vec_inf.client._utils.run_bash_command" ) as mock_run ,
286- patch ("pathlib.Path.mkdir" ),
287- patch ("builtins.open" , debug_helper .tracked_mock_open ),
288- patch ("pathlib.Path.open" , debug_helper .tracked_mock_open ),
289- patch ("pathlib.Path.exists" , new = path_exists ),
290- patch ("pathlib.Path.expanduser" , return_value = test_log_dir ),
291- patch ("pathlib.Path.resolve" , return_value = debug_helper .config_file .parent ),
292- patch (
293- "pathlib.Path.parent" , return_value = debug_helper .config_file .parent .parent
294- ),
295- patch ("json.dump" ),
296- patch ("pathlib.Path.touch" ),
297- patch ("pathlib.Path.__truediv__" , return_value = test_log_dir ),
298- ):
276+ with ExitStack () as stack :
277+ # Apply all base patches
278+ for patch_obj in base_patches :
279+ stack .enter_context (patch_obj )
280+
281+ # Apply specific patches for this test
282+ mock_run = stack .enter_context (patch ("vec_inf.client._utils.run_bash_command" ))
283+ stack .enter_context (patch ("pathlib.Path.exists" , new = path_exists ))
284+
299285 expected_job_id = "14933051"
300286 mock_run .return_value = mock_launch_output (expected_job_id )
301287
@@ -319,7 +305,7 @@ def test_launch_command_with_json_output(
319305 assert output .get ("slurm_job_id" ) == expected_job_id
320306 assert output .get ("model_name" ) == "Meta-Llama-3.1-8B"
321307 assert output .get ("model_type" ) == "LLM"
322- assert str (test_log_dir ) in output .get ("log_dir" , "" )
308+ assert str (test_paths [ "log_dir" ] ) in output .get ("log_dir" , "" )
323309
324310
325311def test_launch_command_no_model_weights_parent_dir (runner , debug_helper , base_patches ):
0 commit comments