Skip to content

Error rework #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ private void generateOperationExecutor(PythonWriter writer) {

var transportRequest = context.applicationProtocol().requestType();
var transportResponse = context.applicationProtocol().responseType();
var errorSymbol = CodegenUtils.getServiceError(context.settings());
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
var configSymbol = CodegenUtils.getConfigSymbol(context.settings());

Expand All @@ -141,63 +140,20 @@ private void generateOperationExecutor(PythonWriter writer) {
writer.addStdlibImport("dataclasses", "replace");

writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
writer.addImport("smithy_core.exceptions", "SmithyRetryException");
writer.addImports("smithy_core.interceptors",
Set.of("Interceptor",
"InterceptorChain",
"InputContext",
"OutputContext",
"RequestContext",
"ResponseContext"));
writer.addImports("smithy_core.interfaces.retries", Set.of("RetryErrorInfo", "RetryErrorType"));
writer.addImport("smithy_core.interfaces.exceptions", "HasFault");
writer.addImport("smithy_core.types", "TypedProperties");
writer.addImport("smithy_core.serializers", "SerializeableShape");
writer.addImport("smithy_core.deserializers", "DeserializeableShape");
writer.addImport("smithy_core.schemas", "APIOperation");

writer.indent();
writer.write("""
def _classify_error(
self,
*,
error: Exception,
context: ResponseContext[Any, $1T, $2T | None]
) -> RetryErrorInfo:
logger.debug("Classifying error: %s", error)
""", transportRequest, transportResponse);
writer.indent();

if (context.applicationProtocol().isHttpProtocol()) {
writer.addDependency(SmithyPythonDependency.SMITHY_HTTP);
writer.write("""
if not isinstance(error, HasFault) and not context.transport_response:
return RetryErrorInfo(error_type=RetryErrorType.TRANSIENT)

if context.transport_response:
if context.transport_response.status in [429, 503]:
retry_after = None
retry_header = context.transport_response.fields["retry-after"]
if retry_header and retry_header.values:
retry_after = float(retry_header.values[0])
return RetryErrorInfo(error_type=RetryErrorType.THROTTLING, retry_after_hint=retry_after)

if context.transport_response.status >= 500:
return RetryErrorInfo(error_type=RetryErrorType.SERVER_ERROR)

""");
}

writer.write("""
error_type = RetryErrorType.CLIENT_ERROR
if isinstance(error, HasFault) and error.fault == "server":
error_type = RetryErrorType.SERVER_ERROR

return RetryErrorInfo(error_type=error_type)

""");
writer.dedent();

if (hasStreaming) {
writer.addStdlibImports("typing", Set.of("Any", "Awaitable"));
writer.addStdlibImport("asyncio");
Expand Down Expand Up @@ -302,46 +258,54 @@ def _classify_error(
}
writer.addStdlibImport("typing", "Any");
writer.addStdlibImport("asyncio", "iscoroutine");
writer.addImports("smithy_core.exceptions", Set.of("SmithyError", "CallError", "RetryError"));
writer.pushState();
writer.putContext("request", transportRequest);
writer.putContext("response", transportResponse);
writer.putContext("plugin", pluginSymbol);
writer.putContext("config", configSymbol);
writer.write(
"""
async def _execute_operation[Input: SerializeableShape, Output: DeserializeableShape](
self,
input: Input,
plugins: list[$1T],
serialize: Callable[[Input, $5T], Awaitable[$2T]],
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
plugins: list[${plugin:T}],
serialize: Callable[[Input, ${config:T}], Awaitable[${request:T}]],
deserialize: Callable[[${response:T}, ${config:T}], Awaitable[Output]],
config: ${config:T},
operation: APIOperation[Input, Output],
request_future: Future[RequestContext[Any, $2T]] | None = None,
response_future: Future[$3T] | None = None,
request_future: Future[RequestContext[Any, ${request:T}]] | None = None,
response_future: Future[${response:T}] | None = None,
) -> Output:
try:
return await self._handle_execution(
input, plugins, serialize, deserialize, config, operation,
request_future, response_future,
)
except Exception as e:
# Make sure every exception that we throw is an instance of SmithyError so
# customers can reliably catch everything we throw.
if not isinstance(e, SmithyError):
wrapped = CallError(str(e))
wrapped.__cause__ = e
e = wrapped

if request_future is not None and not request_future.done():
request_future.set_exception($4T(e))
request_future.set_exception(e)
if response_future is not None and not response_future.done():
response_future.set_exception($4T(e))

# Make sure every exception that we throw is an instance of $4T so
# customers can reliably catch everything we throw.
if not isinstance(e, $4T):
raise $4T(e) from e
response_future.set_exception(e)
raise

async def _handle_execution[Input: SerializeableShape, Output: DeserializeableShape](
self,
input: Input,
plugins: list[$1T],
serialize: Callable[[Input, $5T], Awaitable[$2T]],
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
plugins: list[${plugin:T}],
serialize: Callable[[Input, ${config:T}], Awaitable[${request:T}]],
deserialize: Callable[[${response:T}, ${config:T}], Awaitable[Output]],
config: ${config:T},
operation: APIOperation[Input, Output],
request_future: Future[RequestContext[Any, $2T]] | None,
response_future: Future[$3T] | None,
request_future: Future[RequestContext[Any, ${request:T}]] | None,
response_future: Future[${response:T}] | None,
) -> Output:
operation_name = operation.schema.id.name
logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input)
Expand All @@ -350,11 +314,16 @@ def _classify_error(
plugin(config)

input_context = InputContext(request=input, properties=TypedProperties({"config": config}))
transport_request: $2T | None = None
output_context: OutputContext[Input, Output, $2T | None, $3T | None] | None = None
transport_request: ${request:T} | None = None
output_context: OutputContext[
Input,
Output,
${request:T} | None,
${response:T} | None
] | None = None

client_interceptors = cast(
list[Interceptor[Input, Output, $2T, $3T]], list(config.interceptors)
list[Interceptor[Input, Output, ${request:T}, ${response:T}]], list(config.interceptors)
)
interceptor_chain = InterceptorChain(client_interceptors)

Expand Down Expand Up @@ -413,12 +382,9 @@ def _classify_error(
try:
retry_token = retry_strategy.refresh_retry_token_for_retry(
token_to_renew=retry_token,
error_info=self._classify_error(
error=output_context.response,
context=output_context,
)
error=output_context.response,
)
except SmithyRetryException:
except RetryError:
raise output_context.response
logger.debug(
"Retry needed. Attempting request #%s in %.4f seconds.",
Expand Down Expand Up @@ -455,24 +421,20 @@ await sleep(retry_token.retry_delay)

async def _handle_attempt[Input: SerializeableShape, Output: DeserializeableShape](
self,
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
interceptor: Interceptor[Input, Output, $2T, $3T],
context: RequestContext[Input, $2T],
config: $5T,
deserialize: Callable[[${response:T}, ${config:T}], Awaitable[Output]],
interceptor: Interceptor[Input, Output, ${request:T}, ${response:T}],
context: RequestContext[Input, ${request:T}],
config: ${config:T},
operation: APIOperation[Input, Output],
request_future: Future[RequestContext[Input, $2T]] | None,
) -> OutputContext[Input, Output, $2T, $3T | None]:
transport_response: $3T | None = None
request_future: Future[RequestContext[Input, ${request:T}]] | None,
) -> OutputContext[Input, Output, ${request:T}, ${response:T} | None]:
transport_response: ${response:T} | None = None
try:
# Step 7a: Invoke read_before_attempt
interceptor.read_before_attempt(context)

""",
pluginSymbol,
transportRequest,
transportResponse,
errorSymbol,
configSymbol);
""");
writer.popState();

boolean supportsAuth = !ServiceIndex.of(model).getAuthSchemes(service).isEmpty();
writer.pushState(new ResolveIdentitySection());
Expand Down Expand Up @@ -873,8 +835,8 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat
.orElse("The operation's input.");

writer.write("""
$L
""",docs);
$L
""", docs);
writer.write("");
writer.write(":param input: $L", inputDocs);
writer.write("");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ public static Symbol getPluginSymbol(PythonSettings settings) {
/**
* Gets the service error symbol.
*
* <p>This error is the top-level error for the client. Every error surfaced by
* the client MUST be a subclass of this so that customers can reliably catch all
* exceptions it raises. The client implementation will wrap any errors that aren't
* already subclasses.
* <p>This error is the top-level error for modeled client errors.
*
* @param settings The client settings, used to account for module configuration.
* @return Returns the symbol for the client's error class.
Expand All @@ -105,40 +102,6 @@ public static Symbol getServiceError(PythonSettings settings) {
.build();
}

/**
* Gets the service API error symbol.
*
* <p>This error is the parent class for all errors returned over the wire by the
* service, including unknown errors.
*
* @param settings The client settings, used to account for module configuration.
* @return Returns the symbol for the client's API error class.
*/
public static Symbol getApiError(PythonSettings settings) {
return Symbol.builder()
.name("ApiError")
.namespace(String.format("%s.models", settings.moduleName()), ".")
.definitionFile(String.format("./src/%s/models.py", settings.moduleName()))
.build();
}

/**
* Gets the unknown API error symbol.
*
* <p> This error is the parent class for all errors returned over the wire by
* the service which aren't in the model.
*
* @param settings The client settings, used to account for module configuration.
* @return Returns the symbol for unknown API errors.
*/
public static Symbol getUnknownApiError(PythonSettings settings) {
return Symbol.builder()
.name("UnknownApiError")
.namespace(String.format("%s.models", settings.moduleName()), ".")
.definitionFile(String.format("./src/%s/models.py", settings.moduleName()))
.build();
}

/**
* Gets the symbol for the http auth params.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.util.Locale;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider;
import software.amazon.smithy.codegen.core.ReservedWords;
import software.amazon.smithy.codegen.core.ReservedWordsBuilder;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
Expand Down Expand Up @@ -84,6 +83,10 @@ public PythonSymbolProvider(Model model, PythonSettings settings) {
var reservedMemberNamesBuilder = new ReservedWordsBuilder()
.loadWords(PythonSymbolProvider.class.getResource("reserved-member-names.txt"), this::escapeWord);

// Reserved words that only apply to error members.
var reservedErrorMembers = new ReservedWordsBuilder()
.loadWords(PythonSymbolProvider.class.getResource("reserved-error-member-names.txt"), this::escapeWord);

escaper = ReservedWordSymbolProvider.builder()
.nameReservedWords(reservedClassNames)
.memberReservedWords(reservedMemberNamesBuilder.build())
Expand All @@ -92,13 +95,8 @@ public PythonSymbolProvider(Model model, PythonSettings settings) {
.escapePredicate((shape, symbol) -> !StringUtils.isEmpty(symbol.getDefinitionFile()))
.buildEscaper();

// Reserved words that only apply to error members.
ReservedWords reservedErrorMembers = reservedMemberNamesBuilder
.put("code", "code_")
.build();

errorMemberEscaper = ReservedWordSymbolProvider.builder()
.memberReservedWords(reservedErrorMembers)
.memberReservedWords(reservedErrorMembers.build())
.escapePredicate((shape, symbol) -> !StringUtils.isEmpty(symbol.getDefinitionFile()))
.buildEscaper();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package software.amazon.smithy.python.codegen.generators;

import java.util.Set;
import software.amazon.smithy.codegen.core.WriterDelegator;
import software.amazon.smithy.python.codegen.CodegenUtils;
import software.amazon.smithy.python.codegen.PythonSettings;
Expand All @@ -30,38 +29,15 @@ public void run() {
var serviceError = CodegenUtils.getServiceError(settings);
writers.useFileWriter(serviceError.getDefinitionFile(), serviceError.getNamespace(), writer -> {
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
writer.addImport("smithy_core.exceptions", "SmithyException");
writer.addImport("smithy_core.exceptions", "ModeledError");
writer.write("""
class $L(SmithyException):
""\"Base error for all errors in the service.""\"
pass
""", serviceError.getName());
});

var apiError = CodegenUtils.getApiError(settings);
writers.useFileWriter(apiError.getDefinitionFile(), apiError.getNamespace(), writer -> {
writer.addStdlibImports("typing", Set.of("Literal", "ClassVar"));
var unknownApiError = CodegenUtils.getUnknownApiError(settings);

writer.write("""
@dataclass
class $1L($2T):
""\"Base error for all API errors in the service.""\"
code: ClassVar[str]
fault: ClassVar[Literal["client", "server"]]
class $L(ModeledError):
""\"Base error for all errors in the service.

message: str

def __post_init__(self) -> None:
super().__init__(self.message)


@dataclass
class $3L($1L):
""\"Error representing any unknown api errors.""\"
code: ClassVar[str] = 'Unknown'
fault: ClassVar[Literal["client", "server"]] = "client"
""", apiError.getName(), serviceError, unknownApiError.getName());
Some exceptions do not extend from this class, including
synthetic, implicit, and shared exception types.
""\"
""", serviceError.getName());
});
}
}
Loading