music-assistant-server

5.8 KBPY
aiohttp_client.py
5.8 KB196 lines • python
1"""Helpers for setting up a aiohttp session (and related)."""
2
3from __future__ import annotations
4
5import asyncio
6import socket
7import sys
8from contextlib import suppress
9from functools import cache
10from ssl import SSLContext
11from types import MappingProxyType
12from typing import TYPE_CHECKING, Any, Self
13
14import aiohttp
15from aiohttp import web
16from aiohttp.hdrs import USER_AGENT
17from aiohttp_asyncmdnsresolver.api import AsyncDualMDNSResolver
18from music_assistant_models.enums import EventType
19
20from music_assistant.constants import APPLICATION_NAME
21
22from . import ssl as ssl_util
23from .json import json_dumps, json_loads
24
25if TYPE_CHECKING:
26    from aiohttp.typedefs import JSONDecoder
27    from music_assistant_models.event import MassEvent
28
29    from music_assistant.mass import MusicAssistant
30
31
32MAXIMUM_CONNECTIONS = 4096
33MAXIMUM_CONNECTIONS_PER_HOST = 100
34
35
36def create_clientsession(
37    mass: MusicAssistant,
38    verify_ssl: bool = True,
39    **kwargs: Any,
40) -> aiohttp.ClientSession:
41    """Create a new ClientSession with kwargs, i.e. for cookies."""
42    clientsession = aiohttp.ClientSession(
43        connector=_get_connector(mass, verify_ssl),
44        json_serialize=json_dumps,
45        response_class=MassClientResponse,
46        **kwargs,
47    )
48    # Prevent packages accidentally overriding our default headers
49    # It's important that we identify as Music Assistant
50    # If a package requires a different user agent, override it by passing a headers
51    # dictionary to the request method.
52    user_agent = (
53        f"{APPLICATION_NAME}/{mass.version} "
54        f"aiohttp/{aiohttp.__version__} Python/{sys.version_info[0]}.{sys.version_info[1]}"
55    )
56    clientsession._default_headers = MappingProxyType(  # type: ignore[assignment]
57        {USER_AGENT: user_agent},
58    )
59    return clientsession
60
61
62async def async_aiohttp_proxy_stream(
63    mass: MusicAssistant,
64    request: web.BaseRequest,
65    stream: aiohttp.StreamReader,
66    content_type: str | None,
67    buffer_size: int = 102400,
68    timeout: int = 10,
69) -> web.StreamResponse:
70    """Stream a stream to aiohttp web response."""
71    response = web.StreamResponse()
72    if content_type is not None:
73        response.content_type = content_type
74    await response.prepare(request)
75
76    # Suppressing something went wrong fetching data, closed connection
77    with suppress(TimeoutError, aiohttp.ClientError):
78        while not mass.closing:
79            async with asyncio.timeout(timeout):
80                data = await stream.read(buffer_size)
81
82            if not data:
83                break
84            await response.write(data)
85
86    return response
87
88
89class MassAsyncDNSResolver(AsyncDualMDNSResolver):
90    """Music Assistant AsyncDNSResolver.
91
92    This is a wrapper around the AsyncDualMDNSResolver to only
93    close the resolver when the Music Assistant instance is closed.
94    """
95
96    async def real_close(self) -> None:
97        """Close the resolver."""
98        await super().close()
99
100    async def close(self) -> None:
101        """Close the resolver."""
102
103
104class MassClientResponse(aiohttp.ClientResponse):
105    """aiohttp.ClientResponse with a json method that uses json_loads by default."""
106
107    async def json(
108        self,
109        *args: Any,
110        loads: JSONDecoder = json_loads,
111        **kwargs: Any,
112    ) -> Any:
113        """Send a json request and parse the json response."""
114        return await super().json(*args, loads=loads, **kwargs)
115
116
117class ChunkAsyncStreamIterator:
118    """
119    Async iterator for chunked streams.
120
121    Based on aiohttp.streams.ChunkTupleAsyncStreamIterator, but yields
122    bytes instead of tuple[bytes, bool].
123    """
124
125    __slots__ = ("_stream",)
126
127    def __init__(self, stream: aiohttp.StreamReader) -> None:
128        """Initialize."""
129        self._stream = stream
130
131    def __aiter__(self) -> Self:
132        """Iterate."""
133        return self
134
135    async def __anext__(self) -> bytes:
136        """Yield next chunk."""
137        rv = await self._stream.readchunk()
138        if rv == (b"", False):
139            raise StopAsyncIteration
140        return rv[0]
141
142
143class MusicAssistantTCPConnector(aiohttp.TCPConnector):
144    """Music Assistant TCP Connector.
145
146    Same as aiohttp.TCPConnector but with a longer cleanup_closed timeout.
147
148    By default the cleanup_closed timeout is 2 seconds. This is too short
149    for Music Assistant since we churn through a lot of connections. We set
150    it to 60 seconds to reduce the overhead of aborting TLS connections
151    that are likely already closed.
152    """
153
154    # abort transport after 60 seconds (cleanup broken connections)
155    _cleanup_closed_period = 60.0
156
157
158def _get_connector(
159    mass: MusicAssistant,
160    verify_ssl: bool = True,
161    family: socket.AddressFamily = socket.AF_UNSPEC,
162    ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT,
163) -> aiohttp.BaseConnector:
164    """
165    Return the connector pool for aiohttp.
166
167    This method must be run in the event loop.
168    """
169    if verify_ssl:
170        ssl_context: SSLContext = ssl_util.client_context(ssl_cipher)
171    else:
172        ssl_context = ssl_util.client_context_no_verify(ssl_cipher)
173
174    return MusicAssistantTCPConnector(
175        family=family,
176        # Cleanup closed is no longer needed after https://github.com/python/cpython/pull/118960
177        # which first appeared in Python 3.12.7 and 3.13.1
178        enable_cleanup_closed=False,
179        ssl=ssl_context,
180        limit=MAXIMUM_CONNECTIONS,
181        limit_per_host=MAXIMUM_CONNECTIONS_PER_HOST,
182        resolver=_get_resolver(mass),
183    )
184
185
186@cache
187def _get_resolver(mass: MusicAssistant) -> MassAsyncDNSResolver:
188    """Return the MassAsyncDNSResolver."""
189    resolver = MassAsyncDNSResolver(async_zeroconf=mass.aiozc)
190
191    async def _close_resolver(event: MassEvent) -> None:  # noqa: ARG001
192        await resolver.real_close()
193
194    mass.subscribe(_close_resolver, EventType.SHUTDOWN)
195    return resolver
196