Skip to content

Commit

Permalink
Upgrade to Sanic 22.12LTS (#1103)
Browse files Browse the repository at this point in the history
* update Sanic version and Sanic app instantiation for testing purposes

* add listener to share context between Sanic workers

* update sharing of context

* update how the tracer provider listener is shared across sanic workers

* update some tests

* revert scope

* update listener

* change listener and context types

* use legacy arg in app.run

* add changelog entry

* move custom actions used in tests to conftest.py
  • Loading branch information
ancalita authored Jun 4, 2024
1 parent c4e3d1d commit 2dd8011
Show file tree
Hide file tree
Showing 17 changed files with 268 additions and 221 deletions.
2 changes: 2 additions & 0 deletions changelog/1103.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Upgrade Sanic to v22.12LTS.
Refactor loading of tracer provider to be triggered by Sanic `before_server_start` event listener.
35 changes: 18 additions & 17 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ select = [ "D", "E", "F", "W", "RUF",]
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
coloredlogs = ">=10,<16"
sanic = "^21.12.0"
sanic = "^22.12"
typing-extensions = ">=4.1.1,<5.0.0"
Sanic-Cors = "^2.0.0"
prompt-toolkit = "^3.0,<3.0.29"
Expand All @@ -99,7 +99,7 @@ toml = "^0.10.0"
pep440-version-utils = "^0.3.0"
semantic_version = "^2.8.5"
mypy = "^1.5"
sanic-testing = "^22.3.0, <22.9.0"
sanic-testing = "^22.12"

[tool.ruff.pydocstyle]
convention = "google"
Expand Down
4 changes: 1 addition & 3 deletions rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from rasa_sdk import utils
from rasa_sdk.endpoint import create_argument_parser, run
from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME
from rasa_sdk.tracing.utils import get_tracer_provider


def main_from_args(args):
Expand All @@ -18,7 +17,6 @@ def main_from_args(args):
args.logging_config_file,
)
utils.update_sanic_log_level()
tracer_provider = get_tracer_provider(args)

run(
args.actions,
Expand All @@ -28,7 +26,7 @@ def main_from_args(args):
args.ssl_keyfile,
args.ssl_password,
args.auto_reload,
tracer_provider,
args.endpoints,
)


Expand Down
65 changes: 48 additions & 17 deletions rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import warnings
import zlib
import json
from functools import partial
from typing import List, Text, Union, Optional
from ssl import SSLContext
from sanic import Sanic, response
from sanic.response import HTTPResponse
from sanic.worker.loader import AppLoader

# catching:
# - all `pkg_resources` deprecation warning from multiple dependencies
Expand All @@ -24,16 +26,23 @@
category=DeprecationWarning,
message="distutils Version classes are deprecated",
)
from opentelemetry.sdk.trace import TracerProvider
from sanic_cors import CORS
from sanic.request import Request
from rasa_sdk import utils
from rasa_sdk.cli.arguments import add_endpoint_arguments
from rasa_sdk.constants import DEFAULT_KEEP_ALIVE_TIMEOUT, DEFAULT_SERVER_PORT
from rasa_sdk.constants import (
DEFAULT_ENDPOINTS_PATH,
DEFAULT_KEEP_ALIVE_TIMEOUT,
DEFAULT_SERVER_PORT,
)
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException
from rasa_sdk.plugin import plugin_manager
from rasa_sdk.tracing.utils import get_tracer_and_context, set_span_attributes
from rasa_sdk.tracing.utils import (
get_tracer_and_context,
get_tracer_provider,
set_span_attributes,
)

logger = logging.getLogger(__name__)

Expand All @@ -42,7 +51,6 @@ def configure_cors(
app: Sanic, cors_origins: Union[Text, List[Text], None] = ""
) -> None:
"""Configure CORS origins for the given app."""

CORS(
app, resources={r"/*": {"origins": cors_origins or ""}}, automatic_options=True
)
Expand All @@ -54,7 +62,6 @@ def create_ssl_context(
ssl_password: Optional[Text] = None,
) -> Optional[SSLContext]:
"""Create a SSL context if a certificate is passed."""

if ssl_certificate:
import ssl

Expand All @@ -69,19 +76,23 @@ def create_ssl_context(

def create_argument_parser():
"""Parse all the command line arguments for the run script."""

parser = argparse.ArgumentParser(description="starts the action endpoint")
add_endpoint_arguments(parser)
utils.add_logging_level_option_arguments(parser)
utils.add_logging_file_arguments(parser)
return parser


async def load_tracer_provider(endpoints: str, app: Sanic):
"""Load the tracer provider into the Sanic app."""
tracer_provider = get_tracer_provider(endpoints)
app.ctx.tracer_provider = tracer_provider


def create_app(
action_package_name: Union[Text, types.ModuleType],
cors_origins: Union[Text, List[Text], None] = "*",
auto_reload: bool = False,
tracer_provider: Optional[TracerProvider] = None,
) -> Sanic:
"""Create a Sanic application and return it.
Expand All @@ -90,7 +101,6 @@ def create_app(
from.
cors_origins: CORS origins to allow.
auto_reload: When `True`, auto-reloading of actions is enabled.
tracer_provider: Tracer provider to use for tracing.
Returns:
A new Sanic application ready to be run.
Expand All @@ -102,6 +112,8 @@ def create_app(
executor = ActionExecutor()
executor.register_package(action_package_name)

app.ctx.tracer_provider = None

@app.get("/health")
async def health(_) -> HTTPResponse:
"""Ping endpoint to check if the server is running and well."""
Expand All @@ -111,7 +123,9 @@ async def health(_) -> HTTPResponse:
@app.post("/webhook")
async def webhook(request: Request) -> HTTPResponse:
"""Webhook to retrieve action calls."""
tracer, context, span_name = get_tracer_and_context(tracer_provider, request)
tracer, context, span_name = get_tracer_and_context(
request.app.ctx.tracer_provider, request
)

with tracer.start_as_current_span(span_name, context=context) as span:
if request.headers.get("Content-Encoding") == "deflate":
Expand Down Expand Up @@ -173,27 +187,44 @@ def run(
ssl_keyfile: Optional[Text] = None,
ssl_password: Optional[Text] = None,
auto_reload: bool = False,
tracer_provider: Optional[TracerProvider] = None,
endpoints: str = DEFAULT_ENDPOINTS_PATH,
keep_alive_timeout: int = DEFAULT_KEEP_ALIVE_TIMEOUT,
) -> None:
"""Starts the action endpoint server with given config values."""
logger.info("Starting action endpoint server...")
app = create_app(
action_package_name,
cors_origins=cors_origins,
auto_reload=auto_reload,
tracer_provider=tracer_provider,
loader = AppLoader(
factory=partial(
create_app,
action_package_name,
cors_origins=cors_origins,
auto_reload=auto_reload,
),
)
app = loader.load()

app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout
## Attach additional sanic extensions: listeners, middleware and routing

app.register_listener(
partial(load_tracer_provider, endpoints),
"before_server_start",
)

# Attach additional sanic extensions: listeners, middleware and routing
logger.info("Starting plugins...")
plugin_manager().hook.attach_sanic_app_extensions(app=app)

ssl_context = create_ssl_context(ssl_certificate, ssl_keyfile, ssl_password)
protocol = "https" if ssl_context else "http"
host = os.environ.get("SANIC_HOST", "0.0.0.0")

logger.info(f"Action endpoint is up and running on {protocol}://{host}:{port}")
app.run(host, port, ssl=ssl_context, workers=utils.number_of_sanic_workers())
app.run(
host=host,
port=port,
ssl=ssl_context,
workers=utils.number_of_sanic_workers(),
legacy=True,
)


if __name__ == "__main__":
Expand Down
20 changes: 6 additions & 14 deletions rasa_sdk/tracing/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from rasa_sdk.tracing import config
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
Expand All @@ -9,25 +8,18 @@
from typing import Optional, Tuple, Any, Text


def get_tracer_provider(
cmdline_arguments: argparse.Namespace,
) -> Optional[TracerProvider]:
def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]:
"""Gets the tracer provider from the command line arguments."""
tracer_provider = None
endpoints_file = ""
if "endpoints" in cmdline_arguments:
endpoints_file = cmdline_arguments.endpoints

if endpoints_file is not None:
tracer_provider = config.get_tracer_provider(endpoints_file)
config.configure_tracing(tracer_provider)
tracer_provider = config.get_tracer_provider(endpoints_file)
config.configure_tracing(tracer_provider)

return tracer_provider


def get_tracer_and_context(
tracer_provider: Optional[TracerProvider], request: Request
) -> Tuple[Any, Any, Text]:
"""Gets tracer and context"""
"""Gets tracer and context."""
span_name = "create_app.webhook"
if tracer_provider is None:
tracer = trace.get_tracer(span_name)
Expand All @@ -39,7 +31,7 @@ def get_tracer_and_context(


def set_span_attributes(span: Any, action_call: dict) -> None:
"""Sets span attributes"""
"""Sets span attributes."""
tracker = action_call.get("tracker", {})
set_span_attributes = {
"http.method": "POST",
Expand Down
Loading

0 comments on commit 2dd8011

Please sign in to comment.