Skip to content
Merged
11 changes: 11 additions & 0 deletions sagemaker-train/src/sagemaker/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,14 @@
"qwen.qwen3-32b-v1:0",
"qwen.qwen3-coder-30b-a3b-v1:0"
]

# Allowed evaluator models for LLM as Judge evaluator with region restrictions
_ALLOWED_EVALUATOR_MODELS = {
"anthropic.claude-3-5-sonnet-20240620-v1:0": ["us-west-2", "us-east-1", "ap-northeast-1"],
"anthropic.claude-3-5-sonnet-20241022-v2:0": ["us-west-2"],
"anthropic.claude-3-haiku-20240307-v1:0": ["us-west-2", "us-east-1", "ap-northeast-1", "eu-west-1"],
"anthropic.claude-3-5-haiku-20241022-v1:0": ["us-west-2"],
"meta.llama3-1-70b-instruct-v1:0": ["us-west-2"],
"mistral.mistral-large-2402-v1:0": ["us-west-2", "us-east-1", "eu-west-1"],
"amazon.nova-pro-v1:0": ["us-east-1"]
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .base_evaluator import BaseEvaluator
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import _ALLOWED_EVALUATOR_MODELS

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,6 +145,30 @@ def _validate_model_compatibility(cls, v, values):
)

return v

@validator('evaluator_model')
def _validate_evaluator_model(cls, v, values):
"""Validate evaluator_model is allowed and check region compatibility."""

if v not in _ALLOWED_EVALUATOR_MODELS:
raise ValueError(
f"Invalid evaluator_model '{v}'. "
f"Allowed models are: {list(_ALLOWED_EVALUATOR_MODELS.keys())}"
)

# Get current region from session
session = values.get('sagemaker_session')
if session and hasattr(session, 'boto_region_name'):
current_region = session.boto_region_name
allowed_regions = _ALLOWED_EVALUATOR_MODELS[v]

if current_region not in allowed_regions:
raise ValueError(
f"Evaluator model '{v}' is not available in region '{current_region}'. "
f"Available regions for this model: {allowed_regions}"
)

return v

def _process_builtin_metrics(self, metrics: Optional[List[str]]) -> List[str]:
"""Process builtin metrics by removing 'Builtin.' prefix if present.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,116 @@ def test_llm_as_judge_evaluator_with_mlflow_names(mock_artifact, mock_resolve):

assert evaluator.mlflow_experiment_name == "my-experiment"
assert evaluator.mlflow_run_name == "my-run"


@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model')
@patch('sagemaker.core.resources.Artifact')
def test_llm_as_judge_evaluator_valid_evaluator_models(mock_artifact, mock_resolve):
"""Test LLMAsJudgeEvaluator with valid evaluator models."""
valid_models = [
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-5-haiku-20241022-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"mistral.mistral-large-2402-v1:0",
]

mock_info = Mock()
mock_info.base_model_name = DEFAULT_MODEL
mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN
mock_info.source_model_package_arn = None
mock_resolve.return_value = mock_info

mock_artifact.get_all.return_value = iter([])
mock_artifact_instance = Mock()
mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN
mock_artifact.create.return_value = mock_artifact_instance

mock_session = Mock()
mock_session.boto_region_name = "us-west-2" # Region where all models including nova-pro are available
mock_session.boto_session = Mock()
mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE

for model in valid_models:
evaluator = LLMAsJudgeEvaluator(
model=DEFAULT_MODEL,
evaluator_model=model,
dataset=DEFAULT_DATASET,
builtin_metrics=["Correctness"],
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
)
assert evaluator.evaluator_model == model


@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model')
@patch('sagemaker.core.resources.Artifact')
def test_llm_as_judge_evaluator_invalid_evaluator_model(mock_artifact, mock_resolve):
"""Test LLMAsJudgeEvaluator raises error for invalid evaluator model."""
mock_info = Mock()
mock_info.base_model_name = DEFAULT_MODEL
mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN
mock_info.source_model_package_arn = None
mock_resolve.return_value = mock_info

mock_artifact.get_all.return_value = iter([])
mock_artifact_instance = Mock()
mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN
mock_artifact.create.return_value = mock_artifact_instance

mock_session = Mock()
mock_session.boto_region_name = DEFAULT_REGION
mock_session.boto_session = Mock()
mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE

with pytest.raises(ValidationError) as exc_info:
LLMAsJudgeEvaluator(
model=DEFAULT_MODEL,
evaluator_model="invalid-model",
dataset=DEFAULT_DATASET,
builtin_metrics=["Correctness"],
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
)
assert "Invalid evaluator_model 'invalid-model'" in str(exc_info.value)


@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model')
@patch('sagemaker.core.resources.Artifact')
def test_llm_as_judge_evaluator_region_restriction(mock_artifact, mock_resolve, mock_get_session):
"""Test LLMAsJudgeEvaluator raises error for model not available in region."""
mock_info = Mock()
mock_info.base_model_name = DEFAULT_MODEL
mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN
mock_info.source_model_package_arn = None
mock_resolve.return_value = mock_info

mock_artifact.get_all.return_value = iter([])
mock_artifact_instance = Mock()
mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN
mock_artifact.create.return_value = mock_artifact_instance

mock_session = Mock()
mock_session.boto_region_name = "eu-central-1" # Region not supported for nova-pro
mock_session.boto_session = Mock()
mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE
mock_get_session.return_value = mock_session

with pytest.raises(ValidationError) as exc_info:
LLMAsJudgeEvaluator(
model=DEFAULT_MODEL,
evaluator_model="amazon.nova-pro-v1:0",
dataset=DEFAULT_DATASET,
builtin_metrics=["Correctness"],
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
)
assert "not available in region" in str(exc_info.value)
Loading