/
/
/
1"""Base Webserver logic for an HTTPServer that can handle dynamic routes."""
2
3from __future__ import annotations
4
5from collections.abc import Callable, Coroutine
6from typing import TYPE_CHECKING, Any, Final
7
8from aiohttp import web
9
10if TYPE_CHECKING:
11 import logging
12
13 from aiohttp.typedefs import Handler
14
15
16MAX_CLIENT_SIZE: Final = 1024**2 * 16
17MAX_LINE_SIZE: Final = 24570
18
19# Type alias for dynamic route handlers
20DynamicRouteHandler = Callable[
21 [web.Request], Coroutine[Any, Any, web.Response | web.StreamResponse]
22]
23
24
25class Webserver:
26 """Base Webserver logic for an HTTPServer that can handle dynamic routes."""
27
28 def __init__(
29 self,
30 logger: logging.Logger,
31 enable_dynamic_routes: bool = False,
32 ) -> None:
33 """Initialize instance."""
34 self.logger = logger
35 # the below gets initialized in async setup
36 self._apprunner: web.AppRunner | None = None
37 self._webapp: web.Application | None = None
38 self._tcp_site: web.TCPSite | None = None
39 self._static_routes: list[tuple[str, str, Handler]] | None = None
40 self._dynamic_routes: dict[str, DynamicRouteHandler] | None = (
41 {} if enable_dynamic_routes else None
42 )
43 self._bind_port: int | None = None
44 self._ingress_tcp_site: web.TCPSite | None = None
45
46 async def setup(
47 self,
48 bind_ip: str | None,
49 bind_port: int,
50 base_url: str,
51 static_routes: list[tuple[str, str, Handler]] | None = None,
52 static_content: tuple[str, str, str] | None = None,
53 ingress_tcp_site_params: tuple[str, int] | None = None,
54 app_state: dict[str, Any] | None = None,
55 ssl_context: Any | None = None,
56 ) -> None:
57 """Async initialize of module.
58
59 :param bind_ip: IP address to bind to.
60 :param bind_port: Port to bind to.
61 :param base_url: Base URL for the server.
62 :param static_routes: List of static routes to register.
63 :param static_content: Tuple of (path, directory, name) for static content.
64 :param ingress_tcp_site_params: Tuple of (host, port) for ingress TCP site.
65 :param app_state: Optional dict of key-value pairs to set on app before starting.
66 :param ssl_context: Optional SSL context for HTTPS support.
67 """
68 self._base_url = base_url.removesuffix("/")
69 self._bind_port = bind_port
70 self._static_routes = static_routes
71 self._webapp = web.Application(
72 logger=self.logger,
73 client_max_size=MAX_CLIENT_SIZE,
74 handler_args={
75 "max_line_size": MAX_LINE_SIZE,
76 "max_field_size": MAX_LINE_SIZE,
77 },
78 )
79 # Set app state before starting
80 if app_state:
81 for key, value in app_state.items():
82 self._webapp[key] = value
83 self._apprunner = web.AppRunner(self._webapp, access_log=None, shutdown_timeout=10)
84 # add static routes
85 if self._static_routes:
86 for method, path, handler in self._static_routes:
87 self._webapp.router.add_route(method, path, handler)
88 if static_content:
89 self._webapp.router.add_static(
90 static_content[0], static_content[1], name=static_content[2]
91 )
92 # register catch-all route to handle dynamic routes (if enabled)
93 if self._dynamic_routes is not None:
94 self._webapp.router.add_route("*", "/{tail:.*}", self._handle_catch_all)
95 await self._apprunner.setup()
96 # set host to None to bind to all addresses on both IPv4 and IPv6
97 host = None if bind_ip in ("0.0.0.0", "::") else bind_ip
98 try:
99 self._tcp_site = web.TCPSite(
100 self._apprunner, host=host, port=bind_port, ssl_context=ssl_context
101 )
102 await self._tcp_site.start()
103 except OSError:
104 if host is None:
105 raise
106 # the configured interface is not available, retry on all interfaces
107 self.logger.error(
108 "Could not bind to %s, will start on all interfaces as fallback!", host
109 )
110 self._tcp_site = web.TCPSite(
111 self._apprunner, host=None, port=bind_port, ssl_context=ssl_context
112 )
113 await self._tcp_site.start()
114 # start additional ingress TCP site if configured
115 # this is only used if we're running in the context of an HA add-on
116 # which proxies our frontend and api through ingress
117 if ingress_tcp_site_params:
118 # Store ingress site reference in app for security checks
119 self._webapp["ingress_site"] = ingress_tcp_site_params
120 self._ingress_tcp_site = web.TCPSite(
121 self._apprunner,
122 host=ingress_tcp_site_params[0],
123 port=ingress_tcp_site_params[1],
124 )
125 await self._ingress_tcp_site.start()
126
127 async def close(self) -> None:
128 """Cleanup on exit."""
129 # stop/clean webserver
130 if self._tcp_site:
131 await self._tcp_site.stop()
132 if self._ingress_tcp_site:
133 await self._ingress_tcp_site.stop()
134 if self._apprunner:
135 await self._apprunner.cleanup()
136 if self._webapp:
137 await self._webapp.shutdown()
138 await self._webapp.cleanup()
139
140 @property
141 def base_url(self) -> str:
142 """Return the base URL of this webserver."""
143 return self._base_url
144
145 @property
146 def port(self) -> int | None:
147 """Return the port of this webserver."""
148 return self._bind_port
149
150 def register_dynamic_route(
151 self,
152 path: str,
153 handler: Callable[[web.Request], Coroutine[Any, Any, web.Response | web.StreamResponse]],
154 method: str = "*",
155 ) -> Callable[[], None]:
156 """Register a dynamic route on the webserver, returns handler to unregister."""
157 if self._dynamic_routes is None:
158 msg = "Dynamic routes are not enabled"
159 raise RuntimeError(msg)
160 key = f"{method}.{path}"
161 if key in self._dynamic_routes:
162 msg = f"Route {path} already registered."
163 raise RuntimeError(msg)
164 self._dynamic_routes[key] = handler
165
166 def _remove() -> None:
167 assert self._dynamic_routes is not None # for type checking
168 self._dynamic_routes.pop(key, None)
169
170 return _remove
171
172 def unregister_dynamic_route(self, path: str, method: str = "*") -> None:
173 """Unregister a dynamic route from the webserver."""
174 if self._dynamic_routes is None:
175 msg = "Dynamic routes are not enabled"
176 raise RuntimeError(msg)
177 key = f"{method}.{path}"
178 self._dynamic_routes.pop(key, None)
179
180 async def serve_static(self, file_path: str, request: web.Request) -> web.FileResponse:
181 """Serve file response."""
182 headers = {"Cache-Control": "no-cache"}
183 return web.FileResponse(file_path, headers=headers)
184
185 async def _handle_catch_all(self, request: web.Request) -> web.Response | web.StreamResponse:
186 """Redirect request to correct destination."""
187 # find handler for the request
188 # Try exact match first
189 for key in (f"{request.method}.{request.path}", f"*.{request.path}"):
190 assert self._dynamic_routes is not None # for type checking
191 if handler := self._dynamic_routes.get(key):
192 return await handler(request)
193 # Try prefix match (for routes registered with /*)
194 if self._dynamic_routes is not None:
195 for route_key, handler in list(self._dynamic_routes.items()):
196 method, path = route_key.split(".", 1)
197 if method in (request.method, "*") and path.endswith("/*"):
198 prefix = path[:-2]
199 if request.path.startswith(prefix):
200 return await handler(request)
201 # deny all other requests
202 self.logger.warning(
203 "Received unhandled %s request to %s from %s\nheaders: %s\n",
204 request.method,
205 request.path,
206 request.remote,
207 request.headers,
208 )
209 return web.Response(status=404)
210