music-assistant-server

13.6 KBPY
timed_client_stream.py
13.6 KB332 lines • python
1"""
2Timestamped multi-client audio stream for position-aware playback.
3
4This module provides a multi-client streaming implementation optimized for
5aiosendspin's synchronized multi-room audio playback. Each audio chunk is
6timestamped, allowing late-joining players to start at the correct position
7for synchronized playback across multiple devices.
8"""
9
10import asyncio
11import logging
12from collections import deque
13from collections.abc import AsyncGenerator
14from contextlib import suppress
15from uuid import UUID, uuid4
16
17from music_assistant_models.media_items import AudioFormat
18
19from music_assistant.helpers.ffmpeg import get_ffmpeg_stream
20
21LOGGER = logging.getLogger(__name__)
22
23# Minimum/target buffer retention time in seconds
24# This 10s buffer is currently required since:
25# - aiosendspin currently uses a fixed 5s buffer to allow up to ~4s of network interruption
26# - ~2s allows for ffmpeg processing time and some margin
27# - ~3s are currently needed internally by aiosendspin for initial buffering
28MIN_BUFFER_DURATION = 10.0
29# Maximum buffer duration before raising an error (safety mechanism)
30MAX_BUFFER_DURATION = MIN_BUFFER_DURATION + 5.0
31
32
33class TimedClientStream:
34    """Multi-client audio stream with timestamped chunks for synchronized playback."""
35
36    audio_source: AsyncGenerator[bytes, None]
37    """The source audio stream to read from."""
38    audio_format: AudioFormat
39    """The audio format of the source stream."""
40    chunk_buffer: deque[tuple[bytes, float]]
41    """Buffer storing chunks with their timestamps in seconds (chunk_data, timestamp_seconds)."""
42    subscriber_positions: dict[UUID, int]
43    """Subscriber positions: maps subscriber_id to position (index into chunk_buffer)."""
44    buffer_lock: asyncio.Lock
45    """Lock for buffer and shared state access."""
46    source_read_lock: asyncio.Lock
47    """Lock to serialize audio source reads."""
48    stream_ended: bool = False
49    """Track if stream has ended."""
50    current_position: float = 0.0
51    """Current position in seconds (from stream start)."""
52
53    def __init__(
54        self,
55        audio_source: AsyncGenerator[bytes, None],
56        audio_format: AudioFormat,
57    ) -> None:
58        """Initialize TimedClientStream."""
59        self.audio_source = audio_source
60        self.audio_format = audio_format
61        self.chunk_buffer = deque()
62        self.subscriber_positions = {}
63        self.buffer_lock = asyncio.Lock()
64        self.source_read_lock = asyncio.Lock()
65
66    def _get_bytes_per_second(self) -> int:
67        """Get bytes per second for the audio format."""
68        return (
69            self.audio_format.sample_rate
70            * self.audio_format.channels
71            * (self.audio_format.bit_depth // 8)
72        )
73
74    def _bytes_to_seconds(self, num_bytes: int) -> float:
75        """Convert bytes to seconds based on audio format."""
76        bytes_per_second = self._get_bytes_per_second()
77        if bytes_per_second == 0:
78            return 0.0
79        return num_bytes / bytes_per_second
80
81    def _get_buffer_duration(self) -> float:
82        """Calculate total duration of buffered chunks in seconds."""
83        if not self.chunk_buffer:
84            return 0.0
85        # Duration is from first chunk timestamp to current position
86        first_chunk_timestamp = self.chunk_buffer[0][1]
87        return self.current_position - first_chunk_timestamp
88
89    def _cleanup_old_chunks(self) -> None:
90        """Remove old chunks when all subscribers read them and min duration exceeded."""
91        # Find the oldest position still needed by any subscriber
92        if self.subscriber_positions:
93            min_position = min(self.subscriber_positions.values())
94        else:
95            min_position = len(self.chunk_buffer)
96
97        # Calculate target oldest timestamp
98        # This ensures buffer contains at least MIN_BUFFER_DURATION seconds of recent data
99        target_oldest = self.current_position - MIN_BUFFER_DURATION
100
101        # Remove old chunks that meet both conditions:
102        # 1. Before min_position (no subscriber needs them)
103        # 2. Older than target_oldest (outside minimum retention window)
104        chunks_removed = 0
105        while chunks_removed < min_position and self.chunk_buffer:
106            _chunk_bytes, chunk_timestamp = self.chunk_buffer[0]
107            if chunk_timestamp < target_oldest:
108                self.chunk_buffer.popleft()
109                chunks_removed += 1
110            else:
111                # Stop when we reach chunks we want to keep
112                break
113
114        # Adjust all subscriber positions to account for removed chunks
115        for sub_id in self.subscriber_positions:
116            self.subscriber_positions[sub_id] -= chunks_removed
117
118    async def _read_chunk_from_source(self) -> None:
119        """Read next chunk from audio source and add to buffer."""
120        try:
121            chunk = await anext(self.audio_source)
122            async with self.buffer_lock:
123                # Calculate timestamp for this chunk
124                chunk_timestamp = self.current_position
125                chunk_duration = self._bytes_to_seconds(len(chunk))
126
127                # Append chunk with its timestamp
128                self.chunk_buffer.append((chunk, chunk_timestamp))
129
130                # Update current position
131                self.current_position += chunk_duration
132
133                # Safety check: ensure buffer doesn't grow unbounded
134                if self._get_buffer_duration() > MAX_BUFFER_DURATION:
135                    msg = f"Buffer exceeded maximum duration ({MAX_BUFFER_DURATION}s)"
136                    raise RuntimeError(msg)
137        except StopAsyncIteration:
138            # Source exhausted, add EOF marker
139            async with self.buffer_lock:
140                self.chunk_buffer.append((b"", self.current_position))
141                self.stream_ended = True
142        except Exception:
143            # Source errored or was canceled, mark stream as ended
144            async with self.buffer_lock:
145                self.stream_ended = True
146            raise
147
148    async def _check_buffer(self, subscriber_id: UUID) -> bool | None:
149        """
150        Check if buffer has grown or stream ended.
151
152        REQUIRES: Caller must hold self.source_read_lock before calling.
153
154        Returns:
155            True if should continue reading loop (chunk found in buffer),
156            False if should break (stream ended),
157            None if should proceed to read from source.
158        """
159        async with self.buffer_lock:
160            position = self.subscriber_positions[subscriber_id]
161            if position < len(self.chunk_buffer):
162                # Another subscriber already read the chunk
163                return True
164            if self.stream_ended:
165                # Stream ended while waiting for source lock
166                return False
167        return None  # Continue to read from source
168
169    async def _get_chunk_from_buffer(self, subscriber_id: UUID) -> bytes | None:
170        """
171        Get next chunk from buffer for subscriber.
172
173        Returns:
174            Chunk bytes if available, None if no chunk available, or empty bytes for EOF.
175        """
176        async with self.buffer_lock:
177            position = self.subscriber_positions[subscriber_id]
178
179            # Check if we have a chunk at this position
180            if position < len(self.chunk_buffer):
181                # Chunk available in buffer
182                chunk_data, _ = self.chunk_buffer[position]
183
184                # Move to next position
185                self.subscriber_positions[subscriber_id] = position + 1
186
187                # Cleanup old chunks that no one needs
188                self._cleanup_old_chunks()
189                return chunk_data
190            if self.stream_ended:
191                # Stream ended and we've read all buffered chunks
192                return b""
193        return None
194
195    async def _cleanup_subscriber(self, subscriber_id: UUID) -> None:
196        """Clean up subscriber and close stream if no subscribers left."""
197        async with self.buffer_lock:
198            if subscriber_id in self.subscriber_positions:
199                del self.subscriber_positions[subscriber_id]
200
201            # If no subscribers left, close the stream
202            if not self.subscriber_positions and not self.stream_ended:
203                self.stream_ended = True
204                # Close the audio source generator to prevent resource leak
205                with suppress(Exception):
206                    await self.audio_source.aclose()
207
208    async def get_stream(
209        self,
210        output_format: AudioFormat,
211        filter_params: list[str] | None = None,
212    ) -> tuple[AsyncGenerator[bytes, None], float]:
213        """
214        Get (client specific encoded) ffmpeg stream.
215
216        Returns:
217            A tuple of (audio generator, actual position in seconds)
218        """
219        audio_gen, position = await self.subscribe_raw()
220
221        # Calculate frame size for alignment
222        # Frame size = channels * bytes_per_sample
223        bytes_per_frame = output_format.channels * (output_format.bit_depth // 8)
224
225        async def _stream_with_ffmpeg() -> AsyncGenerator[bytes, None]:
226            buffer = b""
227            try:
228                async for chunk in get_ffmpeg_stream(
229                    audio_input=audio_gen,
230                    input_format=self.audio_format,
231                    output_format=output_format,
232                    filter_params=filter_params,
233                ):
234                    buffer += chunk
235                    # Yield only complete frames
236                    aligned_size = (len(buffer) // bytes_per_frame) * bytes_per_frame
237                    if aligned_size > 0:
238                        yield buffer[:aligned_size]
239                        buffer = buffer[aligned_size:]
240                # Yield any remaining complete frames at end of stream
241                if buffer:
242                    aligned_size = (len(buffer) // bytes_per_frame) * bytes_per_frame
243                    if aligned_size > 0:
244                        yield buffer[:aligned_size]
245            finally:
246                # Ensure audio_gen cleanup runs immediately
247                with suppress(Exception):
248                    await audio_gen.aclose()
249
250        return _stream_with_ffmpeg(), position
251
252    async def _generate(self, subscriber_id: UUID) -> AsyncGenerator[bytes, None]:
253        """
254        Generate audio chunks for a subscriber.
255
256        Yields chunks from the buffer until the stream ends, reading from the source
257        as needed. Automatically cleans up the subscriber on exit.
258        """
259        try:
260            # Position already set above atomically with timestamp capture
261            while True:
262                # Try to get chunk from buffer
263                chunk_bytes = await self._get_chunk_from_buffer(subscriber_id)
264
265                # Release lock before yielding to avoid deadlock
266                if chunk_bytes is not None:
267                    if chunk_bytes == b"":
268                        # End of stream marker
269                        break
270                    yield chunk_bytes
271                else:
272                    # No chunk available, need to read from source
273                    # Use source_read_lock to ensure only one subscriber reads at a time
274                    async with self.source_read_lock:
275                        # Check again if buffer has grown or stream ended while waiting
276                        check_result = await self._check_buffer(subscriber_id)
277                        if check_result is True:
278                            # Another subscriber already read the chunk
279                            continue
280                        if check_result is False:
281                            # Stream ended while waiting for source lock
282                            break
283
284                        # Read next chunk from source (check_result is None)
285                        # Note: This may block if the audio_source does synchronous I/O
286                        await self._read_chunk_from_source()
287
288        finally:
289            await self._cleanup_subscriber(subscriber_id)
290
291    async def subscribe_raw(self) -> tuple[AsyncGenerator[bytes, None], float]:
292        """
293        Subscribe to the raw/unaltered audio stream.
294
295        Returns:
296            A tuple of (audio generator, actual position in seconds).
297            The position indicates where in the stream the first chunk will be from.
298
299        Note:
300            Callers must properly consume or cancel the returned generator to prevent
301            resource leaks.
302        """
303        subscriber_id = uuid4()
304
305        # Atomically capture starting position and register subscriber while holding lock
306        async with self.buffer_lock:
307            if self.chunk_buffer:
308                _, starting_position = self.chunk_buffer[0]
309                # Log buffer time range for debugging
310                newest_ts = self.chunk_buffer[-1][1]
311                oldest_relative = starting_position - self.current_position
312                newest_relative = newest_ts - self.current_position
313                LOGGER.debug(
314                    "New subscriber joining: buffer contains %.3fs (from %.3fs to %.3fs, "
315                    "current_position=%.3fs)",
316                    newest_ts - starting_position,
317                    oldest_relative,
318                    newest_relative,
319                    self.current_position,
320                )
321            else:
322                starting_position = self.current_position
323                LOGGER.debug(
324                    "New subscriber joining: buffer is empty, starting at current_position=%.3fs",
325                    self.current_position,
326                )
327            # Register subscriber at position 0 (start of buffer)
328            self.subscriber_positions[subscriber_id] = 0
329
330        # Return generator and starting position in seconds
331        return self._generate(subscriber_id), starting_position
332