Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

project = "sagemaker-core"
copyright = (
"%s, Amazon Web Services, Inc. or its affiliates. All rights reserved."
% datetime.now().year
"%s, Amazon Web Services, Inc. or its affiliates. All rights reserved." % datetime.now().year
)
author = "Amazon Web Services"

Expand Down Expand Up @@ -65,4 +64,3 @@

# autosectionlabel
autosectionlabel_prefix_document = True

3 changes: 2 additions & 1 deletion integ/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def test_training_and_inference(self):
)
endpoint: Endpoint = Endpoint.create(
endpoint_name=key,
endpoint_config_name=endpoint_config, # Pass `EndpointConfig` object created above
# Pass `EndpointConfig` object created above
endpoint_config_name=endpoint_config,
)
endpoint.wait_for_status("InService")

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Repository = "https://github.com/aws/sagemaker-core.git"

[tool.black]
line-length = 100
exclude = '\.ipynb$'

[tool.setuptools.dynamic]
version = { attr = "sagemaker_core._version.__version__"}
2 changes: 1 addition & 1 deletion src/sagemaker_core/main/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@


class Base(BaseModel):
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True)
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid")

@classmethod
def get_sagemaker_client(cls, session=None, region_name=None, service_name="sagemaker"):
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker_core/main/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Base(BaseModel):
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True)
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid")


class ActionSource(Base):
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker_core/main/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ def __next__(self) -> T:

if self.custom_key_mapping:
init_data = {self.custom_key_mapping.get(k, k): v for k, v in init_data.items()}

# Filter out the fields that are not in the resource class
fields = self.resource_cls.__annotations__
init_data = {k: v for k, v in init_data.items() if k in fields}

resource_object = self.resource_cls(**init_data)

# If the resource object has refresh method, refresh and return it
Expand Down
12 changes: 8 additions & 4 deletions src/sagemaker_core/tools/resources_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,8 @@ def generate_create_method(self, resource_name: str, **kwargs) -> str:
resource_name=resource_name,
create_args=create_args,
resource_lower=resource_lower,
service_name="sagemaker", # TODO: change service name based on the service - runtime, sagemaker, etc.
# TODO: change service name based on the service - runtime, sagemaker, etc.
service_name="sagemaker",
operation_input_args=operation_input_args,
operation=operation,
get_args=get_args,
Expand All @@ -908,7 +909,8 @@ def generate_create_method(self, resource_name: str, **kwargs) -> str:
resource_name=resource_name,
create_args=create_args,
resource_lower=resource_lower,
service_name="sagemaker", # TODO: change service name based on the service - runtime, sagemaker, etc.
# TODO: change service name based on the service - runtime, sagemaker, etc.
service_name="sagemaker",
operation_input_args=operation_input_args,
operation=operation,
get_args=get_args,
Expand Down Expand Up @@ -1074,7 +1076,8 @@ def generate_import_method(self, resource_name: str) -> str:
resource_name=resource_name,
import_args=import_args,
resource_lower=resource_lower,
service_name="sagemaker", # TODO: change service name based on the service - runtime, sagemaker, etc.
# TODO: change service name based on the service - runtime, sagemaker, etc.
service_name="sagemaker",
operation_input_args=operation_input_args,
operation=operation,
get_args=get_args,
Expand Down Expand Up @@ -1378,7 +1381,8 @@ def generate_get_method(self, resource_name: str) -> str:
formatted_method = GET_METHOD_TEMPLATE.format(
docstring=docstring,
resource_name=resource_name,
service_name="sagemaker", # TODO: change service name based on the service - runtime, sagemaker, etc.
# TODO: change service name based on the service - runtime, sagemaker, etc.
service_name="sagemaker",
describe_args=describe_args,
resource_lower=resource_lower,
operation_input_args=operation_input_args,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker_core/tools/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def {method_name}(

RESOURCE_BASE_CLASS_TEMPLATE = """
class Base(BaseModel):
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True)
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid")

@classmethod
def get_sagemaker_client(cls, session = None, region_name = None, service_name = 'sagemaker'):
Expand Down Expand Up @@ -709,7 +709,7 @@ def wrapper(*args, **kwargs):

SHAPE_BASE_CLASS_TEMPLATE = """
class {class_name}:
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True)
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid")
"""

SHAPE_CLASS_TEMPLATE = '''
Expand Down
3 changes: 2 additions & 1 deletion tst/generated/test_resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import importlib, inspect
import importlib
import inspect
import unittest
from unittest.mock import patch

Expand Down
2 changes: 1 addition & 1 deletion workflow_helper/compute_boto_api_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def main():
"""
This function computes the number of APIs covered and uncovered by sagemaker core to the ones in Botocore.
"""
configure_logging("ERROR") # Disable other log messages
configure_logging("ERROR") # Disable other log messages
resources_extractor = ResourcesExtractor()
# Print the number of unsupported Botocore API and supported Botocore API
print(len(resources_extractor.actions), len(resources_extractor.actions_under_resource))
Expand Down