55import warnings
66from typing import Any , overload
77
8- from typing_extensions import deprecated
8+ from typing_extensions import deprecated , override
99
1010from aws_lambda_powertools .utilities .data_classes .common import (
1111 BaseRequestContext ,
@@ -28,9 +28,10 @@ def __init__(
2828 aws_account_id : str ,
2929 api_id : str ,
3030 stage : str ,
31- http_method : str ,
31+ http_method : str | None ,
3232 resource : str ,
3333 partition : str = "aws" ,
34+ is_websocket_authorizer : bool = False ,
3435 ):
3536 self .partition = partition
3637 self .region = region
@@ -40,39 +41,54 @@ def __init__(
4041 self .http_method = http_method
4142 # Remove matching "/" from `resource`.
4243 self .resource = resource .lstrip ("/" )
44+ self .is_websocket_authorizer = is_websocket_authorizer
4345
4446 @property
4547 def arn (self ) -> str :
4648 """Build an arn from its parts
4749 eg: arn:aws:execute-api:us-east-1:123456789012:abcdef123/test/GET/request"""
48- return (
49- f"arn:{ self .partition } :execute-api:{ self .region } :{ self .aws_account_id } :{ self .api_id } /{ self .stage } /"
50- f"{ self .http_method } /{ self .resource } "
51- )
50+ base_arn = f"arn:{ self .partition } :execute-api:{ self .region } :{ self .aws_account_id } :{ self .api_id } /{ self .stage } "
51+
52+ if not self .is_websocket_authorizer :
53+ return f"{ base_arn } /{ self .http_method } /{ self .resource } "
54+ else :
55+ return f"{ base_arn } /{ self .resource } "
5256
5357
54- def parse_api_gateway_arn (arn : str ) -> APIGatewayRouteArn :
58+ def parse_api_gateway_arn (arn : str , is_websocket_authorizer : bool = False ) -> APIGatewayRouteArn :
5559 """Parses a gateway route arn as a APIGatewayRouteArn class
5660
5761 Parameters
5862 ----------
5963 arn : str
6064 ARN string for a methodArn or a routeArn
65+ is_websocket_authorizer: bool
66+ If it's a API Gateway Websocket
67+
6168 Returns
6269 -------
6370 APIGatewayRouteArn
6471 """
6572 arn_parts = arn .split (":" )
6673 api_gateway_arn_parts = arn_parts [5 ].split ("/" )
74+
75+ if not is_websocket_authorizer :
76+ http_method = api_gateway_arn_parts [2 ]
77+ resource = "/" .join (api_gateway_arn_parts [3 :]) if len (api_gateway_arn_parts ) >= 4 else ""
78+ else :
79+ http_method = None
80+ resource = "/" .join (api_gateway_arn_parts [2 :])
81+
6782 return APIGatewayRouteArn (
6883 partition = arn_parts [1 ],
6984 region = arn_parts [3 ],
7085 aws_account_id = arn_parts [4 ],
7186 api_id = api_gateway_arn_parts [0 ],
7287 stage = api_gateway_arn_parts [1 ],
73- http_method = api_gateway_arn_parts [ 2 ] ,
88+ http_method = http_method ,
7489 # conditional allow us to handle /path/{proxy+} resources, as their length changes.
75- resource = "/" .join (api_gateway_arn_parts [3 :]) if len (api_gateway_arn_parts ) >= 4 else "" ,
90+ resource = resource ,
91+ is_websocket_authorizer = is_websocket_authorizer ,
7692 )
7793
7894
@@ -512,13 +528,14 @@ def _add_route(self, effect: str, http_method: str, resource: str, conditions: l
512528 raise ValueError (f"Invalid resource path: { resource } . Path should match { self .path_regex } " )
513529
514530 resource_arn = APIGatewayRouteArn (
515- self .region ,
516- self .aws_account_id ,
517- self .api_id ,
518- self .stage ,
519- http_method ,
520- resource ,
521- self .partition ,
531+ region = self .region ,
532+ aws_account_id = self .aws_account_id ,
533+ api_id = self .api_id ,
534+ stage = self .stage ,
535+ http_method = http_method ,
536+ resource = resource ,
537+ partition = self .partition ,
538+ is_websocket_authorizer = False ,
522539 ).arn
523540
524541 route = {"resourceArn" : resource_arn , "conditions" : conditions }
@@ -617,3 +634,127 @@ def asdict(self) -> dict[str, Any]:
617634 response ["context" ] = self .context
618635
619636 return response
637+
638+
639+ class APIGatewayAuthorizerResponseWebSocket (APIGatewayAuthorizerResponse ):
640+ """The IAM Policy Response required for API Gateway WebSocket APIs
641+
642+ Based on: - https://github.com/awslabs/aws-apigateway-lambda-authorizer-blueprints/blob/\
643+ master/blueprints/python/api-gateway-authorizer-python.py
644+
645+ Documentation:
646+ -------------
647+ - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-lambda-authorizer.html
648+ - https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-lambda-authorizer-output.html
649+ """
650+
651+ @staticmethod
652+ def from_route_arn (
653+ arn : str ,
654+ principal_id : str ,
655+ context : dict | None = None ,
656+ usage_identifier_key : str | None = None ,
657+ ) -> APIGatewayAuthorizerResponseWebSocket :
658+ parsed_arn = parse_api_gateway_arn (arn , is_websocket_authorizer = True )
659+ return APIGatewayAuthorizerResponseWebSocket (
660+ principal_id ,
661+ parsed_arn .region ,
662+ parsed_arn .aws_account_id ,
663+ parsed_arn .api_id ,
664+ parsed_arn .stage ,
665+ context ,
666+ usage_identifier_key ,
667+ )
668+
669+ # Note: we need ignore[override] because we are removing the http_method field
670+ @override
671+ def _add_route (self , effect : str , resource : str , conditions : list [dict ] | None = None ): # type: ignore[override]
672+ """Adds a route to the internal lists of allowed or denied routes. Each object in
673+ the internal list contains a resource ARN and a condition statement. The condition
674+ statement can be null."""
675+ resource_arn = APIGatewayRouteArn (
676+ region = self .region ,
677+ aws_account_id = self .aws_account_id ,
678+ api_id = self .api_id ,
679+ stage = self .stage ,
680+ http_method = None ,
681+ resource = resource ,
682+ partition = self .partition ,
683+ is_websocket_authorizer = True ,
684+ ).arn
685+
686+ route = {"resourceArn" : resource_arn , "conditions" : conditions }
687+
688+ if effect .lower () == "allow" :
689+ self ._allow_routes .append (route )
690+ else : # deny
691+ self ._deny_routes .append (route )
692+
693+ @override
694+ def allow_all_routes (self ):
695+ """Adds a '*' allow to the policy to authorize access to all methods of an API"""
696+ self ._add_route (effect = "Allow" , resource = "*" )
697+
698+ @override
699+ def deny_all_routes (self ):
700+ """Adds a '*' allow to the policy to deny access to all methods of an API"""
701+
702+ self ._add_route (effect = "Deny" , resource = "*" )
703+
704+ # Note: we need ignore[override] because we are removing the http_method field
705+ @override
706+ def allow_route (self , resource : str , conditions : list [dict ] | None = None ): # type: ignore[override]
707+ """
708+ Add an API Gateway Websocket method to the list of allowed methods for the policy.
709+
710+ This method adds an API Gateway Websocket method Resource path) to the list of
711+ allowed methods for the policy. It optionally includes conditions for the policy statement.
712+
713+ Parameters
714+ ----------
715+ resource : str
716+ The API Gateway resource path to allow.
717+ conditions : list[dict] | None, optional
718+ A list of condition dictionaries to apply to the policy statement.
719+ Default is None.
720+
721+ Notes
722+ -----
723+ For more information on AWS policy conditions, see:
724+ https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition
725+
726+ Example
727+ --------
728+ >>> policy = APIGatewayAuthorizerResponseWebSocket(...)
729+ >>> policy.allow_route("/api/users", [{"StringEquals": {"aws:RequestTag/Environment": "Production"}}])
730+ """
731+ self ._add_route (effect = "Allow" , resource = resource , conditions = conditions )
732+
733+ # Note: we need ignore[override] because we are removing the http_method field
734+ @override
735+ def deny_route (self , resource : str , conditions : list [dict ] | None = None ): # type: ignore[override]
736+ """
737+ Add an API Gateway Websocket method to the list of allowed methods for the policy.
738+
739+ This method adds an API Gateway Websocket method Resource path) to the list of
740+ denied methods for the policy. It optionally includes conditions for the policy statement.
741+
742+ Parameters
743+ ----------
744+ resource : str
745+ The API Gateway resource path to allow.
746+ conditions : list[dict] | None, optional
747+ A list of condition dictionaries to apply to the policy statement.
748+ Default is None.
749+
750+ Notes
751+ -----
752+ For more information on AWS policy conditions, see:
753+ https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition
754+
755+ Example
756+ --------
757+ >>> policy = APIGatewayAuthorizerResponseWebSocket(...)
758+ >>> policy.deny_route("/api/users", [{"StringEquals": {"aws:RequestTag/Environment": "Production"}}])
759+ """
760+ self ._add_route (effect = "Deny" , resource = resource , conditions = conditions )
0 commit comments