2323import smdebug_rulesconfig as rule_configs # noqa: F401 # pylint: disable=unused-import
2424
2525
26+ RULES_ECR_REPO_NAME = "sagemaker-debugger-rules"
27+
28+ SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP = {
29+ "eu-north-1" : {RULES_ECR_REPO_NAME : "314864569078" },
30+ "me-south-1" : {RULES_ECR_REPO_NAME : "986000313247" },
31+ "ap-south-1" : {RULES_ECR_REPO_NAME : "904829902805" },
32+ "eu-west-3" : {RULES_ECR_REPO_NAME : "447278800020" },
33+ "us-east-2" : {RULES_ECR_REPO_NAME : "915447279597" },
34+ "eu-west-1" : {RULES_ECR_REPO_NAME : "929884845733" },
35+ "eu-central-1" : {RULES_ECR_REPO_NAME : "482524230118" },
36+ "sa-east-1" : {RULES_ECR_REPO_NAME : "818342061345" },
37+ "ap-east-1" : {RULES_ECR_REPO_NAME : "199566480951" },
38+ "us-east-1" : {RULES_ECR_REPO_NAME : "503895931360" },
39+ "ap-northeast-2" : {RULES_ECR_REPO_NAME : "578805364391" },
40+ "eu-west-2" : {RULES_ECR_REPO_NAME : "250201462417" },
41+ "ap-northeast-1" : {RULES_ECR_REPO_NAME : "430734990657" },
42+ "us-west-2" : {RULES_ECR_REPO_NAME : "895741380848" },
43+ "us-west-1" : {RULES_ECR_REPO_NAME : "685455198987" },
44+ "ap-southeast-1" : {RULES_ECR_REPO_NAME : "972752614525" },
45+ "ap-southeast-2" : {RULES_ECR_REPO_NAME : "184798709955" },
46+ "ca-central-1" : {RULES_ECR_REPO_NAME : "519511493484" },
47+ }
48+
49+
50+ def get_rule_container_image_uri (region ):
51+ """
52+ Returns the rule image uri for the given AWS region and rule type
53+
54+ Args:
55+ region: AWS Region
56+
57+ Returns:
58+ str: Formatted image uri for the given region and the rule container type
59+ """
60+ registry_id = SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP .get (region ).get (RULES_ECR_REPO_NAME )
61+ return "{}.dkr.ecr.{}.amazonaws.com/{}:latest" .format (registry_id , region , RULES_ECR_REPO_NAME )
62+
63+
2664class Rule (object ):
2765 """Rules analyze tensors emitted during the training of a model. They
2866 monitor conditions that are critical for the success of a training job.
@@ -40,7 +78,7 @@ def __init__(
4078 name ,
4179 image_uri ,
4280 instance_type ,
43- container_local_path ,
81+ container_local_output_path ,
4482 s3_output_path ,
4583 volume_size_in_gb ,
4684 rule_parameters ,
@@ -58,7 +96,7 @@ def __init__(
5896 image_uri (str): The URI of the image to be used by the debugger rule.
5997 instance_type (str): Type of EC2 instance to use, for example,
6098 'ml.c4.xlarge'.
61- container_local_path (str): The path in the container .
99+ container_local_output_path (str): The local path to store the Rule output .
62100 s3_output_path (str): The location in S3 to store the output.
63101 volume_size_in_gb (int): Size in GB of the EBS volume
64102 to use for storing data.
@@ -68,7 +106,7 @@ def __init__(
68106 """
69107 self .name = name
70108 self .instance_type = instance_type
71- self .container_local_path = container_local_path
109+ self .container_local_output_path = container_local_output_path
72110 self .s3_output_path = s3_output_path
73111 self .volume_size_in_gb = volume_size_in_gb
74112 self .rule_parameters = rule_parameters
@@ -80,10 +118,8 @@ def sagemaker(
80118 cls ,
81119 base_config ,
82120 name = None ,
83- instance_type = None ,
84- container_local_path = None ,
121+ container_local_output_path = None ,
85122 s3_output_path = None ,
86- volume_size_in_gb = None ,
87123 other_trials_s3_input_paths = None ,
88124 rule_parameters = None ,
89125 collections_to_save = None ,
@@ -98,13 +134,8 @@ def sagemaker(
98134 built-in list of rules. For example, 'rule_configs.dead_relu()'.
99135 name (str): The name of the debugger rule. If one is not provided,
100136 the name of the base_config will be used.
101- instance_type (str): Type of EC2 instance to use, for example,
102- 'ml.c4.xlarge'. If one is not provided, the instance type from
103- the base_config will be used.
104- container_local_path (str): The path in the container.
137+ container_local_output_path (str): The path in the container.
105138 s3_output_path (str): The location in S3 to store the output.
106- volume_size_in_gb (int): Size in GB of the EBS volume
107- to use for storing data.
108139 other_trials_s3_input_paths ([str]): S3 input paths for other trials.
109140 rule_parameters (dict): A dictionary of parameters for the rule.
110141 collections_to_save ([sagemaker.debugger.CollectionConfig]): A list
@@ -113,9 +144,22 @@ def sagemaker(
113144 Returns:
114145 sagemaker.debugger.Rule: The instance of the built-in Rule.
115146 """
116- other_trials_params = {}
147+ merged_rule_params = {}
148+
149+ if rule_parameters is not None and rule_parameters .get ("rule_to_invoke" ) is not None :
150+ raise RuntimeError (
151+ """You cannot provide a 'rule_to_invoke' for SageMaker rules.
152+ Please either remove the rule_to_invoke or use a custom rule.
153+ """
154+ )
155+
117156 if other_trials_s3_input_paths is not None :
118- other_trials_params ["other_trials_s3_input_paths" ] = other_trials_s3_input_paths
157+ for index , s3_input_path in enumerate (other_trials_s3_input_paths ):
158+ merged_rule_params ["other_trial_{}" .format (str (index ))] = s3_input_path
159+
160+ default_rule_params = base_config ["DebugRuleConfiguration" ].get ("RuleParameters" , {})
161+ merged_rule_params .update (default_rule_params )
162+ merged_rule_params .update (rule_parameters or {})
119163
120164 base_config_collections = []
121165 for config in base_config .get ("CollectionConfigurations" , []):
@@ -133,16 +177,11 @@ def sagemaker(
133177 return cls (
134178 name = name or base_config ["DebugRuleConfiguration" ].get ("RuleConfigurationName" ),
135179 image_uri = "DEFAULT_RULE_EVALUATOR_IMAGE" ,
136- instance_type = instance_type or "t3.medium" ,
137- # TODO-reinvent-2019 [akarpur]: Remove t3.medium from line above,
138- # uncomment line below when 1P package updated
139- # or base_config["DebugRuleConfiguration"].get("InstanceType"),
140- container_local_path = container_local_path ,
180+ instance_type = None ,
181+ container_local_output_path = container_local_output_path ,
141182 s3_output_path = s3_output_path ,
142- volume_size_in_gb = volume_size_in_gb ,
143- rule_parameters = other_trials_params .update (
144- rule_parameters or base_config ["DebugRuleConfiguration" ].get ("RuleParameters" , {})
145- ),
183+ volume_size_in_gb = None ,
184+ rule_parameters = merged_rule_params ,
146185 collections_to_save = collections_to_save or base_config_collections ,
147186 )
148187
@@ -154,7 +193,7 @@ def custom(
154193 instance_type ,
155194 source = None ,
156195 rule_to_invoke = None ,
157- container_local_path = None ,
196+ container_local_output_path = None ,
158197 s3_output_path = None ,
159198 volume_size_in_gb = None ,
160199 other_trials_s3_input_paths = None ,
@@ -175,7 +214,7 @@ def custom(
175214 you must also provide rule_to_invoke.
176215 rule_to_invoke (str): The name of the rule to invoke within the source.
177216 If provided, you must also provide source.
178- container_local_path (str): The path in the container.
217+ container_local_output_path (str): The path in the container.
179218 s3_output_path (str): The location in S3 to store the output.
180219 volume_size_in_gb (int): Size in GB of the EBS volume
181220 to use for storing data.
@@ -192,25 +231,28 @@ def custom(
192231 "If you provide a source, you must also provide a rule to invoke (and vice versa)."
193232 )
194233
195- source_params = {}
234+ merged_rule_params = {}
235+
196236 if source is not None and rule_to_invoke is not None :
197- source_params ["source_s3_uri" ] = source
198- source_params ["rule_to_invoke" ] = rule_to_invoke
237+ merged_rule_params ["source_s3_uri" ] = source
238+ merged_rule_params ["rule_to_invoke" ] = rule_to_invoke
199239
200240 other_trials_params = {}
201241 if other_trials_s3_input_paths is not None :
202- other_trials_params ["other_trials_s3_input_paths" ] = other_trials_s3_input_paths
242+ for index , s3_input_path in enumerate (other_trials_s3_input_paths ):
243+ other_trials_params ["other_trial_{}" .format (str (index ))] = s3_input_path
203244
204- combined_rule_params = source_params .update (other_trials_params ) or {}
245+ merged_rule_params .update (other_trials_params )
246+ merged_rule_params .update (rule_parameters or {})
205247
206248 return cls (
207249 name = name ,
208250 image_uri = image_uri ,
209251 instance_type = instance_type ,
210- container_local_path = container_local_path ,
252+ container_local_output_path = container_local_output_path ,
211253 s3_output_path = s3_output_path ,
212254 volume_size_in_gb = volume_size_in_gb ,
213- rule_parameters = combined_rule_params . update ( rule_parameters or {}) ,
255+ rule_parameters = merged_rule_params ,
214256 collections_to_save = collections_to_save or [],
215257 )
216258
@@ -221,21 +263,19 @@ def to_debugger_rule_config_dict(self):
221263 Returns:
222264 dict: An portion of an API request as a dictionary.
223265 """
224- if self .instance_type is None or self .volume_size_in_gb is None :
225- raise RuntimeError (
226- """Cannot create a dictionary if the instance type and volume size are not provided.
227- Please set the instance type and volume size for this Rule object."""
228- )
229-
230266 debugger_rule_config_request = {
231267 "RuleConfigurationName" : self .name ,
232268 "RuleEvaluatorImage" : self .image_uri ,
233- "InstanceType" : self .instance_type ,
234- "VolumeSizeInGB" : self .volume_size_in_gb ,
235269 }
236270
237- if self .container_local_path is not None :
238- debugger_rule_config_request ["LocalPath" ] = self .container_local_path
271+ if self .instance_type is not None :
272+ debugger_rule_config_request ["InstanceType" ] = self .instance_type
273+
274+ if self .volume_size_in_gb is not None :
275+ debugger_rule_config_request ["VolumeSizeInGB" ] = self .volume_size_in_gb
276+
277+ if self .container_local_output_path is not None :
278+ debugger_rule_config_request ["LocalPath" ] = self .container_local_output_path
239279
240280 if self .s3_output_path is not None :
241281 debugger_rule_config_request ["S3OutputPath" ] = self .s3_output_path
@@ -254,7 +294,7 @@ class DebuggerHookConfig(object):
254294 def __init__ (
255295 self ,
256296 s3_output_path ,
257- container_local_path = None ,
297+ container_local_output_path = None ,
258298 hook_parameters = None ,
259299 collection_configs = None ,
260300 ):
@@ -264,13 +304,13 @@ def __init__(
264304
265305 Args:
266306 s3_output_path (str): The location in S3 to store the output.
267- container_local_path (str): The path in the container.
307+ container_local_output_path (str): The path in the container.
268308 hook_parameters (dict): A dictionary of parameters.
269309 collection_configs ([sagemaker.debugger.CollectionConfig]): A list
270310 of CollectionConfig objects to be provided to the API.
271311 """
272312 self .s3_output_path = s3_output_path
273- self .container_local_path = container_local_path
313+ self .container_local_output_path = container_local_output_path
274314 self .hook_parameters = hook_parameters
275315 self .collection_configs = collection_configs
276316
@@ -283,8 +323,8 @@ def to_request_dict(self):
283323 """
284324 debugger_hook_config_request = {"S3OutputPath" : self .s3_output_path }
285325
286- if self .container_local_path is not None :
287- debugger_hook_config_request ["LocalPath" ] = self .container_local_path
326+ if self .container_local_output_path is not None :
327+ debugger_hook_config_request ["LocalPath" ] = self .container_local_output_path
288328
289329 if self .hook_parameters is not None :
290330 debugger_hook_config_request ["HookParameters" ] = self .hook_parameters
@@ -301,17 +341,17 @@ class TensorBoardOutputConfig(object):
301341 """TensorBoardOutputConfig provides options to customize
302342 debugging visualization using TensorBoard."""
303343
304- def __init__ (self , s3_output_path , container_local_path = None ):
344+ def __init__ (self , s3_output_path , container_local_output_path = None ):
305345 """Initialize an instance of TensorBoardOutputConfig.
306346 TensorBoardOutputConfig provides options to customize
307347 debugging visualization using TensorBoard.
308348
309349 Args:
310350 s3_output_path (str): The location in S3 to store the output.
311- container_local_path (str): The path in the container.
351+ container_local_output_path (str): The path in the container.
312352 """
313353 self .s3_output_path = s3_output_path
314- self .container_local_path = container_local_path
354+ self .container_local_output_path = container_local_output_path
315355
316356 def to_request_dict (self ):
317357 """Generates a request dictionary using the parameters provided
@@ -322,16 +362,16 @@ def to_request_dict(self):
322362 """
323363 tensorboard_output_config_request = {"S3OutputPath" : self .s3_output_path }
324364
325- if self .container_local_path is not None :
326- tensorboard_output_config_request ["LocalPath" ] = self .container_local_path
365+ if self .container_local_output_path is not None :
366+ tensorboard_output_config_request ["LocalPath" ] = self .container_local_output_path
327367
328368 return tensorboard_output_config_request
329369
330370
331371class CollectionConfig (object ):
332372 """CollectionConfig object for SageMaker Debugger."""
333373
334- def __init__ (self , name , parameters ):
374+ def __init__ (self , name , parameters = None ):
335375 """Initialize a ``CollectionConfig`` object.
336376
337377 Args:
@@ -359,7 +399,7 @@ def __ne__(self, other):
359399 return self .name != other .name or self .parameters != other .parameters
360400
361401 def __hash__ (self ):
362- return hash ((self .name , tuple (sorted (self .parameters .items ()))))
402+ return hash ((self .name , tuple (sorted (( self .parameters or {}) .items ()))))
363403
364404 def to_request_dict (self ):
365405 """Generates a request dictionary using the parameters provided
@@ -368,9 +408,9 @@ def to_request_dict(self):
368408 Returns:
369409 dict: An portion of an API request as a dictionary.
370410 """
371- collection_config_request = {
372- "CollectionName" : self . name ,
373- "CollectionParameters" : self . parameters ,
374- }
411+ collection_config_request = {"CollectionName" : self . name }
412+
413+ if self . parameters is not None :
414+ collection_config_request [ "CollectionParameters" ] = self . parameters
375415
376416 return collection_config_request
0 commit comments