@@ -32,24 +32,34 @@ def test_node_should_be_modified_fw_constructor_no_fw_version():
3232 fw_constructors = (
3333 "TensorFlow()" ,
3434 "sagemaker.tensorflow.TensorFlow()" ,
35+ "sagemaker.tensorflow.estimator.TensorFlow()" ,
3536 "TensorFlowModel()" ,
3637 "sagemaker.tensorflow.TensorFlowModel()" ,
38+ "sagemaker.tensorflow.model.TensorFlowModel()" ,
3739 "MXNet()" ,
3840 "sagemaker.mxnet.MXNet()" ,
41+ "sagemaker.mxnet.estimator.MXNet()" ,
3942 "MXNetModel()" ,
4043 "sagemaker.mxnet.MXNetModel()" ,
44+ "sagemaker.mxnet.model.MXNetModel()" ,
4145 "Chainer()" ,
4246 "sagemaker.chainer.Chainer()" ,
47+ "sagemaker.chainer.estimator.Chainer()" ,
4348 "ChainerModel()" ,
4449 "sagemaker.chainer.ChainerModel()" ,
50+ "sagemaker.chainer.model.ChainerModel()" ,
4551 "PyTorch()" ,
4652 "sagemaker.pytorch.PyTorch()" ,
53+ "sagemaker.pytorch.estimator.PyTorch()" ,
4754 "PyTorchModel()" ,
4855 "sagemaker.pytorch.PyTorchModel()" ,
56+ "sagemaker.pytorch.model.PyTorchModel()" ,
4957 "SKLearn()" ,
5058 "sagemaker.sklearn.SKLearn()" ,
59+ "sagemaker.sklearn.estimator.SKLearn()" ,
5160 "SKLearnModel()" ,
5261 "sagemaker.sklearn.SKLearnModel()" ,
62+ "sagemaker.sklearn.model.SKLearnModel()" ,
5363 )
5464
5565 modifier = framework_version .FrameworkVersionEnforcer ()
@@ -63,24 +73,34 @@ def test_node_should_be_modified_fw_constructor_with_fw_version():
6373 fw_constructors = (
6474 "TensorFlow(framework_version='2.2')" ,
6575 "sagemaker.tensorflow.TensorFlow(framework_version='2.2')" ,
76+ "sagemaker.tensorflow.estimator.TensorFlow(framework_version='2.2')" ,
6677 "TensorFlowModel(framework_version='1.10')" ,
6778 "sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')" ,
79+ "sagemaker.tensorflow.model.TensorFlowModel(framework_version='1.10')" ,
6880 "MXNet(framework_version='1.6')" ,
6981 "sagemaker.mxnet.MXNet(framework_version='1.6')" ,
82+ "sagemaker.mxnet.estimator.MXNet(framework_version='1.6')" ,
7083 "MXNetModel(framework_version='1.6')" ,
7184 "sagemaker.mxnet.MXNetModel(framework_version='1.6')" ,
85+ "sagemaker.mxnet.model.MXNetModel(framework_version='1.6')" ,
7286 "PyTorch(framework_version='1.4')" ,
7387 "sagemaker.pytorch.PyTorch(framework_version='1.4')" ,
88+ "sagemaker.pytorch.estimator.PyTorch(framework_version='1.4')" ,
7489 "PyTorchModel(framework_version='1.4')" ,
7590 "sagemaker.pytorch.PyTorchModel(framework_version='1.4')" ,
91+ "sagemaker.pytorch.model.PyTorchModel(framework_version='1.4')" ,
7692 "Chainer(framework_version='5.0')" ,
7793 "sagemaker.chainer.Chainer(framework_version='5.0')" ,
94+ "sagemaker.chainer.estimator.Chainer(framework_version='5.0')" ,
7895 "ChainerModel(framework_version='5.0')" ,
7996 "sagemaker.chainer.ChainerModel(framework_version='5.0')" ,
97+ "sagemaker.chainer.model.ChainerModel(framework_version='5.0')" ,
8098 "SKLearn(framework_version='0.20.0')" ,
8199 "sagemaker.sklearn.SKLearn(framework_version='0.20.0')" ,
100+ "sagemaker.sklearn.estimator.SKLearn(framework_version='0.20.0')" ,
82101 "SKLearnModel(framework_version='0.20.0')" ,
83102 "sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')" ,
103+ "sagemaker.sklearn.model.SKLearnModel(framework_version='0.20.0')" ,
84104 )
85105
86106 modifier = framework_version .FrameworkVersionEnforcer ()
@@ -97,51 +117,36 @@ def test_node_should_be_modified_random_function_call():
97117
98118
99119def test_modify_node_tf ():
100- classes = (
101- "TensorFlow" "sagemaker.tensorflow.TensorFlow" ,
102- "TensorFlowModel" ,
103- "sagemaker.tensorflow.TensorFlowModel" ,
104- )
105- _test_modify_node (classes , "1.11.0" )
120+ _test_modify_node ("TensorFlow" , "1.11.0" )
106121
107122
108123def test_modify_node_mx ():
109- classes = ("MXNet" , "sagemaker.mxnet.MXNet" , "MXNetModel" , "sagemaker.mxnet.MXNetModel" )
110- _test_modify_node (classes , "1.2.0" )
124+ _test_modify_node ("MXNet" , "1.2.0" )
111125
112126
113127def test_modify_node_chainer ():
114- classes = (
115- "Chainer" ,
116- "sagemaker.chainer.Chainer" ,
117- "ChainerModel" ,
118- "sagemaker.chainer.ChainerModel" ,
119- )
120- _test_modify_node (classes , "4.1.0" )
128+ _test_modify_node ("Chainer" , "4.1.0" )
121129
122130
123131def test_modify_node_pt ():
124- classes = (
125- "PyTorch" ,
126- "sagemaker.pytorch.PyTorch" ,
127- "PyTorchModel" ,
128- "sagemaker.pytorch.PyTorchModel" ,
129- )
130- _test_modify_node (classes , "0.4.0" )
132+ _test_modify_node ("PyTorch" , "0.4.0" )
131133
132134
133135def test_modify_node_sklearn ():
134- classes = (
135- "SKLearn" ,
136- "sagemaker.sklearn.SKLearn" ,
137- "SKLearnModel" ,
138- "sagemaker.sklearn.SKLearnModel" ,
139- )
140- _test_modify_node (classes , "0.20.0" )
136+ _test_modify_node ("SKLearn" , "0.20.0" )
141137
142138
143- def _test_modify_node (classes , default_version ):
139+ def _test_modify_node (framework , default_version ):
144140 modifier = framework_version .FrameworkVersionEnforcer ()
141+
142+ classes = (
143+ "{}" .format (framework ),
144+ "sagemaker.{}.{}" .format (framework .lower (), framework ),
145+ "sagemaker.{}.estimator.{}" .format (framework .lower (), framework ),
146+ "{}Model" .format (framework ),
147+ "sagemaker.{}.{}Model" .format (framework .lower (), framework ),
148+ "sagemaker.{}.model.{}Model" .format (framework .lower (), framework ),
149+ )
145150 for cls in classes :
146151 node = ast_call ("{}()" .format (cls ))
147152 modifier .modify_node (node )
0 commit comments