music-assistant-server

13.8 KBPY
audio_buffer.py
13.8 KB351 lines • python
1"""Audio buffer implementation for PCM audio streaming."""
2
3from __future__ import annotations
4
5import asyncio
6import logging
7import time
8from collections import deque
9from collections.abc import AsyncGenerator
10from contextlib import suppress
11from typing import TYPE_CHECKING, Any
12
13from music_assistant_models.errors import AudioError
14
15from music_assistant.constants import MASS_LOGGER_NAME, VERBOSE_LOG_LEVEL
16
17if TYPE_CHECKING:
18    from music_assistant_models.media_items import AudioFormat
19
20LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.audio_buffer")
21
22DEFAULT_MAX_BUFFER_SIZE_SECONDS: int = 60 * 8  # 8 minutes
23
24
25class AudioBufferEOF(Exception):
26    """Exception raised when the audio buffer reaches end-of-file."""
27
28
29class AudioBuffer:
30    """Simple buffer to hold (PCM) audio chunks with seek capability.
31
32    Each chunk represents exactly 1 second of audio.
33    Chunks are stored in a deque for efficient O(1) append and popleft operations.
34    """
35
36    def __init__(
37        self,
38        pcm_format: AudioFormat,
39        checksum: str,
40        max_size_seconds: int = DEFAULT_MAX_BUFFER_SIZE_SECONDS,
41    ) -> None:
42        """
43        Initialize AudioBuffer.
44
45        Args:
46            pcm_format: The PCM audio format specification
47            checksum: The checksum for the audio data (for validation purposes)
48            max_size_seconds: Maximum buffer size in seconds
49        """
50        self.pcm_format = pcm_format
51        self.checksum = checksum
52        self.max_size_seconds = max_size_seconds
53        # Store chunks in a deque for O(1) append and popleft operations
54        self._chunks: deque[bytes] = deque()
55        # Track how many chunks have been discarded from the start
56        self._discarded_chunks = 0
57        self._lock = asyncio.Lock()
58        self._data_available = asyncio.Condition(self._lock)
59        self._space_available = asyncio.Condition(self._lock)
60        self._eof_received = False
61        self._producer_task: asyncio.Task[None] | None = None
62        self._last_access_time: float = time.time()
63        self._inactivity_task: asyncio.Task[None] | None = None
64        self._cancelled = False  # Set to True when buffer is cleared/cancelled
65        self._producer_error: Exception | None = None
66
67    @property
68    def cancelled(self) -> bool:
69        """Return whether the buffer has been cancelled or cleared."""
70        if self._cancelled:
71            return True
72        return self._producer_task is not None and self._producer_task.cancelled()
73
74    @property
75    def chunk_size_bytes(self) -> int:
76        """Return the size in bytes of one second of PCM audio."""
77        return self.pcm_format.pcm_sample_size
78
79    @property
80    def size_seconds(self) -> int:
81        """Return current size of the buffer in seconds."""
82        return len(self._chunks)
83
84    @property
85    def seconds_available(self) -> int:
86        """Return number of seconds of audio currently available in the buffer."""
87        return len(self._chunks)
88
89    def is_valid(self, checksum: str | None = None, seek_position: int = 0) -> bool:
90        """
91        Validate the buffer's checksum and check if seek position is available.
92
93        Args:
94            checksum: The checksum to validate against
95            seek_position: The position we want to seek to (0-based)
96
97        Returns:
98            True if buffer is valid and seek position is available
99        """
100        if self.cancelled:
101            return False
102
103        if checksum is not None and self.checksum != checksum:
104            return False
105
106        # Check if buffer is close to inactivity timeout (within 30 seconds)
107        # to prevent race condition where buffer gets cleared right after validation
108        time_since_access = time.time() - self._last_access_time
109        inactivity_timeout = 60 * 5  # 5 minutes
110        if time_since_access > (inactivity_timeout - 30):
111            # Buffer is close to being cleared, don't reuse it
112            return False
113
114        if seek_position > self._discarded_chunks + self.max_size_seconds:
115            return False
116
117        # Check if the seek position has already been discarded
118        return seek_position >= self._discarded_chunks
119
120    async def put(self, chunk: bytes) -> None:
121        """
122        Put a chunk of data into the buffer.
123
124        Each chunk represents exactly 1 second of PCM audio
125        (except for the last one, which may be shorter).
126        Waits if buffer is full.
127
128        Args:
129            chunk: Bytes representing 1 second of PCM audio
130        """
131        async with self._space_available:
132            # Wait until there's space in the buffer
133            while len(self._chunks) >= self.max_size_seconds and not self._eof_received:
134                if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
135                    LOGGER.log(
136                        VERBOSE_LOG_LEVEL,
137                        "AudioBuffer.put: Buffer full (%s/%s), waiting for space...",
138                        len(self._chunks),
139                        self.max_size_seconds,
140                    )
141                await self._space_available.wait()
142
143            if self._eof_received:
144                # Don't accept new data after EOF
145                LOGGER.log(
146                    VERBOSE_LOG_LEVEL, "AudioBuffer.put: EOF already received, rejecting chunk"
147                )
148                return
149
150            # Add chunk to the list (index = second position)
151            self._chunks.append(chunk)
152            if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
153                LOGGER.log(
154                    VERBOSE_LOG_LEVEL,
155                    "AudioBuffer.put: Added chunk at position %s (size: %s bytes, buffer size: %s)",
156                    self._discarded_chunks + len(self._chunks) - 1,
157                    len(chunk),
158                    len(self._chunks),
159                )
160
161            # Notify waiting consumers
162            self._data_available.notify_all()
163
164    async def get(self, chunk_number: int = 0) -> bytes:
165        """
166        Get one second of data from the buffer at the specified chunk number.
167
168        Waits until requested chunk is available.
169        Discards old chunks if buffer is full.
170
171        Args:
172            chunk_number: The chunk index to retrieve (0-based, absolute position).
173
174        Returns:
175            Bytes containing one second of audio data
176
177        Raises:
178            AudioBufferEOF: If EOF is reached before chunk is available
179            AudioError: If chunk has been discarded
180            Exception: Any exception that occurred in the producer task
181        """
182        # Update last access time
183        self._last_access_time = time.time()
184
185        async with self._data_available:
186            # Check if producer had an error - raise immediately
187            if self._producer_error:
188                raise self._producer_error
189
190            # Check if the chunk was already discarded
191            if chunk_number < self._discarded_chunks:
192                msg = (
193                    f"Chunk {chunk_number} has been discarded "
194                    f"(buffer starts at {self._discarded_chunks})"
195                )
196                raise AudioError(msg)
197
198            # Wait until the requested chunk is available or EOF
199            buffer_index = chunk_number - self._discarded_chunks
200            while buffer_index >= len(self._chunks):
201                # Check if producer had an error - raise immediately
202                if self._producer_error:
203                    raise self._producer_error
204                if self._eof_received:
205                    raise AudioBufferEOF
206                await self._data_available.wait()
207                buffer_index = chunk_number - self._discarded_chunks
208
209            # If buffer is at max size, discard the oldest chunk to make room
210            if len(self._chunks) >= self.max_size_seconds:
211                discarded = self._chunks.popleft()  # O(1) operation with deque
212                self._discarded_chunks += 1
213                if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
214                    LOGGER.log(
215                        VERBOSE_LOG_LEVEL,
216                        "AudioBuffer.get: Discarded chunk %s (size: %s bytes) to free space",
217                        self._discarded_chunks - 1,
218                        len(discarded),
219                    )
220                # Notify producers waiting for space
221                self._space_available.notify_all()
222                # Recalculate buffer index after discard
223                buffer_index = chunk_number - self._discarded_chunks
224
225            # Return the chunk at the requested index
226            return self._chunks[buffer_index]
227
228    async def iter(self, seek_position: int = 0) -> AsyncGenerator[bytes, None]:
229        """
230        Iterate over seconds of audio data until EOF.
231
232        Args:
233            seek_position: Optional starting position in seconds (default: 0).
234
235        Yields:
236            Bytes containing one second of audio data
237        """
238        chunk_number = seek_position
239        while True:
240            try:
241                yield await self.get(chunk_number=chunk_number)
242                chunk_number += 1
243            except AudioBufferEOF:
244                break  # EOF reached
245
246    async def clear(self, cancel_inactivity_task: bool = True) -> None:
247        """Reset the buffer completely, clearing all data."""
248        chunk_count = len(self._chunks)
249        LOGGER.log(
250            VERBOSE_LOG_LEVEL,
251            "AudioBuffer.clear: Resetting buffer (had %s chunks, has producer task: %s)",
252            chunk_count,
253            self._producer_task is not None,
254        )
255        # Cancel producer task if present and wait for it to finish
256        # This ensures any subprocess cleanup happens on the main event loop
257        if self._producer_task and not self._producer_task.done():
258            self._producer_task.cancel()
259            with suppress(asyncio.CancelledError):
260                await self._producer_task
261
262        # Cancel inactivity task if present
263        if cancel_inactivity_task and self._inactivity_task and not self._inactivity_task.done():
264            self._inactivity_task.cancel()
265            with suppress(asyncio.CancelledError):
266                await self._inactivity_task
267
268        async with self._lock:
269            # Replace the deque instead of clearing it to avoid blocking
270            # Clearing a large deque can take >100ms
271            self._chunks = deque()
272            self._discarded_chunks = 0
273            self._eof_received = False
274            self._cancelled = True  # Mark buffer as cancelled
275            self._producer_error = None  # Clear any producer error
276            # Notify all waiting tasks
277            self._data_available.notify_all()
278            self._space_available.notify_all()
279
280    async def set_eof(self) -> None:
281        """Signal that no more data will be added to the buffer."""
282        async with self._lock:
283            LOGGER.log(
284                VERBOSE_LOG_LEVEL,
285                "AudioBuffer.set_eof: Marking EOF (buffer has %s chunks)",
286                len(self._chunks),
287            )
288            self._eof_received = True
289            # Wake up all waiting consumers and producers
290            self._data_available.notify_all()
291            self._space_available.notify_all()
292
293    async def _monitor_inactivity(self) -> None:
294        """Monitor buffer for inactivity and clear if inactive for 5 minutes."""
295        inactivity_timeout = 60 * 5  # 5 minutes
296        check_interval = 30  # Check every 30 seconds
297        while True:
298            await asyncio.sleep(check_interval)
299            # Check if buffer has been inactive (no data and no activity)
300            time_since_access = time.time() - self._last_access_time
301            # If buffer hasn't been accessed for timeout period,
302            # it likely means the producer failed or stream was abandoned
303            if len(self._chunks) > 0 and time_since_access > inactivity_timeout:
304                LOGGER.log(
305                    VERBOSE_LOG_LEVEL,
306                    "AudioBuffer: No activity for %.1f seconds, clearing buffer (had %s chunks)",
307                    time_since_access,
308                    len(self._chunks),
309                )
310                break  # Stop monitoring after clearing
311        # if we reach here, we have broken out of the loop due to inactivity
312        await self.clear(cancel_inactivity_task=False)
313
314    async def _notify_on_producer_error(self) -> None:
315        """Notify waiting consumers that producer has failed.
316
317        This is called from the producer task done callback and properly
318        acquires the lock before calling notify_all.
319        """
320        async with self._lock:
321            self._data_available.notify_all()
322
323    def attach_producer_task(self, task: asyncio.Task[Any]) -> None:
324        """Attach a background task that fills the buffer."""
325        self._producer_task = task
326
327        # Add a callback to capture any exceptions from the producer task
328        def _on_producer_done(t: asyncio.Task[Any]) -> None:
329            """Handle producer task completion."""
330            if t.cancelled():
331                return
332            # Capture any exception that occurred
333            exc = t.exception()
334            if exc is not None and isinstance(exc, Exception):
335                self._producer_error = exc
336                # Mark buffer as cancelled when producer fails
337                # This prevents reuse of a buffer in error state
338                self._cancelled = True
339                # Wake up any waiting consumers so they can see the error
340                # We need to acquire the lock before calling notify_all
341                loop = asyncio.get_running_loop()
342                loop.create_task(self._notify_on_producer_error())
343
344        task.add_done_callback(_on_producer_done)
345
346        # Start inactivity monitor if not already running
347        if self._inactivity_task is None or self._inactivity_task.done():
348            self._last_access_time = time.time()  # Initialize access time
349            loop = asyncio.get_running_loop()
350            self._inactivity_task = loop.create_task(self._monitor_inactivity())
351