@@ -271,13 +271,61 @@ def test_spark_processor_base_extend_processing_args(
271271serialized_configuration = BytesIO ("test" .encode ("utf-8" ))
272272
273273
274+ @pytest .mark .parametrize (
275+ "config, expected" ,
276+ [
277+ (
278+ {
279+ "spark_processor_type" : "py_spark_processor" ,
280+ "configuration_location" : None ,
281+ },
282+ "s3://bucket/None/input/conf/configuration.json" ,
283+ ),
284+ (
285+ {
286+ "spark_processor_type" : "py_spark_processor" ,
287+ "configuration_location" : "s3://configbucket/someprefix/" ,
288+ },
289+ "s3://configbucket/someprefix/None/input/conf/configuration.json" ,
290+ ),
291+ (
292+ {
293+ "spark_processor_type" : "spark_jar_processor" ,
294+ "configuration_location" : None ,
295+ },
296+ "s3://bucket/None/input/conf/configuration.json" ,
297+ ),
298+ (
299+ {
300+ "spark_processor_type" : "spark_jar_processor" ,
301+ "configuration_location" : "s3://configbucket/someprefix" ,
302+ },
303+ "s3://configbucket/someprefix/None/input/conf/configuration.json" ,
304+ ),
305+ ],
306+ )
274307@patch ("sagemaker.spark.processing.BytesIO" )
275308@patch ("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body" )
276- def test_stage_configuration (mock_s3_upload , mock_bytesIO , py_spark_processor , sagemaker_session ):
277- desired_s3_uri = "s3://bucket/None/input/conf/configuration.json"
309+ def test_stage_configuration (mock_s3_upload , mock_bytesIO , config , expected , sagemaker_session ):
310+ spark_processor_type = {
311+ "py_spark_processor" : PySparkProcessor ,
312+ "spark_jar_processor" : SparkJarProcessor ,
313+ }[config ["spark_processor_type" ]]
314+ spark_processor = spark_processor_type (
315+ base_job_name = "sm-spark" ,
316+ role = "AmazonSageMaker-ExecutionRole" ,
317+ framework_version = "2.4" ,
318+ instance_count = 1 ,
319+ instance_type = "ml.c5.xlarge" ,
320+ image_uri = "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1" ,
321+ configuration_location = config ["configuration_location" ],
322+ sagemaker_session = sagemaker_session ,
323+ )
324+
325+ desired_s3_uri = expected
278326 mock_bytesIO .return_value = serialized_configuration
279327
280- result = py_spark_processor ._stage_configuration ({})
328+ result = spark_processor ._stage_configuration ({})
281329
282330 mock_s3_upload .assert_called_with (
283331 body = serialized_configuration ,
@@ -290,23 +338,121 @@ def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, s
290338@pytest .mark .parametrize (
291339 "config, expected" ,
292340 [
293- ({"submit_deps" : None , "input_channel_name" : "channelName" }, ValueError ),
294- ({"submit_deps" : ["s3" ], "input_channel_name" : None }, ValueError ),
295- ({"submit_deps" : ["other" ], "input_channel_name" : "channelName" }, ValueError ),
296- ({"submit_deps" : ["file" ], "input_channel_name" : "channelName" }, ValueError ),
297- ({"submit_deps" : ["file" ], "input_channel_name" : "channelName" }, ValueError ),
298341 (
299- {"submit_deps" : ["s3" , "s3" ], "input_channel_name" : "channelName" },
342+ {
343+ "spark_processor_type" : "py_spark_processor" ,
344+ "dependency_location" : None ,
345+ "submit_deps" : None ,
346+ "input_channel_name" : "channelName" ,
347+ },
348+ ValueError ,
349+ ),
350+ (
351+ {
352+ "spark_processor_type" : "py_spark_processor" ,
353+ "dependency_location" : None ,
354+ "submit_deps" : ["s3" ],
355+ "input_channel_name" : None ,
356+ },
357+ ValueError ,
358+ ),
359+ (
360+ {
361+ "spark_processor_type" : "py_spark_processor" ,
362+ "dependency_location" : None ,
363+ "submit_deps" : ["other" ],
364+ "input_channel_name" : "channelName" ,
365+ },
366+ ValueError ,
367+ ),
368+ (
369+ {
370+ "spark_processor_type" : "py_spark_processor" ,
371+ "dependency_location" : None ,
372+ "submit_deps" : ["file" ],
373+ "input_channel_name" : "channelName" ,
374+ },
375+ ValueError ,
376+ ),
377+ (
378+ {
379+ "spark_processor_type" : "py_spark_processor" ,
380+ "dependency_location" : None ,
381+ "submit_deps" : ["file" ],
382+ "input_channel_name" : "channelName" ,
383+ },
384+ ValueError ,
385+ ),
386+ (
387+ {
388+ "spark_processor_type" : "py_spark_processor" ,
389+ "dependency_location" : None ,
390+ "submit_deps" : ["s3" , "s3" ],
391+ "input_channel_name" : "channelName" ,
392+ },
300393 (None , "s3://bucket,s3://bucket" ),
301394 ),
302395 (
303- {"submit_deps" : ["jar" ], "input_channel_name" : "channelName" },
304- (processing_input , "s3://bucket" ),
396+ {
397+ "spark_processor_type" : "py_spark_processor" ,
398+ "dependency_location" : None ,
399+ "submit_deps" : ["jar" ],
400+ "input_channel_name" : "channelName" ,
401+ },
402+ ("s3://bucket/None/input/channelName" , "/opt/ml/processing/input/channelName" ),
403+ ),
404+ (
405+ {
406+ "spark_processor_type" : "py_spark_processor" ,
407+ "dependency_location" : "s3://codebucket/someprefix/" ,
408+ "submit_deps" : ["jar" ],
409+ "input_channel_name" : "channelName" ,
410+ },
411+ (
412+ "s3://codebucket/someprefix/None/input/channelName" ,
413+ "/opt/ml/processing/input/channelName"
414+ ),
415+ ),
416+ (
417+ {
418+ "spark_processor_type" : "spark_jar_processor" ,
419+ "dependency_location" : None ,
420+ "submit_deps" : ["jar" ],
421+ "input_channel_name" : "channelName" ,
422+ },
423+ ("s3://bucket/None/input/channelName" , "/opt/ml/processing/input/channelName" ),
424+ ),
425+ (
426+ {
427+ "spark_processor_type" : "spark_jar_processor" ,
428+ "dependency_location" : "s3://codebucket/someprefix" ,
429+ "submit_deps" : ["jar" ],
430+ "input_channel_name" : "channelName" ,
431+ },
432+ (
433+ "s3://codebucket/someprefix/None/input/channelName" ,
434+ "/opt/ml/processing/input/channelName"
435+ ),
305436 ),
306437 ],
307438)
308439@patch ("sagemaker.spark.processing.S3Uploader" )
309- def test_stage_submit_deps (mock_s3_uploader , py_spark_processor , jar_file , config , expected ):
440+ def test_stage_submit_deps (mock_s3_uploader , jar_file , config , expected , sagemaker_session ):
441+ spark_processor_type = {
442+ "py_spark_processor" : PySparkProcessor ,
443+ "spark_jar_processor" : SparkJarProcessor ,
444+ }[config ["spark_processor_type" ]]
445+ spark_processor = spark_processor_type (
446+ base_job_name = "sm-spark" ,
447+ role = "AmazonSageMaker-ExecutionRole" ,
448+ framework_version = "2.4" ,
449+ instance_count = 1 ,
450+ instance_type = "ml.c5.xlarge" ,
451+ image_uri = "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1" ,
452+ dependency_location = config ["dependency_location" ],
453+ sagemaker_session = sagemaker_session ,
454+ )
455+
310456 submit_deps_dict = {
311457 None : None ,
312458 "s3" : "s3://bucket" ,
@@ -320,21 +466,20 @@ def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, confi
320466
321467 if expected is ValueError :
322468 with pytest .raises (expected ) as e :
323- py_spark_processor ._stage_submit_deps (submit_deps , config ["input_channel_name" ])
469+ spark_processor ._stage_submit_deps (submit_deps , config ["input_channel_name" ])
324470
325471 assert isinstance (e .value , expected )
326472 else :
327- input_channel , spark_opt = py_spark_processor ._stage_submit_deps (
473+ input_channel , spark_opt = spark_processor ._stage_submit_deps (
328474 submit_deps , config ["input_channel_name" ]
329475 )
330476
331477 if expected [0 ] is None :
332478 assert input_channel is None
333479 assert spark_opt == expected [1 ]
334480 else :
335- expected_source = "s3://bucket/None/input/channelName"
336- assert input_channel .source == expected_source
337- assert spark_opt == "/opt/ml/processing/input/channelName"
481+ assert input_channel .source == expected [0 ]
482+ assert spark_opt == expected [1 ]
338483
339484
340485@pytest .mark .parametrize (
0 commit comments