Skip to content

Commit a6cd2a1

Browse files
committed
Implement Eyrie application
1 parent 991a970 commit a6cd2a1

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

eyrie/core/eyrie_application.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import (
4+
dataclass,
5+
field,
6+
)
7+
from typing import TYPE_CHECKING
8+
9+
from fastapi import Depends
10+
11+
from eyrie.common.interfaces import CanActivate
12+
from eyrie.common.interfaces.middlewares import EyrieGlobalMiddleware
13+
14+
from .injector import EyrieContainer
15+
from .router import RouterFactory
16+
from .scanner import DependenciesScanner
17+
from .utils import (
18+
sanitize_path,
19+
should_apply_middleware,
20+
)
21+
22+
if TYPE_CHECKING:
23+
from fastapi import FastAPI
24+
from fastapi.routing import APIRoute
25+
26+
from eyrie.core.injector import Module
27+
28+
29+
@dataclass(slots=True)
30+
class EyrieApplication:
31+
module_map: dict
32+
http_server: FastAPI
33+
global_prefix: str | None = None
34+
container_context: set = field(default_factory=set)
35+
36+
def _aggregate_module_components(self, modules: list[Module]):
37+
providers, controllers = {}, {}
38+
for module in modules:
39+
providers.update(module._providers)
40+
controllers.update(module._controllers)
41+
return providers, controllers
42+
43+
def add_global_middlewares(self, *middlewares: EyrieGlobalMiddleware | dict):
44+
for middleware in middlewares:
45+
if isinstance(middleware, dict):
46+
self.http_server.add_middleware(**middleware)
47+
elif issubclass(middleware, EyrieGlobalMiddleware):
48+
self.http_server.add_middleware(middleware)
49+
else:
50+
self.http_server.add_middleware(*middleware)
51+
52+
def set_global_prefix(self, prefix: str):
53+
self.global_prefix = sanitize_path(None, prefix)
54+
55+
def _register_providers(
56+
self, modules: list[Module], container: EyrieContainer, providers: dict
57+
):
58+
DependenciesScanner(modules=modules, providers=providers).scan()
59+
resolved_providers = container.get_resolved_providers(providers)
60+
for provider in resolved_providers:
61+
container.add_provider(providers[provider])
62+
63+
def _build_api_routers(self, controllers: dict):
64+
routers = []
65+
for controller in controllers:
66+
self.container_context.add(controller.__module__)
67+
if not isinstance(controller, type):
68+
continue
69+
router = RouterFactory(
70+
controller=controller, global_prefix=self.global_prefix
71+
).create()
72+
73+
controller_guards = getattr(controller, "__guards__", [])
74+
for route in router.router.routes:
75+
for guard in controller_guards:
76+
self.container_context.add(guard.__module__)
77+
deps = getattr(route, "dependencies", [])
78+
if isinstance(guard, type):
79+
if issubclass(guard, CanActivate) or hasattr(
80+
guard, "can_activate"
81+
):
82+
deps.append(Depends(guard().can_activate))
83+
else:
84+
deps.append(Depends(guard()))
85+
else:
86+
deps.append(Depends(guard))
87+
88+
method_guards = getattr(route.endpoint, "__guards__", [])
89+
for method_guard in method_guards:
90+
self.container_context.add(method_guard.__module__)
91+
deps = getattr(route, "dependencies", [])
92+
if isinstance(guard, type):
93+
if issubclass(guard, CanActivate) or hasattr(
94+
guard, "can_activate"
95+
):
96+
deps.append(Depends(guard().can_activate))
97+
else:
98+
deps.append(Depends(guard()))
99+
else:
100+
deps.append(Depends(guard))
101+
102+
routers.append(router.router)
103+
return routers
104+
105+
def _build_api_middlewares(self, routes: list[APIRoute]):
106+
middleware_info_list = []
107+
module_middlewares = list(
108+
map(lambda value: value.get("middlewares"), self.module_map.values())
109+
)
110+
for module_middleware in module_middlewares:
111+
if module_middleware:
112+
for item in module_middleware:
113+
if item not in middleware_info_list:
114+
middleware_info_list.append(item)
115+
116+
for route in routes:
117+
for middleware_info in middleware_info_list:
118+
middlewares = middleware_info.get("middlewares", [])
119+
methods = list(route.methods)
120+
for method in methods:
121+
should_apply = should_apply_middleware(
122+
method, route.path, middleware_info
123+
)
124+
for middleware in middlewares:
125+
self.container_context.add(middleware.__module__)
126+
if should_apply:
127+
deps = getattr(route, "dependencies", [])
128+
deps.append(Depends(middleware()))
129+
130+
def start(self):
131+
modules = list(map(lambda value: value["module"], self.module_map.values()))
132+
container = EyrieContainer()
133+
providers, controllers = self._aggregate_module_components(modules)
134+
self._register_providers(modules, container, providers)
135+
routers = self._build_api_routers(controllers)
136+
routes = []
137+
for router in routers:
138+
routes.extend(router.routes)
139+
140+
self._build_api_middlewares(routes)
141+
for router in routers:
142+
self.http_server.include_router(router)
143+
144+
container.wire(modules=self.container_context)
145+
self.http_server.container = container
146+
return self.http_server

0 commit comments

Comments
 (0)