music-assistant-server

17.1 KBPY
stream_session.py
17.1 KB412 lines • python
1"""Unified AirPlay/RAOP stream session logic for AirPlay devices."""
2
3from __future__ import annotations
4
5import asyncio
6import time
7from collections import deque
8from collections.abc import AsyncGenerator
9from contextlib import suppress
10from typing import TYPE_CHECKING
11
12from music_assistant_models.enums import PlaybackState
13from music_assistant_models.errors import PlayerCommandFailed
14
15from music_assistant.constants import CONF_SYNC_ADJUST
16from music_assistant.helpers.audio import get_player_filter_params
17from music_assistant.helpers.ffmpeg import FFMpeg
18from music_assistant.providers.airplay.helpers import ntp_to_unix_time, unix_time_to_ntp
19
20from .constants import (
21    AIRPLAY2_CONNECT_TIME_MS,
22    CONF_ENABLE_LATE_JOIN,
23    ENABLE_LATE_JOIN_DEFAULT,
24    RAOP_CONNECT_TIME_MS,
25    StreamingProtocol,
26)
27from .protocols.airplay2 import AirPlay2Stream
28from .protocols.raop import RaopStream
29
30if TYPE_CHECKING:
31    from music_assistant_models.media_items import AudioFormat
32
33    from .player import AirPlayPlayer
34    from .provider import AirPlayProvider
35
36
37class AirPlayStreamSession:
38    """Stream session (RAOP or AirPlay2) to one or more players."""
39
40    def __init__(
41        self,
42        airplay_provider: AirPlayProvider,
43        sync_clients: list[AirPlayPlayer],
44        pcm_format: AudioFormat,
45    ) -> None:
46        """Initialize AirPlayStreamSession.
47
48        :param airplay_provider: The AirPlay provider instance.
49        :param sync_clients: List of AirPlay players to stream to.
50        :param pcm_format: PCM format of the input stream.
51        """
52        assert sync_clients
53        self.prov = airplay_provider
54        self.mass = airplay_provider.mass
55        self.pcm_format = pcm_format
56        self.sync_clients = sync_clients
57        self._audio_source_task: asyncio.Task[None] | None = None
58        self._player_ffmpeg: dict[str, FFMpeg] = {}
59        self._lock = asyncio.Lock()
60        self.start_ntp: int = 0
61        self.start_time: float = 0.0
62        self.wait_start: float = 0.0
63        self.seconds_streamed: float = 0
64        self._first_chunk_received = asyncio.Event()
65        # Ring buffer for late joiners: stores (chunk_data, seconds_offset) tuples
66        # Chunks from streams controller are ~1 second each (pcm_sample_size bytes)
67        # Keep 8 seconds of buffer for late joiners (maxlen=10 for safety with variable sizes)
68        self._chunk_buffer: deque[tuple[bytes, float]] = deque(maxlen=10)
69
70    async def start(self, audio_source: AsyncGenerator[bytes, None]) -> None:
71        """Initialize stream session for all players."""
72        cur_time = time.time()
73        has_airplay2_client = any(
74            p.protocol == StreamingProtocol.AIRPLAY2 for p in self.sync_clients
75        )
76        max_output_buffer_ms: int = 0
77        if has_airplay2_client:
78            max_output_buffer_ms = max(p.output_buffer_duration_ms for p in self.sync_clients)
79        wait_start = (
80            AIRPLAY2_CONNECT_TIME_MS + max_output_buffer_ms
81            if has_airplay2_client
82            else RAOP_CONNECT_TIME_MS
83        )
84        wait_start_seconds = wait_start / 1000
85        self.wait_start = wait_start_seconds
86        self.start_time = cur_time + wait_start_seconds
87        self.start_ntp = unix_time_to_ntp(self.start_time)
88        await asyncio.gather(*[self._start_client(p, self.start_ntp) for p in self.sync_clients])
89        self._audio_source_task = asyncio.create_task(self._audio_streamer(audio_source))
90        try:
91            await asyncio.gather(
92                *[p.stream.wait_for_connection() for p in self.sync_clients if p.stream]
93            )
94        except Exception:
95            # playback failed to start, cleanup
96            await self.stop()
97            raise PlayerCommandFailed("Playback failed to start")
98
99    async def stop(self) -> None:
100        """Stop playback and cleanup."""
101        if self._audio_source_task and not self._audio_source_task.done():
102            self._audio_source_task.cancel()
103            with suppress(asyncio.CancelledError):
104                await self._audio_source_task
105        await asyncio.gather(
106            *[self.remove_client(x) for x in self.sync_clients],
107        )
108
109    async def remove_client(self, airplay_player: AirPlayPlayer) -> None:
110        """Remove a sync client from the session."""
111        async with self._lock:
112            if airplay_player not in self.sync_clients:
113                return
114            self.sync_clients.remove(airplay_player)
115        await self.stop_client(airplay_player)
116        airplay_player.set_state_from_stream(PlaybackState.IDLE)
117        # If this was the last client, stop the session
118        if not self.sync_clients:
119            await self.stop()
120            return
121
122    async def stop_client(self, airplay_player: AirPlayPlayer) -> None:
123        """
124        Stop a client's stream and ffmpeg.
125
126        :param airplay_player: The player to stop.
127        :param force: If True, kill CLI process immediately.
128        """
129        ffmpeg = self._player_ffmpeg.pop(airplay_player.player_id, None)
130        # note that we use kill instead of graceful close here,
131        # because otherwise it can take a very long time for the process to exit.
132        if ffmpeg and not ffmpeg.closed:
133            await ffmpeg.kill()
134        if airplay_player.stream and airplay_player.stream.session == self:
135            await airplay_player.stream.stop(force=True)
136
137    async def add_client(self, airplay_player: AirPlayPlayer) -> None:
138        """Add a sync client to the session as a late joiner.
139
140        The late joiner will:
141        1. Start with NTP timestamp accounting for buffered chunks we'll send
142        2. Receive buffered chunks immediately to prime the ffmpeg/CLI pipeline
143        3. Join the real-time stream in perfect sync with other players
144        """
145        sync_leader = self.sync_clients[0]
146        if not sync_leader.stream or not sync_leader.stream.running:
147            return
148
149        allow_late_join = self.prov.config.get_value(
150            CONF_ENABLE_LATE_JOIN, ENABLE_LATE_JOIN_DEFAULT
151        )
152        if not allow_late_join:
153            await self.stop()
154            if sync_leader.state.current_media:
155                self.mass.call_later(
156                    0.5,
157                    self.mass.players.cmd_resume(sync_leader.player_id),
158                    task_id=f"resync_session_{sync_leader.player_id}",
159                )
160            return
161
162        async with self._lock:
163            # Get buffered chunks to send, but limit to ~5 seconds to avoid
164            # blocking real-time streaming to other players (causes packet loss)
165            max_late_join_buffer_seconds = 5.0
166            all_buffered = list(self._chunk_buffer)
167
168            # Filter to only include chunks within the time limit
169            if all_buffered:
170                min_position = self.seconds_streamed - max_late_join_buffer_seconds
171                buffered_chunks = [
172                    (chunk, pos) for chunk, pos in all_buffered if pos >= min_position
173                ]
174            else:
175                buffered_chunks = []
176
177            if buffered_chunks:
178                # Calculate how much buffer we're sending
179                first_chunk_position = buffered_chunks[0][1]
180                buffer_duration = self.seconds_streamed - first_chunk_position
181
182                # Set start NTP to account for the buffer we're about to send
183                # Device will start at (current_position - buffer_duration) and catch up
184                start_at = self.start_time + (self.seconds_streamed - buffer_duration)
185
186                self.prov.logger.debug(
187                    "Late joiner %s: sending %.2fs of buffered audio, start at %.2fs",
188                    airplay_player.player_id,
189                    buffer_duration,
190                    self.seconds_streamed - buffer_duration,
191                )
192            else:
193                # No buffer available, start from current position
194                start_at = self.start_time + self.seconds_streamed
195                self.prov.logger.debug(
196                    "Late joiner %s: no buffered chunks available, starting at %.2fs",
197                    airplay_player.player_id,
198                    self.seconds_streamed,
199                )
200
201            start_ntp = unix_time_to_ntp(start_at)
202
203            if airplay_player not in self.sync_clients:
204                self.sync_clients.append(airplay_player)
205
206            await self._start_client(airplay_player, start_ntp)
207            if airplay_player.stream:
208                await airplay_player.stream.wait_for_connection()
209
210            # Feed buffered chunks INSIDE the lock to prevent race conditions
211            # This ensures we don't send a new real-time chunk while feeding the buffer
212            if buffered_chunks:
213                await self._feed_buffered_chunks(airplay_player, buffered_chunks)
214
215    async def _audio_streamer(self, audio_source: AsyncGenerator[bytes, None]) -> None:
216        """Stream audio to all players."""
217        pcm_sample_size = self.pcm_format.pcm_sample_size
218        watchdog_task = asyncio.create_task(self._silence_watchdog(pcm_sample_size))
219        stream_error: BaseException | None = None
220        try:
221            async for chunk in audio_source:
222                if not self._first_chunk_received.is_set():
223                    watchdog_task.cancel()
224                    with suppress(asyncio.CancelledError):
225                        await watchdog_task
226                    self._first_chunk_received.set()
227
228                if not self.sync_clients:
229                    break
230
231                has_running_clients = await self._write_chunk_to_all_players(chunk)
232                if not has_running_clients:
233                    self.prov.logger.debug("No running clients remaining, stopping audio streamer")
234                    break
235                self.seconds_streamed += len(chunk) / pcm_sample_size
236        except asyncio.CancelledError:
237            self.prov.logger.debug("Audio streamer cancelled after %.1fs", self.seconds_streamed)
238            raise
239        except Exception as err:
240            stream_error = err
241            self.prov.logger.error(
242                "Audio source error after %.1fs of streaming: %s",
243                self.seconds_streamed,
244                err,
245                exc_info=err,
246            )
247        finally:
248            if not watchdog_task.done():
249                watchdog_task.cancel()
250                with suppress(asyncio.CancelledError):
251                    await watchdog_task
252            if stream_error:
253                self.prov.logger.warning(
254                    "Stream ended prematurely due to error - notifying players"
255                )
256        async with self._lock:
257            await asyncio.gather(
258                *[
259                    self._write_eof_to_player(x)
260                    for x in self.sync_clients
261                    if x.stream and x.stream.running
262                ],
263                return_exceptions=True,
264            )
265
266    async def _silence_watchdog(self, pcm_sample_size: int) -> None:
267        """Insert silence if audio source is slow to deliver first chunk."""
268        grace_period = 0.2
269        max_silence_padding = 5.0
270        silence_inserted = 0.0
271
272        await asyncio.sleep(grace_period)
273        while not self._first_chunk_received.is_set() and silence_inserted < max_silence_padding:
274            silence_duration = 0.1
275            silence_bytes = int(pcm_sample_size * silence_duration)
276            silence_chunk = bytes(silence_bytes)
277            has_running_clients = await self._write_chunk_to_all_players(silence_chunk)
278            if not has_running_clients:
279                break
280            self.seconds_streamed += silence_duration
281            silence_inserted += silence_duration
282            await asyncio.sleep(0.05)
283
284        if silence_inserted > 0:
285            self.prov.logger.warning(
286                "Inserted %.1fs silence padding while waiting for audio source",
287                silence_inserted,
288            )
289
290    async def _write_chunk_to_all_players(self, chunk: bytes) -> bool:
291        """Write a chunk to all connected players.
292
293        :return: True if there are still running clients, False otherwise.
294        """
295        async with self._lock:
296            sync_clients = [x for x in self.sync_clients if x.stream and x.stream.running]
297            if not sync_clients:
298                return False
299
300            # Add chunk to ring buffer for late joiners (before seconds_streamed is updated)
301            chunk_position = self.seconds_streamed
302            self._chunk_buffer.append((chunk, chunk_position))
303
304            # Write chunk to all players
305            write_tasks = [self._write_chunk_to_player(x, chunk) for x in sync_clients if x.stream]
306            results = await asyncio.gather(*write_tasks, return_exceptions=True)
307
308            # Check for write errors or timeouts
309            players_to_remove: list[AirPlayPlayer] = []
310            for i, result in enumerate(results):
311                if i >= len(sync_clients):
312                    continue
313                player = sync_clients[i]
314
315                if isinstance(result, TimeoutError):
316                    self.prov.logger.warning(
317                        "Removing player %s from session: stopped reading data (write timeout)",
318                        player.player_id,
319                    )
320                    players_to_remove.append(player)
321                elif isinstance(result, Exception):
322                    self.prov.logger.warning(
323                        "Removing player %s from session due to write error: %s",
324                        player.player_id,
325                        result,
326                    )
327                    players_to_remove.append(player)
328
329            for player in players_to_remove:
330                self.mass.create_task(self.remove_client(player))
331
332            # Return False if all clients were removed (or scheduled for removal)
333            remaining_clients = len(sync_clients) - len(players_to_remove)
334            return remaining_clients > 0
335
336    async def _write_chunk_to_player(self, airplay_player: AirPlayPlayer, chunk: bytes) -> None:
337        """Write audio chunk to a player's ffmpeg process."""
338        player_id = airplay_player.player_id
339        if ffmpeg := self._player_ffmpeg.get(player_id):
340            if ffmpeg.closed:
341                return
342            await asyncio.wait_for(ffmpeg.write(chunk), timeout=35.0)
343
344    async def _feed_buffered_chunks(
345        self,
346        airplay_player: AirPlayPlayer,
347        buffered_chunks: list[tuple[bytes, float]],
348    ) -> None:
349        """Feed buffered chunks to a late joiner to prime the ffmpeg pipeline.
350
351        :param airplay_player: The late joiner player.
352        :param buffered_chunks: List of (chunk_data, position) tuples to send.
353        """
354        try:
355            for chunk, _position in buffered_chunks:
356                await self._write_chunk_to_player(airplay_player, chunk)
357        except Exception as err:
358            self.prov.logger.warning(
359                "Failed to feed buffered chunks to late joiner %s: %s",
360                airplay_player.player_id,
361                err,
362            )
363            # Remove the client if feeding buffered chunks fails
364            self.mass.create_task(self.remove_client(airplay_player))
365
366    async def _write_eof_to_player(self, airplay_player: AirPlayPlayer) -> None:
367        """Write EOF to a specific player."""
368        if ffmpeg := self._player_ffmpeg.pop(airplay_player.player_id, None):
369            await ffmpeg.write_eof()
370            await ffmpeg.wait_with_timeout(30)
371            if airplay_player.stream and airplay_player.stream._cli_proc:
372                await airplay_player.stream._cli_proc.write_eof()
373
374    async def _start_client(self, airplay_player: AirPlayPlayer, start_ntp: int) -> None:
375        """Start CLI process and ffmpeg for a single client."""
376        if airplay_player.stream and airplay_player.stream.running:
377            await airplay_player.stream.stop()
378        if airplay_player.protocol == StreamingProtocol.AIRPLAY2:
379            airplay_player.stream = AirPlay2Stream(airplay_player)
380        else:
381            airplay_player.stream = RaopStream(airplay_player)
382        airplay_player.stream.session = self
383        sync_adjust = airplay_player.config.get_value(CONF_SYNC_ADJUST, 0)
384        assert isinstance(sync_adjust, int)
385        if sync_adjust != 0:
386            start_ntp = unix_time_to_ntp(ntp_to_unix_time(start_ntp) + (sync_adjust / 1000))
387        await airplay_player.stream.start(start_ntp)
388        # Start ffmpeg to feed audio to CLI stdin
389        if ffmpeg := self._player_ffmpeg.pop(airplay_player.player_id, None):
390            await ffmpeg.close()
391        filter_params = get_player_filter_params(
392            self.mass,
393            airplay_player.player_id,
394            self.pcm_format,
395            airplay_player.stream.pcm_format,
396        )
397        cli_proc = airplay_player.stream._cli_proc
398        assert cli_proc
399        assert cli_proc.proc
400        assert cli_proc.proc.stdin
401        stdin_transport = cli_proc.proc.stdin.transport
402        audio_output: str | int = stdin_transport.get_extra_info("pipe").fileno()
403        ffmpeg = FFMpeg(
404            audio_input="-",
405            input_format=self.pcm_format,
406            output_format=airplay_player.stream.pcm_format,
407            filter_params=filter_params,
408            audio_output=audio_output,
409        )
410        await ffmpeg.start()
411        self._player_ffmpeg[airplay_player.player_id] = ffmpeg
412