/
/
/
1"""Implementation of a simple multi-client stream task/job."""
2
3import asyncio
4import logging
5from collections.abc import AsyncGenerator
6from contextlib import suppress
7
8from music_assistant_models.media_items import AudioFormat
9
10from music_assistant.helpers.ffmpeg import get_ffmpeg_stream
11from music_assistant.helpers.util import empty_queue
12
13LOGGER = logging.getLogger(__name__)
14
15
16class MultiClientStream:
17 """Implementation of a simple multi-client (audio) stream task/job."""
18
19 def __init__(
20 self,
21 audio_source: AsyncGenerator[bytes, None],
22 audio_format: AudioFormat,
23 expected_clients: int = 0,
24 ) -> None:
25 """Initialize MultiClientStream."""
26 self.audio_source = audio_source
27 self.audio_format = audio_format
28 self.subscribers: list[asyncio.Queue[bytes]] = []
29 self.expected_clients = expected_clients
30 self.task = asyncio.create_task(self._runner())
31
32 @property
33 def done(self) -> bool:
34 """Return if this stream is already done."""
35 return self.task.done()
36
37 async def stop(self) -> None:
38 """Stop/cancel the stream."""
39 if self.done:
40 return
41 self.task.cancel()
42 with suppress(asyncio.CancelledError):
43 await self.task
44 for sub_queue in list(self.subscribers):
45 empty_queue(sub_queue)
46
47 async def get_stream(
48 self,
49 output_format: AudioFormat,
50 filter_params: list[str] | None = None,
51 ) -> AsyncGenerator[bytes, None]:
52 """Get (client specific encoded) ffmpeg stream."""
53 async for chunk in get_ffmpeg_stream(
54 audio_input=self.subscribe_raw(),
55 input_format=self.audio_format,
56 output_format=output_format,
57 filter_params=filter_params,
58 ):
59 yield chunk
60
61 async def subscribe_raw(self) -> AsyncGenerator[bytes, None]:
62 """Subscribe to the raw/unaltered audio stream."""
63 queue: asyncio.Queue[bytes] = asyncio.Queue(2)
64 try:
65 self.subscribers.append(queue)
66 while True:
67 chunk = await queue.get()
68 if chunk == b"":
69 break
70 yield chunk
71 finally:
72 with suppress(ValueError):
73 self.subscribers.remove(queue)
74
75 async def _runner(self) -> None:
76 """Run the stream for the given audio source."""
77 expected_clients = self.expected_clients or 1
78 # wait for first/all subscriber
79 count = 0
80 while count < 50:
81 await asyncio.sleep(0.1)
82 count += 1
83 if len(self.subscribers) >= expected_clients:
84 break
85 LOGGER.debug(
86 "Starting multi-client stream with %s/%s clients",
87 len(self.subscribers),
88 self.expected_clients,
89 )
90 async for chunk in self.audio_source:
91 fail_count = 0
92 while len(self.subscribers) == 0:
93 await asyncio.sleep(0.1)
94 fail_count += 1
95 if fail_count > 50:
96 LOGGER.warning("No clients connected, stopping stream")
97 return
98 await asyncio.gather(
99 *[sub.put(chunk) for sub in self.subscribers], return_exceptions=True
100 )
101 # EOF: send empty chunk
102 await asyncio.gather(*[sub.put(b"") for sub in self.subscribers], return_exceptions=True)
103