music-assistant-server

3.4 KBPY
multi_client_stream.py
3.4 KB103 lines • python
1"""Implementation of a simple multi-client stream task/job."""
2
3import asyncio
4import logging
5from collections.abc import AsyncGenerator
6from contextlib import suppress
7
8from music_assistant_models.media_items import AudioFormat
9
10from music_assistant.helpers.ffmpeg import get_ffmpeg_stream
11from music_assistant.helpers.util import empty_queue
12
13LOGGER = logging.getLogger(__name__)
14
15
16class MultiClientStream:
17    """Implementation of a simple multi-client (audio) stream task/job."""
18
19    def __init__(
20        self,
21        audio_source: AsyncGenerator[bytes, None],
22        audio_format: AudioFormat,
23        expected_clients: int = 0,
24    ) -> None:
25        """Initialize MultiClientStream."""
26        self.audio_source = audio_source
27        self.audio_format = audio_format
28        self.subscribers: list[asyncio.Queue[bytes]] = []
29        self.expected_clients = expected_clients
30        self.task = asyncio.create_task(self._runner())
31
32    @property
33    def done(self) -> bool:
34        """Return if this stream is already done."""
35        return self.task.done()
36
37    async def stop(self) -> None:
38        """Stop/cancel the stream."""
39        if self.done:
40            return
41        self.task.cancel()
42        with suppress(asyncio.CancelledError):
43            await self.task
44        for sub_queue in list(self.subscribers):
45            empty_queue(sub_queue)
46
47    async def get_stream(
48        self,
49        output_format: AudioFormat,
50        filter_params: list[str] | None = None,
51    ) -> AsyncGenerator[bytes, None]:
52        """Get (client specific encoded) ffmpeg stream."""
53        async for chunk in get_ffmpeg_stream(
54            audio_input=self.subscribe_raw(),
55            input_format=self.audio_format,
56            output_format=output_format,
57            filter_params=filter_params,
58        ):
59            yield chunk
60
61    async def subscribe_raw(self) -> AsyncGenerator[bytes, None]:
62        """Subscribe to the raw/unaltered audio stream."""
63        queue: asyncio.Queue[bytes] = asyncio.Queue(2)
64        try:
65            self.subscribers.append(queue)
66            while True:
67                chunk = await queue.get()
68                if chunk == b"":
69                    break
70                yield chunk
71        finally:
72            with suppress(ValueError):
73                self.subscribers.remove(queue)
74
75    async def _runner(self) -> None:
76        """Run the stream for the given audio source."""
77        expected_clients = self.expected_clients or 1
78        # wait for first/all subscriber
79        count = 0
80        while count < 50:
81            await asyncio.sleep(0.1)
82            count += 1
83            if len(self.subscribers) >= expected_clients:
84                break
85        LOGGER.debug(
86            "Starting multi-client stream with %s/%s clients",
87            len(self.subscribers),
88            self.expected_clients,
89        )
90        async for chunk in self.audio_source:
91            fail_count = 0
92            while len(self.subscribers) == 0:
93                await asyncio.sleep(0.1)
94                fail_count += 1
95                if fail_count > 50:
96                    LOGGER.warning("No clients connected, stopping stream")
97                    return
98            await asyncio.gather(
99                *[sub.put(chunk) for sub in self.subscribers], return_exceptions=True
100            )
101        # EOF: send empty chunk
102        await asyncio.gather(*[sub.put(b"") for sub in self.subscribers], return_exceptions=True)
103