music-assistant-server

23.5 KBPY
fades.py
23.5 KB610 lines • python
1"""Smart Fades - Audio fade implementations."""
2
3from __future__ import annotations
4
5import logging
6from abc import ABC, abstractmethod
7from typing import TYPE_CHECKING
8
9import aiofiles
10import numpy as np
11import numpy.typing as npt
12import shortuuid
13
14from music_assistant.constants import VERBOSE_LOG_LEVEL
15from music_assistant.controllers.streams.smart_fades.filters import (
16    CrossfadeFilter,
17    Filter,
18    FrequencySweepFilter,
19    TimeStretchFilter,
20    TrimFilter,
21)
22from music_assistant.helpers.process import communicate
23from music_assistant.helpers.util import remove_file
24from music_assistant.models.smart_fades import (
25    SmartFadesAnalysis,
26)
27
28if TYPE_CHECKING:
29    from music_assistant_models.media_items import AudioFormat
30
31SMART_CROSSFADE_DURATION = 45
32
33
34class SmartFade(ABC):
35    """Abstract base class for Smart Fades."""
36
37    filters: list[Filter]
38
39    def __init__(self, logger: logging.Logger) -> None:
40        """Initialize SmartFade base class."""
41        self.filters = []
42        self.logger = logger
43
44    @abstractmethod
45    def _build(self) -> None:
46        """Build the smart fades filter chain."""
47        ...
48
49    def _get_ffmpeg_filters(
50        self,
51        input_fadein_label: str = "[1]",
52        input_fadeout_label: str = "[0]",
53    ) -> list[str]:
54        """Get FFmpeg filters for smart fades."""
55        if not self.filters:
56            self._build()
57        filters = []
58        _cur_fadein_label = input_fadein_label
59        _cur_fadeout_label = input_fadeout_label
60        for audio_filter in self.filters:
61            filter_strings = audio_filter.apply(_cur_fadein_label, _cur_fadeout_label)
62            filters.extend(filter_strings)
63            _cur_fadein_label = f"[{audio_filter.output_fadein_label}]"
64            _cur_fadeout_label = f"[{audio_filter.output_fadeout_label}]"
65        return filters
66
67    async def apply(
68        self,
69        fade_out_part: bytes,
70        fade_in_part: bytes,
71        pcm_format: AudioFormat,
72    ) -> bytes:
73        """Apply the smart fade to the given PCM audio parts."""
74        # Write the fade_out_part to a temporary file
75        fadeout_filename = f"/tmp/{shortuuid.random(20)}.pcm"  # noqa: S108
76        async with aiofiles.open(fadeout_filename, "wb") as outfile:
77            await outfile.write(fade_out_part)
78
79        args = [
80            "ffmpeg",
81            "-hide_banner",
82            "-loglevel",
83            "error",
84            # Input 1: fadeout part (as file)
85            "-acodec",
86            pcm_format.content_type.name.lower(),  # e.g., "pcm_f32le" not just "f32le"
87            "-ac",
88            str(pcm_format.channels),
89            "-ar",
90            str(pcm_format.sample_rate),
91            "-channel_layout",
92            "mono" if pcm_format.channels == 1 else "stereo",
93            "-f",
94            pcm_format.content_type.value,
95            "-i",
96            fadeout_filename,
97            # Input 2: fade_in part (stdin)
98            "-acodec",
99            pcm_format.content_type.name.lower(),
100            "-ac",
101            str(pcm_format.channels),
102            "-ar",
103            str(pcm_format.sample_rate),
104            "-channel_layout",
105            "mono" if pcm_format.channels == 1 else "stereo",
106            "-f",
107            pcm_format.content_type.value,
108            "-i",
109            "-",
110        ]
111        smart_fade_filters = self._get_ffmpeg_filters()
112        self.logger.debug(
113            "Applying smartfade: %s",
114            self,
115        )
116        args.extend(
117            [
118                "-filter_complex",
119                ";".join(smart_fade_filters),
120                # Output format specification - must match input codec format
121                "-acodec",
122                pcm_format.content_type.name.lower(),
123                "-ac",
124                str(pcm_format.channels),
125                "-ar",
126                str(pcm_format.sample_rate),
127                "-channel_layout",
128                "mono" if pcm_format.channels == 1 else "stereo",
129                "-f",
130                pcm_format.content_type.value,
131                "-",
132            ]
133        )
134        self.logger.log(VERBOSE_LOG_LEVEL, "FFmpeg command args: %s", " ".join(args))
135
136        try:
137            # Execute the enhanced smart fade with full buffer
138            _, raw_crossfade_output, stderr = await communicate(args, fade_in_part)
139
140            if raw_crossfade_output:
141                return raw_crossfade_output
142            stderr_msg = stderr.decode() if stderr else "(no stderr output)"
143            raise RuntimeError(f"Smart crossfade failed. FFmpeg stderr: {stderr_msg}")
144        finally:
145            # Always cleanup temp file, even if ffmpeg fails
146            await remove_file(fadeout_filename)
147
148    def __repr__(self) -> str:
149        """Return string representation of SmartFade showing the filter chain."""
150        if not self.filters:
151            return f"<{self.__class__.__name__}: 0 filters>"
152
153        chain = " → ".join(repr(f) for f in self.filters)
154        return f"<{self.__class__.__name__}: {len(self.filters)} filters> {chain}"
155
156
157class SmartCrossFade(SmartFade):
158    """Smart fades class that implements a Smart Fade mode."""
159
160    # Only apply time stretching if BPM difference is < this %
161    time_stretch_bpm_percentage_threshold: float = 5.0
162
163    def __init__(
164        self,
165        logger: logging.Logger,
166        fade_out_analysis: SmartFadesAnalysis,
167        fade_in_analysis: SmartFadesAnalysis,
168    ) -> None:
169        """Initialize SmartFades with analysis data.
170
171        Args:
172            fade_out_analysis: Analysis data for the outgoing track
173            fade_in_analysis: Analysis data for the incoming track
174            logger: Optional logger for debug output
175        """
176        self.fade_out_analysis = fade_out_analysis
177        self.fade_in_analysis = fade_in_analysis
178        super().__init__(logger)
179
180    def _build(self) -> None:
181        """Build the smart fades filter chain."""
182        # Calculate tempo factor for time stretching
183        bpm_ratio = self.fade_in_analysis.bpm / self.fade_out_analysis.bpm
184        bpm_diff_percent = abs(1.0 - bpm_ratio) * 100
185
186        # Extrapolate downbeats for better bar calculation
187        self.extrapolated_fadeout_downbeats = extrapolate_downbeats(
188            self.fade_out_analysis.downbeats,
189            tempo_factor=1.0,
190            bpm=self.fade_out_analysis.bpm,
191        )
192
193        # Additional verbose logging to debug rare failures
194        self.logger.log(
195            VERBOSE_LOG_LEVEL,
196            "SmartCrossFade build: fade_out: %s, fade_in: %s",
197            self.fade_out_analysis,
198            self.fade_in_analysis,
199        )
200
201        # Calculate optimal crossfade bars that fit in available buffer
202        crossfade_bars = self._calculate_optimal_crossfade_bars()
203
204        # Calculate beat positions for the selected bar count
205        fadein_start_pos = self._calculate_optimal_fade_timing(crossfade_bars)
206
207        # Calculate initial crossfade duration (may be adjusted later for downbeat alignment)
208        crossfade_duration = self._calculate_crossfade_duration(crossfade_bars=crossfade_bars)
209
210        # Add time stretch filter if needed
211        if (
212            0.1 < bpm_diff_percent <= self.time_stretch_bpm_percentage_threshold
213            and crossfade_bars > 4
214        ):
215            self.filters.append(TimeStretchFilter(logger=self.logger, stretch_ratio=bpm_ratio))
216            # Re-extrapolate downbeats with actual tempo factor for time-stretched audio
217            self.extrapolated_fadeout_downbeats = extrapolate_downbeats(
218                self.fade_out_analysis.downbeats,
219                tempo_factor=bpm_ratio,
220                bpm=self.fade_out_analysis.bpm,
221            )
222
223        # Check if we would have enough audio after beat alignment for the crossfade
224        if fadein_start_pos and fadein_start_pos + crossfade_duration <= SMART_CROSSFADE_DURATION:
225            self.filters.append(TrimFilter(logger=self.logger, fadein_start_pos=fadein_start_pos))
226        else:
227            self.logger.log(
228                VERBOSE_LOG_LEVEL,
229                "Skipping beat alignment: not enough audio after trim (%.1fs + %.1fs > %.1fs)",
230                fadein_start_pos,
231                crossfade_duration,
232                SMART_CROSSFADE_DURATION,
233            )
234
235        # Adjust crossfade duration to align with outgoing track's downbeats
236        crossfade_duration = self._adjust_crossfade_to_downbeats(
237            crossfade_duration=crossfade_duration,
238            fadein_start_pos=fadein_start_pos,
239        )
240
241        # 90 BPM -> 1500Hz, 140 BPM -> 2500Hz
242        avg_bpm = (self.fade_out_analysis.bpm + self.fade_in_analysis.bpm) / 2
243        crossover_freq = int(np.clip(1500 + (avg_bpm - 90) * 20, 1500, 2500))
244
245        # Adjust for BPM mismatch
246        if abs(bpm_ratio - 1.0) > 0.3:
247            crossover_freq = int(crossover_freq * 0.85)
248
249        # For shorter fades, use exp/exp curves to avoid abruptness
250        if crossfade_bars < 8:
251            fadeout_curve = "exponential"
252            fadein_curve = "exponential"
253        # For long fades, use log/linear curves
254        else:
255            # Use logarithmic curve to give the next track more space
256            fadeout_curve = "logarithmic"
257            # Use linear curve for transition, predictable and not too abrupt
258            fadein_curve = "linear"
259
260        # Create lowpass filter on the outgoing track (unfiltered → low-pass)
261        # Extended lowpass effect to gradually remove bass frequencies
262        fadeout_eq_duration = min(max(crossfade_duration * 2.5, 8.0), SMART_CROSSFADE_DURATION)
263        # The crossfade always happens at the END of the buffer
264        fadeout_eq_start = max(0, SMART_CROSSFADE_DURATION - fadeout_eq_duration)
265        fadeout_sweep = FrequencySweepFilter(
266            logger=self.logger,
267            sweep_type="lowpass",
268            target_freq=crossover_freq,
269            duration=fadeout_eq_duration,
270            start_time=fadeout_eq_start,
271            sweep_direction="fade_in",
272            poles=1,
273            curve_type=fadeout_curve,
274            stream_type="fadeout",
275        )
276        self.filters.append(fadeout_sweep)
277
278        # Create high pass filter on the incoming track (high-pass → unfiltered)
279        # Quicker highpass removal to avoid lingering vocals after crossfade
280        fadein_eq_duration = crossfade_duration / 1.5
281        fadein_sweep = FrequencySweepFilter(
282            logger=self.logger,
283            sweep_type="highpass",
284            target_freq=crossover_freq,
285            duration=fadein_eq_duration,
286            start_time=0,
287            sweep_direction="fade_out",
288            poles=1,
289            curve_type=fadein_curve,
290            stream_type="fadein",
291        )
292        self.filters.append(fadein_sweep)
293
294        # Add final crossfade filter
295        crossfade_filter = CrossfadeFilter(
296            logger=self.logger, crossfade_duration=crossfade_duration
297        )
298        self.filters.append(crossfade_filter)
299
300    def _calculate_crossfade_duration(self, crossfade_bars: int) -> float:
301        """Calculate final crossfade duration based on musical bars and BPM."""
302        # Calculate crossfade duration based on incoming track's BPM
303        beats_per_bar = 4
304        seconds_per_beat = 60.0 / self.fade_in_analysis.bpm
305        musical_duration = crossfade_bars * beats_per_bar * seconds_per_beat
306
307        # Apply buffer constraint
308        actual_duration = min(musical_duration, SMART_CROSSFADE_DURATION)
309
310        # Log if we had to constrain the duration
311        if musical_duration > SMART_CROSSFADE_DURATION:
312            self.logger.log(
313                VERBOSE_LOG_LEVEL,
314                "Constraining crossfade duration from %.1fs to %.1fs (buffer limit)",
315                musical_duration,
316                actual_duration,
317            )
318
319        return actual_duration
320
321    def _calculate_optimal_crossfade_bars(self) -> int:
322        """Calculate optimal crossfade bars that fit in available buffer."""
323        bpm_in = self.fade_in_analysis.bpm
324        bpm_out = self.fade_out_analysis.bpm
325        bpm_diff_percent = abs(1.0 - bpm_in / bpm_out) * 100
326
327        # Calculate ideal bars based on BPM compatibility
328        ideal_bars = 10 if bpm_diff_percent <= self.time_stretch_bpm_percentage_threshold else 6
329
330        # Reduce bars until it fits in the fadein buffer
331        for bars in [ideal_bars, 8, 6, 4, 2, 1]:
332            if bars > ideal_bars:
333                continue
334
335            fadein_start_pos = self._calculate_optimal_fade_timing(bars)
336            if fadein_start_pos is None:
337                continue
338
339            # Calculate what the duration would be
340            test_duration = self._calculate_crossfade_duration(crossfade_bars=bars)
341
342            # Check if it fits in fadein buffer
343            fadein_buffer = SMART_CROSSFADE_DURATION - fadein_start_pos
344            if test_duration <= fadein_buffer:
345                if bars < ideal_bars:
346                    self.logger.log(
347                        VERBOSE_LOG_LEVEL,
348                        "Reduced crossfade from %d to %d bars (fadein buffer=%.1fs, needed=%.1fs)",
349                        ideal_bars,
350                        bars,
351                        fadein_buffer,
352                        test_duration,
353                    )
354                return bars
355
356        # Fall back to 1 bar if nothing else fits
357        return 1
358
359    def _calculate_optimal_fade_timing(self, crossfade_bars: int) -> float | None:
360        """Calculate beat positions for alignment."""
361        beats_per_bar = 4
362
363        def calculate_beat_positions(
364            fade_out_beats: npt.NDArray[np.float64],
365            fade_in_beats: npt.NDArray[np.float64],
366            num_beats: int,
367        ) -> float | None:
368            """Calculate start positions from beat arrays."""
369            if len(fade_out_beats) < num_beats or len(fade_in_beats) < num_beats:
370                return None
371
372            fade_in_slice = fade_in_beats[:num_beats]
373            return float(fade_in_slice[0])
374
375        # Try downbeats first for most musical timing
376        downbeat_positions = calculate_beat_positions(
377            self.extrapolated_fadeout_downbeats, self.fade_in_analysis.downbeats, crossfade_bars
378        )
379        if downbeat_positions:
380            return downbeat_positions
381
382        # Try regular beats if downbeats insufficient
383        required_beats = crossfade_bars * beats_per_bar
384        beat_positions = calculate_beat_positions(
385            self.fade_out_analysis.beats, self.fade_in_analysis.beats, required_beats
386        )
387        if beat_positions:
388            return beat_positions
389
390        # Fallback: No beat alignment possible
391        self.logger.log(VERBOSE_LOG_LEVEL, "No beat alignment possible (insufficient beats)")
392        return None
393
394    def _adjust_crossfade_to_downbeats(
395        self,
396        crossfade_duration: float,
397        fadein_start_pos: float | None,
398    ) -> float:
399        """Adjust crossfade duration to align with outgoing track's downbeats."""
400        # If we don't have downbeats or beat alignment is disabled, return original duration
401        if len(self.extrapolated_fadeout_downbeats) == 0 or fadein_start_pos is None:
402            return crossfade_duration
403
404        # Calculate where the crossfade would start in the buffer
405        ideal_start_pos = SMART_CROSSFADE_DURATION - crossfade_duration
406
407        # Debug logging
408        self.logger.log(
409            VERBOSE_LOG_LEVEL,
410            "Downbeat adjustment - ideal_start=%.2fs (buffer=%.1fs - crossfade=%.2fs), "
411            "fadein_start=%.2fs",
412            ideal_start_pos,
413            SMART_CROSSFADE_DURATION,
414            crossfade_duration,
415            fadein_start_pos,
416        )
417
418        # Find the closest downbeats (earlier and later)
419        earlier_downbeat = None
420        later_downbeat = None
421
422        for downbeat in self.extrapolated_fadeout_downbeats:
423            if downbeat <= ideal_start_pos:
424                earlier_downbeat = downbeat
425            elif downbeat > ideal_start_pos and later_downbeat is None:
426                later_downbeat = downbeat
427                break
428
429        # Try earlier downbeat first (longer crossfade)
430        if earlier_downbeat is not None:
431            adjusted_duration = float(SMART_CROSSFADE_DURATION - earlier_downbeat)
432            if fadein_start_pos + adjusted_duration <= SMART_CROSSFADE_DURATION:
433                if abs(adjusted_duration - crossfade_duration) > 0.1:
434                    self.logger.log(
435                        VERBOSE_LOG_LEVEL,
436                        "Adjusted crossfade duration from %.2fs to %.2fs to align with "
437                        "downbeat at %.2fs (earlier)",
438                        crossfade_duration,
439                        adjusted_duration,
440                        earlier_downbeat,
441                    )
442                return adjusted_duration
443
444        # Try later downbeat (shorter crossfade)
445        if later_downbeat is not None:
446            adjusted_duration = float(SMART_CROSSFADE_DURATION - later_downbeat)
447            if fadein_start_pos + adjusted_duration <= SMART_CROSSFADE_DURATION:
448                if abs(adjusted_duration - crossfade_duration) > 0.1:
449                    self.logger.log(
450                        VERBOSE_LOG_LEVEL,
451                        "Adjusted crossfade duration from %.2fs to %.2fs to align with "
452                        "downbeat at %.2fs (later)",
453                        crossfade_duration,
454                        adjusted_duration,
455                        later_downbeat,
456                    )
457                return adjusted_duration
458
459        # If no suitable downbeat found, return original duration
460        self.logger.log(
461            VERBOSE_LOG_LEVEL,
462            "Could not adjust crossfade duration to downbeats, using original %.2fs",
463            crossfade_duration,
464        )
465        return crossfade_duration
466
467
468class StandardCrossFade(SmartFade):
469    """Standard crossfade class that implements a standard crossfade mode."""
470
471    def __init__(self, logger: logging.Logger, crossfade_duration: float = 10.0) -> None:
472        """Initialize StandardCrossFade with crossfade duration."""
473        self.crossfade_duration = crossfade_duration
474        super().__init__(logger)
475
476    def _build(self) -> None:
477        """Build the standard crossfade filter chain."""
478        self.filters = [
479            CrossfadeFilter(logger=self.logger, crossfade_duration=self.crossfade_duration),
480        ]
481
482    async def apply(
483        self, fade_out_part: bytes, fade_in_part: bytes, pcm_format: AudioFormat
484    ) -> bytes:
485        """Apply the standard crossfade to the given PCM audio parts."""
486        # We need to override the default apply here, since standard crossfade only needs to be
487        # applied to the overlapping parts, not the full buffers.
488        crossfade_size = int(pcm_format.pcm_sample_size * self.crossfade_duration)
489        # Pre-crossfade: outgoing track minus the crossfaded portion
490        pre_crossfade = fade_out_part[:-crossfade_size]
491        # Post-crossfade: incoming track minus the crossfaded portion
492        post_crossfade = fade_in_part[crossfade_size:]
493        # Adjust portions to exact crossfade size
494        adjusted_fade_in_part = fade_in_part[:crossfade_size]
495        adjusted_fade_out_part = fade_out_part[-crossfade_size:]
496        # Adjust the duration to match actual sizes
497        self.crossfade_duration = min(
498            len(adjusted_fade_in_part) / pcm_format.pcm_sample_size,
499            len(adjusted_fade_out_part) / pcm_format.pcm_sample_size,
500        )
501        # Crossfaded portion: user's configured duration
502        crossfaded_section = await super().apply(
503            adjusted_fade_out_part, adjusted_fade_in_part, pcm_format
504        )
505        # Full result: everything concatenated
506        return pre_crossfade + crossfaded_section + post_crossfade
507
508
509# HELPER METHODS
510def get_bpm_diff_percentage(bpm1: float, bpm2: float) -> float:
511    """Calculate BPM difference percentage between two BPM values."""
512    return abs(1.0 - bpm1 / bpm2) * 100
513
514
515def extrapolate_downbeats(
516    downbeats: npt.NDArray[np.float64],
517    tempo_factor: float,
518    buffer_size: float = SMART_CROSSFADE_DURATION,
519    bpm: float | None = None,
520) -> npt.NDArray[np.float64]:
521    """Extrapolate downbeats based on actual intervals when detection is incomplete.
522
523    This is needed when we want to perform beat alignment in an 'atmospheric' outro
524    that does not have any detected downbeats.
525
526    Args:
527        downbeats: Array of detected downbeat positions in seconds
528        tempo_factor: Tempo adjustment factor for time stretching
529        buffer_size: Maximum buffer size in seconds
530        bpm: Optional BPM for validation when extrapolating with only 2 downbeats
531    """
532    # Handle case with exactly 2 downbeats (with BPM validation)
533    if len(downbeats) == 2 and bpm is not None:
534        interval = float(downbeats[1] - downbeats[0])
535
536        # Expected interval for this BPM (assuming 4/4 time signature)
537        expected_interval = (60.0 / bpm) * 4
538
539        # Only extrapolate if interval matches BPM within 15% tolerance
540        if abs(interval - expected_interval) / expected_interval < 0.15:
541            # Adjust detected downbeats for time stretching first
542            adjusted_downbeats = downbeats / tempo_factor
543            last_downbeat = adjusted_downbeats[-1]
544
545            # If the last downbeat is close to the buffer end, no extrapolation needed
546            if last_downbeat >= buffer_size - 5:
547                return adjusted_downbeats
548
549            # Adjust the interval for time stretching
550            adjusted_interval = interval / tempo_factor
551
552            # Extrapolate forward from last adjusted downbeat using adjusted interval
553            extrapolated = []
554            current_pos = last_downbeat + adjusted_interval
555            max_extrapolation_distance = 125.0  # Don't extrapolate more than 25s
556
557            while (
558                current_pos < buffer_size
559                and (current_pos - last_downbeat) <= max_extrapolation_distance
560            ):
561                extrapolated.append(current_pos)
562                current_pos += adjusted_interval
563
564            if extrapolated:
565                # Combine adjusted detected downbeats and extrapolated downbeats
566                return np.concatenate([adjusted_downbeats, np.array(extrapolated)])
567
568            return adjusted_downbeats
569        # else: interval doesn't match BPM, fall through to return original
570
571    if len(downbeats) < 2:
572        # Need at least 2 downbeats to extrapolate
573        return downbeats / tempo_factor
574
575    # Adjust detected downbeats for time stretching first
576    adjusted_downbeats = downbeats / tempo_factor
577    last_downbeat = adjusted_downbeats[-1]
578
579    # If the last downbeat is close to the buffer end, no extrapolation needed
580    if last_downbeat >= buffer_size - 5:
581        return adjusted_downbeats
582
583    # Calculate intervals from ORIGINAL downbeats (before time stretching)
584    intervals = np.diff(downbeats)
585    median_interval = float(np.median(intervals))
586    std_interval = float(np.std(intervals))
587
588    # Only extrapolate if intervals are consistent (low standard deviation)
589    if std_interval > 0.2:
590        return adjusted_downbeats
591
592    # Adjust the interval for time stretching
593    # When slowing down (tempo_factor < 1.0), intervals get longer
594    adjusted_interval = median_interval / tempo_factor
595
596    # Extrapolate forward from last adjusted downbeat using adjusted interval
597    extrapolated = []
598    current_pos = last_downbeat + adjusted_interval
599    max_extrapolation_distance = 25.0  # Don't extrapolate more than 25s
600
601    while current_pos < buffer_size and (current_pos - last_downbeat) <= max_extrapolation_distance:
602        extrapolated.append(current_pos)
603        current_pos += adjusted_interval
604
605    if extrapolated:
606        # Combine adjusted detected downbeats and extrapolated downbeats
607        return np.concatenate([adjusted_downbeats, np.array(extrapolated)])
608
609    return adjusted_downbeats
610