|
1 | 1 | """Tests for the Vector Inference API client.""" |
2 | 2 |
|
| 3 | +import subprocess |
3 | 4 | from unittest.mock import MagicMock, patch |
4 | 5 |
|
5 | 6 | import pytest |
@@ -518,3 +519,188 @@ def test_batch_launch_models_with_custom_config_integration(): |
518 | 519 | assert result.slurm_job_id == "12345678" |
519 | 520 | assert result.slurm_job_name == "BATCH-model1-model2" |
520 | 521 | assert result.model_names == ["model1", "model2"] |
| 522 | + |
| 523 | + |
| 524 | +def test_fetch_running_jobs_success_with_matching_jobs(): |
| 525 | + """Test fetch_running_jobs returns matching job IDs.""" |
| 526 | + client = VecInfClient() |
| 527 | + |
| 528 | + # Mock squeue output with multiple jobs |
| 529 | + squeue_output = "12345 RUNNING gpu\n67890 RUNNING gpu\n" |
| 530 | + # Mock scontrol outputs for each job |
| 531 | + scontrol_outputs = { |
| 532 | + "12345": "JobId=12345 JobName=test-model-vec-inf User=user", |
| 533 | + "67890": "JobId=67890 JobName=other-model-vec-inf User=user", |
| 534 | + } |
| 535 | + |
| 536 | + def mock_subprocess_run(cmd, **kwargs): |
| 537 | + mock_result = MagicMock() |
| 538 | + if cmd[0] == "squeue": |
| 539 | + mock_result.stdout = squeue_output |
| 540 | + mock_result.returncode = 0 |
| 541 | + elif cmd[0] == "scontrol": |
| 542 | + job_id = cmd[-1] |
| 543 | + mock_result.stdout = scontrol_outputs.get(job_id, "") |
| 544 | + mock_result.returncode = 0 |
| 545 | + return mock_result |
| 546 | + |
| 547 | + with patch("vec_inf.client.api.subprocess.run", side_effect=mock_subprocess_run): |
| 548 | + result = client.fetch_running_jobs() |
| 549 | + |
| 550 | + assert result == ["12345", "67890"] |
| 551 | + |
| 552 | + |
| 553 | +def test_fetch_running_jobs_no_matching_jobs(): |
| 554 | + """Test fetch_running_jobs returns empty list when no jobs match.""" |
| 555 | + client = VecInfClient() |
| 556 | + |
| 557 | + # Mock squeue output with jobs that don't match |
| 558 | + squeue_output = "12345 RUNNING gpu\n67890 RUNNING gpu\n" |
| 559 | + # Mock scontrol outputs - jobs don't end with -vec-inf |
| 560 | + scontrol_outputs = { |
| 561 | + "12345": "JobId=12345 JobName=test-model User=user", |
| 562 | + "67890": "JobId=67890 JobName=other-job User=user", |
| 563 | + } |
| 564 | + |
| 565 | + def mock_subprocess_run(cmd, **kwargs): |
| 566 | + mock_result = MagicMock() |
| 567 | + if cmd[0] == "squeue": |
| 568 | + mock_result.stdout = squeue_output |
| 569 | + mock_result.returncode = 0 |
| 570 | + elif cmd[0] == "scontrol": |
| 571 | + job_id = cmd[-1] |
| 572 | + mock_result.stdout = scontrol_outputs.get(job_id, "") |
| 573 | + mock_result.returncode = 0 |
| 574 | + return mock_result |
| 575 | + |
| 576 | + with patch("vec_inf.client.api.subprocess.run", side_effect=mock_subprocess_run): |
| 577 | + result = client.fetch_running_jobs() |
| 578 | + |
| 579 | + assert result == [] |
| 580 | + |
| 581 | + |
| 582 | +def test_fetch_running_jobs_empty_squeue(): |
| 583 | + """Test fetch_running_jobs returns empty list when squeue is empty.""" |
| 584 | + client = VecInfClient() |
| 585 | + |
| 586 | + # Mock empty squeue output |
| 587 | + squeue_output = "" |
| 588 | + |
| 589 | + def mock_subprocess_run(cmd, **kwargs): |
| 590 | + mock_result = MagicMock() |
| 591 | + if cmd[0] == "squeue": |
| 592 | + mock_result.stdout = squeue_output |
| 593 | + mock_result.returncode = 0 |
| 594 | + return mock_result |
| 595 | + |
| 596 | + with patch("vec_inf.client.api.subprocess.run", side_effect=mock_subprocess_run): |
| 597 | + result = client.fetch_running_jobs() |
| 598 | + |
| 599 | + assert result == [] |
| 600 | + |
| 601 | + |
| 602 | +def test_fetch_running_jobs_mixed_jobs(): |
| 603 | + """Test fetch_running_jobs filters correctly with mixed matching/non-matching jobs.""" |
| 604 | + client = VecInfClient() |
| 605 | + |
| 606 | + # Mock squeue output with multiple jobs |
| 607 | + squeue_output = "12345 RUNNING gpu\n67890 RUNNING gpu\n11111 RUNNING gpu\n" |
| 608 | + # Mock scontrol outputs - only some match |
| 609 | + scontrol_outputs = { |
| 610 | + "12345": "JobId=12345 JobName=test-model-vec-inf User=user", |
| 611 | + "67890": "JobId=67890 JobName=other-job User=user", # Doesn't match |
| 612 | + "11111": "JobId=11111 JobName=another-model-vec-inf User=user", |
| 613 | + } |
| 614 | + |
| 615 | + def mock_subprocess_run(cmd, **kwargs): |
| 616 | + mock_result = MagicMock() |
| 617 | + if cmd[0] == "squeue": |
| 618 | + mock_result.stdout = squeue_output |
| 619 | + mock_result.returncode = 0 |
| 620 | + elif cmd[0] == "scontrol": |
| 621 | + job_id = cmd[-1] |
| 622 | + mock_result.stdout = scontrol_outputs.get(job_id, "") |
| 623 | + mock_result.returncode = 0 |
| 624 | + return mock_result |
| 625 | + |
| 626 | + with patch("vec_inf.client.api.subprocess.run", side_effect=mock_subprocess_run): |
| 627 | + result = client.fetch_running_jobs() |
| 628 | + |
| 629 | + assert result == ["12345", "11111"] |
| 630 | + |
| 631 | + |
| 632 | +def test_fetch_running_jobs_scontrol_failure(): |
| 633 | + """Test fetch_running_jobs skips jobs when scontrol fails.""" |
| 634 | + client = VecInfClient() |
| 635 | + |
| 636 | + # Mock squeue output |
| 637 | + squeue_output = "12345 RUNNING gpu\n67890 RUNNING gpu\n" |
| 638 | + # Mock scontrol - one succeeds, one fails |
| 639 | + scontrol_outputs = { |
| 640 | + "12345": "JobId=12345 JobName=test-model-vec-inf User=user", |
| 641 | + } |
| 642 | + |
| 643 | + def mock_subprocess_run(cmd, **kwargs): |
| 644 | + mock_result = MagicMock() |
| 645 | + if cmd[0] == "squeue": |
| 646 | + mock_result.stdout = squeue_output |
| 647 | + mock_result.returncode = 0 |
| 648 | + elif cmd[0] == "scontrol": |
| 649 | + job_id = cmd[-1] |
| 650 | + if job_id in scontrol_outputs: |
| 651 | + mock_result.stdout = scontrol_outputs[job_id] |
| 652 | + mock_result.returncode = 0 |
| 653 | + else: |
| 654 | + # Simulate CalledProcessError for job 67890 |
| 655 | + raise subprocess.CalledProcessError(1, cmd) |
| 656 | + return mock_result |
| 657 | + |
| 658 | + with patch("vec_inf.client.api.subprocess.run", side_effect=mock_subprocess_run): |
| 659 | + result = client.fetch_running_jobs() |
| 660 | + |
| 661 | + # Should only return the job that succeeded |
| 662 | + assert result == ["12345"] |
| 663 | + |
| 664 | + |
| 665 | +def test_fetch_running_jobs_squeue_failure(): |
| 666 | + """Test fetch_running_jobs raises SlurmJobError when squeue fails.""" |
| 667 | + client = VecInfClient() |
| 668 | + |
| 669 | + def mock_subprocess_run(cmd, **kwargs): |
| 670 | + mock_result = MagicMock() |
| 671 | + if cmd[0] == "squeue": |
| 672 | + # Simulate CalledProcessError |
| 673 | + raise subprocess.CalledProcessError(1, cmd, stderr="squeue: error") |
| 674 | + return mock_result |
| 675 | + |
| 676 | + with ( |
| 677 | + patch("vec_inf.client.api.subprocess.run", side_effect=mock_subprocess_run), |
| 678 | + pytest.raises(SlurmJobError, match="Error running slurm command"), |
| 679 | + ): |
| 680 | + client.fetch_running_jobs() |
| 681 | + |
| 682 | + |
| 683 | +def test_fetch_running_jobs_job_name_not_found(): |
| 684 | + """Test fetch_running_jobs handles missing JobName in scontrol output.""" |
| 685 | + client = VecInfClient() |
| 686 | + |
| 687 | + # Mock squeue output |
| 688 | + squeue_output = "12345 RUNNING gpu\n" |
| 689 | + # Mock scontrol output without JobName |
| 690 | + scontrol_output = "JobId=12345 User=user State=RUNNING" |
| 691 | + |
| 692 | + def mock_subprocess_run(cmd, **kwargs): |
| 693 | + mock_result = MagicMock() |
| 694 | + if cmd[0] == "squeue": |
| 695 | + mock_result.stdout = squeue_output |
| 696 | + mock_result.returncode = 0 |
| 697 | + elif cmd[0] == "scontrol": |
| 698 | + mock_result.stdout = scontrol_output |
| 699 | + mock_result.returncode = 0 |
| 700 | + return mock_result |
| 701 | + |
| 702 | + with patch("vec_inf.client.api.subprocess.run", side_effect=mock_subprocess_run): |
| 703 | + result = client.fetch_running_jobs() |
| 704 | + |
| 705 | + # Should return empty list since JobName doesn't match |
| 706 | + assert result == [] |
0 commit comments