@@ -1552,6 +1552,8 @@ def test_model_initialization_with_config_name(
15521552
15531553 model = JumpStartModel (model_id = model_id , config_name = "neuron-inference" )
15541554
1555+ assert model .config_name == "neuron-inference"
1556+
15551557 model .deploy ()
15561558
15571559 mock_model_deploy .assert_called_once_with (
@@ -1594,6 +1596,8 @@ def test_model_set_deployment_config(
15941596
15951597 model = JumpStartModel (model_id = model_id )
15961598
1599+ assert model .config_name is None
1600+
15971601 model .deploy ()
15981602
15991603 mock_model_deploy .assert_called_once_with (
@@ -1612,6 +1616,8 @@ def test_model_set_deployment_config(
16121616 mock_get_model_specs .side_effect = get_prototype_spec_with_configs
16131617 model .set_deployment_config ("neuron-inference" )
16141618
1619+ assert model .config_name == "neuron-inference"
1620+
16151621 model .deploy ()
16161622
16171623 mock_model_deploy .assert_called_once_with (
@@ -1654,6 +1660,8 @@ def test_model_unset_deployment_config(
16541660
16551661 model = JumpStartModel (model_id = model_id , config_name = "neuron-inference" )
16561662
1663+ assert model .config_name == "neuron-inference"
1664+
16571665 model .deploy ()
16581666
16591667 mock_model_deploy .assert_called_once_with (
@@ -1789,7 +1797,6 @@ def test_model_retrieve_deployment_config(
17891797 ):
17901798 model_id , _ = "pytorch-eqa-bert-base-cased" , "*"
17911799
1792- mock_get_init_kwargs .side_effect = lambda * args , ** kwargs : get_mock_init_kwargs (model_id )
17931800 mock_verify_model_region_and_return_specs .side_effect = (
17941801 lambda * args , ** kwargs : get_base_spec_with_prototype_configs_with_missing_benchmarks ()
17951802 )
@@ -1804,15 +1811,23 @@ def test_model_retrieve_deployment_config(
18041811 )
18051812 mock_model_deploy .return_value = default_predictor
18061813
1814+ expected = get_base_deployment_configs ()[0 ]
1815+ config_name = expected .get ("DeploymentConfigName" )
1816+ mock_get_init_kwargs .side_effect = lambda * args , ** kwargs : get_mock_init_kwargs (
1817+ model_id , config_name
1818+ )
1819+
18071820 mock_session .return_value = sagemaker_session
18081821
18091822 model = JumpStartModel (model_id = model_id )
18101823
1811- expected = get_base_deployment_configs ()[0 ]
1812- model .set_deployment_config (expected .get ("DeploymentConfigName" ))
1824+ model .set_deployment_config (config_name )
18131825
18141826 self .assertEqual (model .deployment_config , expected )
18151827
1828+ mock_get_init_kwargs .reset_mock ()
1829+ mock_get_init_kwargs .side_effect = lambda * args , ** kwargs : get_mock_init_kwargs (model_id )
1830+
18161831 # Unset
18171832 model .set_deployment_config (None )
18181833 self .assertIsNone (model .deployment_config )
0 commit comments