|
16 | 16 | import pytest |
17 | 17 | import yaml |
18 | 18 | import logging |
19 | | -from mock import Mock, MagicMock, patch |
| 19 | +from mock import Mock, MagicMock, patch, call |
20 | 20 |
|
21 | 21 | from sagemaker.config.config import ( |
22 | 22 | load_local_mode_config, |
23 | 23 | load_sagemaker_config, |
24 | 24 | logger, |
| 25 | + non_repeating_log_factory, |
25 | 26 | _DEFAULT_ADMIN_CONFIG_FILE_PATH, |
26 | 27 | _DEFAULT_USER_CONFIG_FILE_PATH, |
27 | 28 | _DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH, |
@@ -349,6 +350,26 @@ def test_logging_when_default_admin_not_found_and_default_user_config_not_found( |
349 | 350 | logger.propagate = False |
350 | 351 |
|
351 | 352 |
|
| 353 | +@patch("sagemaker.config.config.log_info_function") |
| 354 | +def test_load_config_without_repeating_log(log_info): |
| 355 | + |
| 356 | + load_sagemaker_config(repeat_log=False) |
| 357 | + assert log_info.call_count == 2 |
| 358 | + log_info.assert_has_calls( |
| 359 | + [ |
| 360 | + call( |
| 361 | + "Not applying SDK defaults from location: %s", |
| 362 | + _DEFAULT_ADMIN_CONFIG_FILE_PATH, |
| 363 | + ), |
| 364 | + call( |
| 365 | + "Not applying SDK defaults from location: %s", |
| 366 | + _DEFAULT_USER_CONFIG_FILE_PATH, |
| 367 | + ), |
| 368 | + ], |
| 369 | + any_order=True, |
| 370 | + ) |
| 371 | + |
| 372 | + |
352 | 373 | def test_logging_when_default_admin_not_found_and_overriden_user_config_not_found( |
353 | 374 | get_data_dir, caplog |
354 | 375 | ): |
@@ -421,3 +442,19 @@ def test_load_local_mode_config(mock_load_config): |
421 | 442 |
|
422 | 443 | def test_load_local_mode_config_when_config_file_is_not_found(): |
423 | 444 | assert load_local_mode_config() is None |
| 445 | + |
| 446 | + |
| 447 | +@pytest.mark.parametrize( |
| 448 | + "method_name", |
| 449 | + ["info", "warning", "debug"], |
| 450 | +) |
| 451 | +def test_non_repeating_log_factory(method_name): |
| 452 | + tmp_logger = logging.getLogger("test-logger") |
| 453 | + mock = MagicMock() |
| 454 | + setattr(tmp_logger, method_name, mock) |
| 455 | + |
| 456 | + log_function = non_repeating_log_factory(tmp_logger, method_name) |
| 457 | + log_function("foo") |
| 458 | + log_function("foo") |
| 459 | + |
| 460 | + mock.assert_called_once() |
0 commit comments