From 9bb96bfaabe7cf0f92652b118dfc56db2dea1a0a Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Tue, 16 Dec 2025 08:34:26 -0800 Subject: [PATCH 1/3] feat: Add support to trainer object for model parameter in Evaluator --- .../train/evaluate/base_evaluator.py | 21 ++++- .../train/evaluate/benchmark_evaluator.py | 8 -- .../train/evaluate/test_base_evaluator.py | 85 +++++++++++++++++++ .../evaluate/test_benchmark_evaluator.py | 52 ++++++------ 4 files changed, 128 insertions(+), 38 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index 620b7ffe34..a1a593b810 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -15,9 +15,11 @@ from sagemaker.core.resources import ModelPackageGroup from sagemaker.core.shapes import VpcConfig +from sagemaker.core.utils.utils import Unassigned if TYPE_CHECKING: from sagemaker.core.helper.session_helper import Session + from sagemaker.train.base_trainer import BaseTrainer # Module-level logger _logger = logging.getLogger(__name__) @@ -278,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]: return v @validator('model') - def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any]: + def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) -> Union[str, Any]: """Resolve model information from various input types. This validator uses the common model resolution utility to extract: @@ -289,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any The resolved information is stored in private attributes for use by subclasses. Args: - v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, or ARN). + v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer). values (dict): Dictionary of already-validated fields. Returns: @@ -302,12 +304,25 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any import os try: + # Handle BaseTrainer type + if hasattr(v, '__class__') and v.__class__.__name__ == 'BaseTrainer' or hasattr(v, '_latest_training_job'): + if hasattr(v._latest_training_job, 'output_model_package_arn'): + arn = v._latest_training_job.output_model_package_arn + if not isinstance(arn, Unassigned): + model_to_resolve = arn + else: + raise ValueError("BaseTrainer must have completed training job to be used for evaluation") + else: + raise ValueError("BaseTrainer must have completed training job to be used for evaluation") + else: + model_to_resolve = v + # Get the session for resolution (may not be created yet due to validator order) session = values.get('sagemaker_session') # Resolve model information model_info = _resolve_base_model( - base_model=v, + base_model=model_to_resolve, sagemaker_session=session ) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py index 4ca685b811..b8b26e119f 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py @@ -300,18 +300,10 @@ class BenchMarkEvaluator(BaseEvaluator): """ benchmark: _Benchmark - dataset: Union[str, Any] # Required field, must come before optional fields subtasks: Optional[Union[str, List[str]]] = None evaluate_base_model: bool = True _hyperparameters: Optional[Any] = None - @validator('dataset', pre=True) - def _resolve_dataset(cls, v): - """Resolve dataset to string (S3 URI or ARN) and validate format. - - Uses BaseEvaluator's common validation logic to avoid code duplication. - """ - return BaseEvaluator._validate_and_resolve_dataset(v) @validator('benchmark') def _validate_benchmark_model_compatibility(cls, v, values): diff --git a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py index 86c09489a0..cd94e2c7c8 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py @@ -20,6 +20,8 @@ from sagemaker.core.shapes import VpcConfig from sagemaker.core.resources import ModelPackageGroup, Artifact from sagemaker.core.shapes import ArtifactSource, ArtifactSourceType +from sagemaker.core.utils.utils import Unassigned +from sagemaker.train.base_trainer import BaseTrainer from sagemaker.train.evaluate.base_evaluator import BaseEvaluator @@ -1291,3 +1293,86 @@ def test_with_all_optional_params(self, mock_resolve, mock_session, mock_model_i assert evaluator.networking == vpc_config assert evaluator.kms_key_id == "arn:aws:kms:us-west-2:123456789012:key/12345" assert evaluator.region == DEFAULT_REGION + + +class TestBaseTrainerHandling: + """Tests for BaseTrainer model handling.""" + + @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") + def test_base_trainer_with_valid_training_job(self, mock_resolve, mock_session, mock_model_info_with_package): + """Test BaseTrainer with valid completed training job.""" + mock_resolve.return_value = mock_model_info_with_package + + # Create mock BaseTrainer with completed training job + mock_trainer = MagicMock(spec=BaseTrainer) + mock_training_job = MagicMock() + mock_training_job.output_model_package_arn = DEFAULT_MODEL_PACKAGE_ARN + mock_trainer._latest_training_job = mock_training_job + + evaluator = BaseEvaluator( + model=mock_trainer, + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + sagemaker_session=mock_session, + ) + + # Verify model resolution was called with the training job's model package ARN + mock_resolve.assert_called_once_with( + base_model=DEFAULT_MODEL_PACKAGE_ARN, + sagemaker_session=mock_session + ) + assert evaluator.model == mock_trainer + + @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") + def test_base_trainer_with_unassigned_arn(self, mock_resolve, mock_session): + """Test BaseTrainer with Unassigned output_model_package_arn raises error.""" + # Create mock BaseTrainer with Unassigned ARN + mock_trainer = MagicMock(spec=BaseTrainer) + mock_training_job = MagicMock() + mock_training_job.output_model_package_arn = Unassigned() + mock_trainer._latest_training_job = mock_training_job + + with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"): + BaseEvaluator( + model=mock_trainer, + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + sagemaker_session=mock_session, + ) + + @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") + def test_base_trainer_without_training_job(self, mock_resolve, mock_session): + """Test BaseTrainer without _latest_training_job falls through to normal processing.""" + # Create mock BaseTrainer without _latest_training_job attribute + mock_trainer = MagicMock() + mock_trainer.__class__.__name__ = 'BaseTrainer' + # Don't set _latest_training_job attribute at all + + # This should fail during model resolution, not in BaseTrainer handling + with pytest.raises(ValidationError, match="Failed to resolve model"): + BaseEvaluator( + model=mock_trainer, + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + sagemaker_session=mock_session, + ) + + def test_base_trainer_without_output_model_package_arn_attribute(self, mock_session): + """Test BaseTrainer with training job but missing output_model_package_arn attribute.""" + + # Create a custom class that doesn't have output_model_package_arn + class MockTrainingJobWithoutArn: + pass + + # Create mock BaseTrainer with _latest_training_job but no output_model_package_arn + mock_trainer = MagicMock() + mock_trainer.__class__.__name__ = 'BaseTrainer' + mock_trainer._latest_training_job = MockTrainingJobWithoutArn() + + with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"): + BaseEvaluator( + model=mock_trainer, + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + sagemaker_session=mock_session, + ) diff --git a/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py index e9c74e3f2b..858bb12d32 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py @@ -121,7 +121,7 @@ def test_benchmark_evaluator_initialization_minimal(mock_artifact, mock_resolve) evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -130,7 +130,6 @@ def test_benchmark_evaluator_initialization_minimal(mock_artifact, mock_resolve) assert evaluator.benchmark == _Benchmark.MMLU assert evaluator.model == DEFAULT_MODEL - assert evaluator.dataset == DEFAULT_DATASET assert evaluator.evaluate_base_model is True assert evaluator.subtasks == "ALL" @@ -158,7 +157,7 @@ def test_benchmark_evaluator_subtask_defaults_to_all(mock_artifact, mock_resolve evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -188,7 +187,7 @@ def test_benchmark_evaluator_subtask_validation_invalid(mock_artifact, mock_reso benchmark=_Benchmark.MMLU, subtasks=["invalid_subtask"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -216,7 +215,7 @@ def test_benchmark_evaluator_no_subtask_for_unsupported_benchmark(mock_artifact, benchmark=_Benchmark.GPQA, subtasks="some_subtask", model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -250,14 +249,13 @@ def test_benchmark_evaluator_dataset_resolution_from_object(mock_artifact, mock_ evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=mock_dataset, s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, sagemaker_session=mock_session, ) - assert evaluator.dataset == mock_dataset.arn + # Dataset field is commented out, so no assertion needed @patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') @@ -284,7 +282,7 @@ def test_benchmark_evaluator_evaluate_method_exists(mock_artifact, mock_resolve) benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -327,7 +325,7 @@ def test_benchmark_evaluator_evaluate_invalid_subtask_override(mock_artifact, mo evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -378,7 +376,7 @@ def test_benchmark_evaluator_missing_required_fields(): BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, sagemaker_session=mock_session, ) @@ -408,7 +406,7 @@ def test_benchmark_evaluator_resolve_subtask_for_evaluation(mock_artifact, mock_ benchmark=_Benchmark.MMLU, subtasks="abstract_algebra", # Use a specific subtask instead of "ALL" model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -458,7 +456,7 @@ def test_benchmark_evaluator_hyperparameters_property(mock_artifact, mock_resolv evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -512,7 +510,7 @@ def test_benchmark_evaluator_get_benchmark_template_additions(mock_artifact, moc benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -559,7 +557,7 @@ def test_benchmark_evaluator_mmmu_nova_validation(mock_artifact, mock_resolve, m BenchMarkEvaluator( benchmark=_Benchmark.MMMU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -596,7 +594,7 @@ def test_benchmark_evaluator_llm_judge_nova_validation(mock_artifact, mock_resol BenchMarkEvaluator( benchmark=_Benchmark.LLM_JUDGE, model="nova-pro", - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -629,7 +627,7 @@ def test_benchmark_evaluator_subtask_list_validation(mock_artifact, mock_resolve benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra", "anatomy"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -643,7 +641,7 @@ def test_benchmark_evaluator_subtask_list_validation(mock_artifact, mock_resolve benchmark=_Benchmark.MMLU, subtasks=[], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -675,7 +673,7 @@ def test_benchmark_evaluator_resolve_subtask_list(mock_artifact, mock_resolve): benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra", "anatomy"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -729,7 +727,7 @@ def test_benchmark_evaluator_template_additions_with_list_subtasks(mock_artifact benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra", "anatomy"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -761,7 +759,7 @@ def test_benchmark_evaluator_with_subtask_list(mock_resolve): evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, subtasks=['abstract_algebra', 'anatomy'], s3_output_path=DEFAULT_S3_OUTPUT, @@ -788,7 +786,7 @@ def test_benchmark_evaluator_with_subtask_string(mock_resolve): evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, subtasks='abstract_algebra', s3_output_path=DEFAULT_S3_OUTPUT, @@ -817,7 +815,7 @@ def test_benchmark_evaluator_invalid_subtask(mock_resolve): with pytest.raises(ValidationError, match="Invalid subtask"): BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, subtasks=['invalid_subtask'], s3_output_path=DEFAULT_S3_OUTPUT, @@ -843,7 +841,7 @@ def test_benchmark_evaluator_no_subtask_available(mock_resolve): # IFEVAL doesn't support subtasks evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.IFEVAL, s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, @@ -874,7 +872,7 @@ def test_benchmark_evaluator_with_networking(mock_resolve): evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, @@ -903,7 +901,7 @@ def test_benchmark_evaluator_with_kms_key(mock_resolve): evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, @@ -951,7 +949,7 @@ def test_benchmark_evaluator_uses_metric_key_for_nova(mock_artifact, mock_resolv evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model="nova-pro", - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -1001,7 +999,7 @@ def test_benchmark_evaluator_uses_evaluation_metric_key_for_non_nova(mock_artifa evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, From 6dc76a3859aefe6e18d637ba35ce3ce770cb2061 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Tue, 16 Dec 2025 15:15:16 -0800 Subject: [PATCH 2/3] feat: Evaluator handshake with trainer --- .../train/common_utils/model_resolution.py | 17 +++- .../train/evaluate/base_evaluator.py | 27 ++---- .../train/evaluate/benchmark_evaluator.py | 2 +- .../src/sagemaker/train/rlaif_trainer.py | 2 +- .../src/sagemaker/train/rlvr_trainer.py | 2 +- .../src/sagemaker/train/sft_trainer.py | 2 +- .../common_utils/test_model_resolution.py | 73 ++++++++++++++++ .../train/evaluate/test_base_evaluator.py | 83 ------------------- 8 files changed, 99 insertions(+), 109 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 1c3f09a43e..2ce6ea7198 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -13,6 +13,8 @@ from dataclasses import dataclass from enum import Enum import re +from sagemaker.train.base_trainer import BaseTrainer +from sagemaker.core.utils.utils import Unassigned class _ModelType(Enum): @@ -65,14 +67,14 @@ def __init__(self, sagemaker_session=None): def resolve_model_info( self, - base_model: Union[str, 'ModelPackage'], + base_model: Union[str, BaseTrainer, 'ModelPackage'], hub_name: Optional[str] = None ) -> _ModelInfo: """ Resolve model information from various input types. Args: - base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN + base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN or BaseTrainer object with a completed job hub_name: Optional hub name for JumpStart models (defaults to SageMakerPublicHub) Returns: @@ -88,6 +90,17 @@ def resolve_model_info( return self._resolve_model_package_arn(base_model) else: return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME) + # Handle BaseTrainer type + elif isinstance(base_model, BaseTrainer): + if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job, + 'output_model_package_arn'): + arn = base_model._latest_training_job.output_model_package_arn + if not isinstance(arn, Unassigned): + return self._resolve_model_package_arn(arn) + else: + raise ValueError("BaseTrainer must have completed training job to be used for evaluation") + else: + raise ValueError("BaseTrainer must have completed training job to be used for evaluation") else: # Not a string, so assume it's a ModelPackage object # Check if it has the expected attributes of a ModelPackage diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index a1a593b810..6a87fa96eb 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -13,14 +13,13 @@ from pydantic import BaseModel, validator -from sagemaker.core.resources import ModelPackageGroup +from sagemaker.core.resources import ModelPackageGroup, ModelPackage from sagemaker.core.shapes import VpcConfig -from sagemaker.core.utils.utils import Unassigned if TYPE_CHECKING: from sagemaker.core.helper.session_helper import Session - from sagemaker.train.base_trainer import BaseTrainer +from sagemaker.train.base_trainer import BaseTrainer # Module-level logger _logger = logging.getLogger(__name__) @@ -55,6 +54,7 @@ class BaseEvaluator(BaseModel): - JumpStart model ID (str): e.g., 'llama3-2-1b-instruct' - ModelPackage object: A fine-tuned model package - ModelPackage ARN (str): e.g., 'arn:aws:sagemaker:region:account:model-package/name/version' + - BaseTrainer object: A completed training job (i.e., it must have _latest_training_job with output_model_package_arn populated) base_eval_name (Optional[str]): Optional base name for evaluation jobs. This name is used as the PipelineExecutionDisplayName when creating the SageMaker pipeline execution. The actual display name will be "{base_eval_name}-{timestamp}". This parameter can @@ -88,7 +88,7 @@ class BaseEvaluator(BaseModel): region: Optional[str] = None sagemaker_session: Optional[Any] = None - model: Union[str, Any] + model: Union[str, BaseTrainer, ModelPackage] base_eval_name: Optional[str] = None s3_output_path: str mlflow_resource_arn: Optional[str] = None @@ -280,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]: return v @validator('model') - def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) -> Union[str, Any]: + def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: dict) -> Union[str, Any]: """Resolve model information from various input types. This validator uses the common model resolution utility to extract: @@ -291,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) -> The resolved information is stored in private attributes for use by subclasses. Args: - v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer). + v (Union[str, BaseTrainer, ModelPackage]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer). values (dict): Dictionary of already-validated fields. Returns: @@ -304,25 +304,12 @@ def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) -> import os try: - # Handle BaseTrainer type - if hasattr(v, '__class__') and v.__class__.__name__ == 'BaseTrainer' or hasattr(v, '_latest_training_job'): - if hasattr(v._latest_training_job, 'output_model_package_arn'): - arn = v._latest_training_job.output_model_package_arn - if not isinstance(arn, Unassigned): - model_to_resolve = arn - else: - raise ValueError("BaseTrainer must have completed training job to be used for evaluation") - else: - raise ValueError("BaseTrainer must have completed training job to be used for evaluation") - else: - model_to_resolve = v - # Get the session for resolution (may not be created yet due to validator order) session = values.get('sagemaker_session') # Resolve model information model_info = _resolve_base_model( - base_model=model_to_resolve, + base_model=v, sagemaker_session=session ) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py index b8b26e119f..5d37e53f8c 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py @@ -303,7 +303,7 @@ class BenchMarkEvaluator(BaseEvaluator): subtasks: Optional[Union[str, List[str]]] = None evaluate_base_model: bool = True _hyperparameters: Optional[Any] = None - + @validator('benchmark') def _validate_benchmark_model_compatibility(cls, v, values): diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index 1bf5c02813..ebadc6bfda 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -286,7 +286,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati except TimeoutExceededError as e: logger.error("Error: %s", e) - self.latest_training_job = training_job + self._latest_training_job = training_job return training_job def _process_hyperparameters(self): diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index f00c7aac36..b28c9d865c 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -274,5 +274,5 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, except TimeoutExceededError as e: logger.error("Error: %s", e) - self.latest_training_job = training_job + self._latest_training_job = training_job return training_job diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 57d2c52a06..b2688dce5d 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -268,7 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati except TimeoutExceededError as e: logger.error("Error: %s", e) - self.latest_training_job = training_job + self._latest_training_job = training_job return training_job diff --git a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py index d0cc5990a8..31a827e3f0 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py @@ -24,6 +24,8 @@ _ModelResolver, _resolve_base_model, ) +from sagemaker.train.base_trainer import BaseTrainer +from sagemaker.core.utils.utils import Unassigned class TestModelType: @@ -557,3 +559,74 @@ def test_resolve_base_model_with_hub_name(self, mock_resolver_class): _resolve_base_model("test-model", hub_name="CustomHub") mock_resolver.resolve_model_info.assert_called_once_with("test-model", "CustomHub") + + +class TestBaseTrainerHandling: + """Tests for BaseTrainer model handling in _resolve_base_model.""" + + def test_base_trainer_with_valid_training_job(self): + """Test BaseTrainer with valid completed training job.""" + # Create concrete BaseTrainer subclass for testing + class TestTrainer(BaseTrainer): + def train(self, input_data_config, wait=True, logs=True): + pass + + mock_trainer = TestTrainer() + mock_training_job = MagicMock() + mock_training_job.output_model_package_arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1" + mock_trainer._latest_training_job = mock_training_job + + with patch('sagemaker.train.common_utils.model_resolution._ModelResolver._resolve_model_package_arn') as mock_resolve_arn: + mock_resolve_arn.return_value = MagicMock() + + result = _resolve_base_model(mock_trainer) + + # Verify model package ARN resolution was called + mock_resolve_arn.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1" + ) + + def test_base_trainer_with_unassigned_arn(self): + """Test BaseTrainer with Unassigned output_model_package_arn raises error.""" + # Create concrete BaseTrainer subclass for testing + class TestTrainer(BaseTrainer): + def train(self, input_data_config, wait=True, logs=True): + pass + + mock_trainer = TestTrainer() + mock_training_job = MagicMock() + mock_training_job.output_model_package_arn = Unassigned() + mock_trainer._latest_training_job = mock_training_job + + with pytest.raises(ValueError, match="BaseTrainer must have completed training job"): + _resolve_base_model(mock_trainer) + + def test_base_trainer_without_training_job(self): + """Test BaseTrainer without _latest_training_job raises error.""" + # Create concrete BaseTrainer subclass for testing + class TestTrainer(BaseTrainer): + def train(self, input_data_config, wait=True, logs=True): + pass + + mock_trainer = TestTrainer() + # Don't set _latest_training_job attribute at all + + with pytest.raises(ValueError, match="BaseTrainer must have completed training job"): + _resolve_base_model(mock_trainer) + + def test_base_trainer_without_output_model_package_arn_attribute(self): + """Test BaseTrainer with training job but missing output_model_package_arn attribute.""" + # Create concrete BaseTrainer subclass for testing + class TestTrainer(BaseTrainer): + def train(self, input_data_config, wait=True, logs=True): + pass + + # Create a simple object without output_model_package_arn + class TrainingJobWithoutArn: + pass + + mock_trainer = TestTrainer() + mock_trainer._latest_training_job = TrainingJobWithoutArn() + + with pytest.raises(ValueError, match="BaseTrainer must have completed training job"): + _resolve_base_model(mock_trainer) diff --git a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py index cd94e2c7c8..c9b2e0a255 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py @@ -1293,86 +1293,3 @@ def test_with_all_optional_params(self, mock_resolve, mock_session, mock_model_i assert evaluator.networking == vpc_config assert evaluator.kms_key_id == "arn:aws:kms:us-west-2:123456789012:key/12345" assert evaluator.region == DEFAULT_REGION - - -class TestBaseTrainerHandling: - """Tests for BaseTrainer model handling.""" - - @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - def test_base_trainer_with_valid_training_job(self, mock_resolve, mock_session, mock_model_info_with_package): - """Test BaseTrainer with valid completed training job.""" - mock_resolve.return_value = mock_model_info_with_package - - # Create mock BaseTrainer with completed training job - mock_trainer = MagicMock(spec=BaseTrainer) - mock_training_job = MagicMock() - mock_training_job.output_model_package_arn = DEFAULT_MODEL_PACKAGE_ARN - mock_trainer._latest_training_job = mock_training_job - - evaluator = BaseEvaluator( - model=mock_trainer, - s3_output_path=DEFAULT_S3_OUTPUT, - mlflow_resource_arn=DEFAULT_MLFLOW_ARN, - sagemaker_session=mock_session, - ) - - # Verify model resolution was called with the training job's model package ARN - mock_resolve.assert_called_once_with( - base_model=DEFAULT_MODEL_PACKAGE_ARN, - sagemaker_session=mock_session - ) - assert evaluator.model == mock_trainer - - @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - def test_base_trainer_with_unassigned_arn(self, mock_resolve, mock_session): - """Test BaseTrainer with Unassigned output_model_package_arn raises error.""" - # Create mock BaseTrainer with Unassigned ARN - mock_trainer = MagicMock(spec=BaseTrainer) - mock_training_job = MagicMock() - mock_training_job.output_model_package_arn = Unassigned() - mock_trainer._latest_training_job = mock_training_job - - with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"): - BaseEvaluator( - model=mock_trainer, - s3_output_path=DEFAULT_S3_OUTPUT, - mlflow_resource_arn=DEFAULT_MLFLOW_ARN, - sagemaker_session=mock_session, - ) - - @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - def test_base_trainer_without_training_job(self, mock_resolve, mock_session): - """Test BaseTrainer without _latest_training_job falls through to normal processing.""" - # Create mock BaseTrainer without _latest_training_job attribute - mock_trainer = MagicMock() - mock_trainer.__class__.__name__ = 'BaseTrainer' - # Don't set _latest_training_job attribute at all - - # This should fail during model resolution, not in BaseTrainer handling - with pytest.raises(ValidationError, match="Failed to resolve model"): - BaseEvaluator( - model=mock_trainer, - s3_output_path=DEFAULT_S3_OUTPUT, - mlflow_resource_arn=DEFAULT_MLFLOW_ARN, - sagemaker_session=mock_session, - ) - - def test_base_trainer_without_output_model_package_arn_attribute(self, mock_session): - """Test BaseTrainer with training job but missing output_model_package_arn attribute.""" - - # Create a custom class that doesn't have output_model_package_arn - class MockTrainingJobWithoutArn: - pass - - # Create mock BaseTrainer with _latest_training_job but no output_model_package_arn - mock_trainer = MagicMock() - mock_trainer.__class__.__name__ = 'BaseTrainer' - mock_trainer._latest_training_job = MockTrainingJobWithoutArn() - - with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"): - BaseEvaluator( - model=mock_trainer, - s3_output_path=DEFAULT_S3_OUTPUT, - mlflow_resource_arn=DEFAULT_MLFLOW_ARN, - sagemaker_session=mock_session, - ) From 512c22d78bf242a7d9f3c0caafdd35e19d6a5688 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Wed, 17 Dec 2025 09:21:58 -0800 Subject: [PATCH 3/3] fix: update evaluate_base_model as False, minor change to README --- sagemaker-train/README.rst | 2 +- .../src/sagemaker/train/evaluate/benchmark_evaluator.py | 2 +- .../src/sagemaker/train/evaluate/custom_scorer_evaluator.py | 2 +- .../src/sagemaker/train/evaluate/llm_as_judge_evaluator.py | 2 +- .../tests/unit/train/evaluate/test_benchmark_evaluator.py | 4 ++-- .../tests/unit/train/evaluate/test_custom_scorer_evaluator.py | 4 ++-- .../tests/unit/train/evaluate/test_llm_as_judge_evaluator.py | 4 ++-- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sagemaker-train/README.rst b/sagemaker-train/README.rst index 90e306c42d..c1a1195902 100644 --- a/sagemaker-train/README.rst +++ b/sagemaker-train/README.rst @@ -47,7 +47,7 @@ Table of Contents Installing the SageMaker Python SDK Train ------------------------------------ +----------------------------------------- You can install from source by cloning this repository and running a pip install command in the root directory of the repository: diff --git a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py index 5d37e53f8c..d6bad422c6 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py @@ -301,7 +301,7 @@ class BenchMarkEvaluator(BaseEvaluator): benchmark: _Benchmark subtasks: Optional[Union[str, List[str]]] = None - evaluate_base_model: bool = True + evaluate_base_model: bool = False _hyperparameters: Optional[Any] = None diff --git a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py index 290a6f80ba..78d297006c 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py @@ -137,7 +137,7 @@ class CustomScorerEvaluator(BaseEvaluator): _hyperparameters: Optional[Any] = None # Template-required fields - evaluate_base_model: bool = True + evaluate_base_model: bool = False @validator('dataset', pre=True) def _resolve_dataset(cls, v): diff --git a/sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py index 98e1c50c48..8438b65688 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py @@ -123,7 +123,7 @@ class LLMAsJudgeEvaluator(BaseEvaluator): custom_metrics: Optional[str] = None # Template-required fields - evaluate_base_model: bool = True + evaluate_base_model: bool = False @validator('dataset', pre=True) def _resolve_dataset(cls, v): diff --git a/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py index 858bb12d32..d87a435ba0 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py @@ -130,7 +130,7 @@ def test_benchmark_evaluator_initialization_minimal(mock_artifact, mock_resolve) assert evaluator.benchmark == _Benchmark.MMLU assert evaluator.model == DEFAULT_MODEL - assert evaluator.evaluate_base_model is True + assert evaluator.evaluate_base_model is False assert evaluator.subtasks == "ALL" @@ -525,7 +525,7 @@ def test_benchmark_evaluator_get_benchmark_template_additions(mock_artifact, moc assert additions['strategy'] == 'zs_cot' assert additions['evaluation_metric'] == 'accuracy' assert additions['subtask'] == 'abstract_algebra' - assert additions['evaluate_base_model'] is True + assert additions['evaluate_base_model'] is False @patch('sagemaker.train.common_utils.recipe_utils._is_nova_model') diff --git a/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py index 1f37632903..9267cc7f73 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py @@ -85,7 +85,7 @@ def test_custom_scorer_evaluator_initialization_minimal(mock_artifact, mock_reso assert evaluator.evaluator == _BuiltInMetric.PRIME_MATH assert evaluator.dataset == DEFAULT_DATASET assert evaluator.model == DEFAULT_MODEL - assert evaluator.evaluate_base_model is True + assert evaluator.evaluate_base_model is False @patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') @@ -952,7 +952,7 @@ def test_custom_scorer_evaluator_get_custom_scorer_template_additions_builtin( assert additions['task'] == 'gen_qa' assert additions['strategy'] == 'gen_qa' assert additions['evaluation_metric'] == 'all' - assert additions['evaluate_base_model'] is True + assert additions['evaluate_base_model'] is False assert additions['evaluator_arn'] is None assert additions['preset_reward_function'] == 'prime_math' assert 'temperature' in additions diff --git a/sagemaker-train/tests/unit/train/evaluate/test_llm_as_judge_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_llm_as_judge_evaluator.py index 5af23f7960..60f89b6b69 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_llm_as_judge_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_llm_as_judge_evaluator.py @@ -67,7 +67,7 @@ def test_llm_as_judge_evaluator_initialization_minimal(mock_artifact, mock_resol assert evaluator.evaluator_model == DEFAULT_EVALUATOR_MODEL assert evaluator.dataset == DEFAULT_DATASET assert evaluator.model == DEFAULT_MODEL - assert evaluator.evaluate_base_model is True + assert evaluator.evaluate_base_model is False assert evaluator.builtin_metrics is None assert evaluator.custom_metrics is None @@ -472,7 +472,7 @@ def test_llm_as_judge_evaluator_get_llmaj_template_additions(mock_artifact, mock assert additions['top_p'] == '1.0' # pipeline_name is no longer in template additions - it's resolved dynamically in execution.py assert 'pipeline_name' not in additions - assert additions['evaluate_base_model'] is True + assert additions['evaluate_base_model'] is False # Verify S3 upload was called mock_s3_upload.assert_called_once()