music-assistant-server

23.4 KBPY
fades.py
23.4 KB606 lines • python
1"""Smart Fades - Audio fade implementations."""
2
3from __future__ import annotations
4
5import logging
6from abc import ABC, abstractmethod
7from typing import TYPE_CHECKING
8
9import aiofiles
10import numpy as np
11import numpy.typing as npt
12import shortuuid
13
14from music_assistant.constants import VERBOSE_LOG_LEVEL
15from music_assistant.controllers.streams.smart_fades.filters import (
16    CrossfadeFilter,
17    Filter,
18    FrequencySweepFilter,
19    TimeStretchFilter,
20    TrimFilter,
21)
22from music_assistant.helpers.process import communicate
23from music_assistant.helpers.util import remove_file
24from music_assistant.models.smart_fades import (
25    SmartFadesAnalysis,
26)
27
28if TYPE_CHECKING:
29    from music_assistant_models.media_items import AudioFormat
30
31SMART_CROSSFADE_DURATION = 45
32
33
34class SmartFade(ABC):
35    """Abstract base class for Smart Fades."""
36
37    filters: list[Filter]
38
39    def __init__(self, logger: logging.Logger) -> None:
40        """Initialize SmartFade base class."""
41        self.filters = []
42        self.logger = logger
43
44    @abstractmethod
45    def _build(self) -> None:
46        """Build the smart fades filter chain."""
47        ...
48
49    def _get_ffmpeg_filters(
50        self,
51        input_fadein_label: str = "[1]",
52        input_fadeout_label: str = "[0]",
53    ) -> list[str]:
54        """Get FFmpeg filters for smart fades."""
55        if not self.filters:
56            self._build()
57        filters = []
58        _cur_fadein_label = input_fadein_label
59        _cur_fadeout_label = input_fadeout_label
60        for audio_filter in self.filters:
61            filter_strings = audio_filter.apply(_cur_fadein_label, _cur_fadeout_label)
62            filters.extend(filter_strings)
63            _cur_fadein_label = f"[{audio_filter.output_fadein_label}]"
64            _cur_fadeout_label = f"[{audio_filter.output_fadeout_label}]"
65        return filters
66
67    async def apply(
68        self,
69        fade_out_part: bytes,
70        fade_in_part: bytes,
71        pcm_format: AudioFormat,
72    ) -> bytes:
73        """Apply the smart fade to the given PCM audio parts."""
74        # Write the fade_out_part to a temporary file
75        fadeout_filename = f"/tmp/{shortuuid.random(20)}.pcm"  # noqa: S108
76        async with aiofiles.open(fadeout_filename, "wb") as outfile:
77            await outfile.write(fade_out_part)
78        args = [
79            "ffmpeg",
80            "-hide_banner",
81            "-loglevel",
82            "error",
83            # Input 1: fadeout part (as file)
84            "-acodec",
85            pcm_format.content_type.name.lower(),  # e.g., "pcm_f32le" not just "f32le"
86            "-ac",
87            str(pcm_format.channels),
88            "-ar",
89            str(pcm_format.sample_rate),
90            "-channel_layout",
91            "mono" if pcm_format.channels == 1 else "stereo",
92            "-f",
93            pcm_format.content_type.value,
94            "-i",
95            fadeout_filename,
96            # Input 2: fade_in part (stdin)
97            "-acodec",
98            pcm_format.content_type.name.lower(),
99            "-ac",
100            str(pcm_format.channels),
101            "-ar",
102            str(pcm_format.sample_rate),
103            "-channel_layout",
104            "mono" if pcm_format.channels == 1 else "stereo",
105            "-f",
106            pcm_format.content_type.value,
107            "-i",
108            "-",
109        ]
110        smart_fade_filters = self._get_ffmpeg_filters()
111        self.logger.debug(
112            "Applying smartfade: %s",
113            self,
114        )
115        args.extend(
116            [
117                "-filter_complex",
118                ";".join(smart_fade_filters),
119                # Output format specification - must match input codec format
120                "-acodec",
121                pcm_format.content_type.name.lower(),
122                "-ac",
123                str(pcm_format.channels),
124                "-ar",
125                str(pcm_format.sample_rate),
126                "-channel_layout",
127                "mono" if pcm_format.channels == 1 else "stereo",
128                "-f",
129                pcm_format.content_type.value,
130                "-",
131            ]
132        )
133        self.logger.log(VERBOSE_LOG_LEVEL, "FFmpeg command args: %s", " ".join(args))
134
135        # Execute the enhanced smart fade with full buffer
136        _, raw_crossfade_output, stderr = await communicate(args, fade_in_part)
137        await remove_file(fadeout_filename)
138
139        if raw_crossfade_output:
140            return raw_crossfade_output
141        stderr_msg = stderr.decode() if stderr else "(no stderr output)"
142        raise RuntimeError(f"Smart crossfade failed. FFmpeg stderr: {stderr_msg}")
143
144    def __repr__(self) -> str:
145        """Return string representation of SmartFade showing the filter chain."""
146        if not self.filters:
147            return f"<{self.__class__.__name__}: 0 filters>"
148
149        chain = " → ".join(repr(f) for f in self.filters)
150        return f"<{self.__class__.__name__}: {len(self.filters)} filters> {chain}"
151
152
153class SmartCrossFade(SmartFade):
154    """Smart fades class that implements a Smart Fade mode."""
155
156    # Only apply time stretching if BPM difference is < this %
157    time_stretch_bpm_percentage_threshold: float = 5.0
158
159    def __init__(
160        self,
161        logger: logging.Logger,
162        fade_out_analysis: SmartFadesAnalysis,
163        fade_in_analysis: SmartFadesAnalysis,
164    ) -> None:
165        """Initialize SmartFades with analysis data.
166
167        Args:
168            fade_out_analysis: Analysis data for the outgoing track
169            fade_in_analysis: Analysis data for the incoming track
170            logger: Optional logger for debug output
171        """
172        self.fade_out_analysis = fade_out_analysis
173        self.fade_in_analysis = fade_in_analysis
174        super().__init__(logger)
175
176    def _build(self) -> None:
177        """Build the smart fades filter chain."""
178        # Calculate tempo factor for time stretching
179        bpm_ratio = self.fade_in_analysis.bpm / self.fade_out_analysis.bpm
180        bpm_diff_percent = abs(1.0 - bpm_ratio) * 100
181
182        # Extrapolate downbeats for better bar calculation
183        self.extrapolated_fadeout_downbeats = extrapolate_downbeats(
184            self.fade_out_analysis.downbeats,
185            tempo_factor=1.0,
186            bpm=self.fade_out_analysis.bpm,
187        )
188
189        # Additional verbose logging to debug rare failures
190        self.logger.log(
191            VERBOSE_LOG_LEVEL,
192            "SmartCrossFade build: fade_out: %s, fade_in: %s",
193            self.fade_out_analysis,
194            self.fade_in_analysis,
195        )
196
197        # Calculate optimal crossfade bars that fit in available buffer
198        crossfade_bars = self._calculate_optimal_crossfade_bars()
199
200        # Calculate beat positions for the selected bar count
201        fadein_start_pos = self._calculate_optimal_fade_timing(crossfade_bars)
202
203        # Calculate initial crossfade duration (may be adjusted later for downbeat alignment)
204        crossfade_duration = self._calculate_crossfade_duration(crossfade_bars=crossfade_bars)
205
206        # Add time stretch filter if needed
207        if (
208            0.1 < bpm_diff_percent <= self.time_stretch_bpm_percentage_threshold
209            and crossfade_bars > 4
210        ):
211            self.filters.append(TimeStretchFilter(logger=self.logger, stretch_ratio=bpm_ratio))
212            # Re-extrapolate downbeats with actual tempo factor for time-stretched audio
213            self.extrapolated_fadeout_downbeats = extrapolate_downbeats(
214                self.fade_out_analysis.downbeats,
215                tempo_factor=bpm_ratio,
216                bpm=self.fade_out_analysis.bpm,
217            )
218
219        # Check if we would have enough audio after beat alignment for the crossfade
220        if fadein_start_pos and fadein_start_pos + crossfade_duration <= SMART_CROSSFADE_DURATION:
221            self.filters.append(TrimFilter(logger=self.logger, fadein_start_pos=fadein_start_pos))
222        else:
223            self.logger.log(
224                VERBOSE_LOG_LEVEL,
225                "Skipping beat alignment: not enough audio after trim (%.1fs + %.1fs > %.1fs)",
226                fadein_start_pos,
227                crossfade_duration,
228                SMART_CROSSFADE_DURATION,
229            )
230
231        # Adjust crossfade duration to align with outgoing track's downbeats
232        crossfade_duration = self._adjust_crossfade_to_downbeats(
233            crossfade_duration=crossfade_duration,
234            fadein_start_pos=fadein_start_pos,
235        )
236
237        # 90 BPM -> 1500Hz, 140 BPM -> 2500Hz
238        avg_bpm = (self.fade_out_analysis.bpm + self.fade_in_analysis.bpm) / 2
239        crossover_freq = int(np.clip(1500 + (avg_bpm - 90) * 20, 1500, 2500))
240
241        # Adjust for BPM mismatch
242        if abs(bpm_ratio - 1.0) > 0.3:
243            crossover_freq = int(crossover_freq * 0.85)
244
245        # For shorter fades, use exp/exp curves to avoid abruptness
246        if crossfade_bars < 8:
247            fadeout_curve = "exponential"
248            fadein_curve = "exponential"
249        # For long fades, use log/linear curves
250        else:
251            # Use logarithmic curve to give the next track more space
252            fadeout_curve = "logarithmic"
253            # Use linear curve for transition, predictable and not too abrupt
254            fadein_curve = "linear"
255
256        # Create lowpass filter on the outgoing track (unfiltered → low-pass)
257        # Extended lowpass effect to gradually remove bass frequencies
258        fadeout_eq_duration = min(max(crossfade_duration * 2.5, 8.0), SMART_CROSSFADE_DURATION)
259        # The crossfade always happens at the END of the buffer
260        fadeout_eq_start = max(0, SMART_CROSSFADE_DURATION - fadeout_eq_duration)
261        fadeout_sweep = FrequencySweepFilter(
262            logger=self.logger,
263            sweep_type="lowpass",
264            target_freq=crossover_freq,
265            duration=fadeout_eq_duration,
266            start_time=fadeout_eq_start,
267            sweep_direction="fade_in",
268            poles=1,
269            curve_type=fadeout_curve,
270            stream_type="fadeout",
271        )
272        self.filters.append(fadeout_sweep)
273
274        # Create high pass filter on the incoming track (high-pass → unfiltered)
275        # Quicker highpass removal to avoid lingering vocals after crossfade
276        fadein_eq_duration = crossfade_duration / 1.5
277        fadein_sweep = FrequencySweepFilter(
278            logger=self.logger,
279            sweep_type="highpass",
280            target_freq=crossover_freq,
281            duration=fadein_eq_duration,
282            start_time=0,
283            sweep_direction="fade_out",
284            poles=1,
285            curve_type=fadein_curve,
286            stream_type="fadein",
287        )
288        self.filters.append(fadein_sweep)
289
290        # Add final crossfade filter
291        crossfade_filter = CrossfadeFilter(
292            logger=self.logger, crossfade_duration=crossfade_duration
293        )
294        self.filters.append(crossfade_filter)
295
296    def _calculate_crossfade_duration(self, crossfade_bars: int) -> float:
297        """Calculate final crossfade duration based on musical bars and BPM."""
298        # Calculate crossfade duration based on incoming track's BPM
299        beats_per_bar = 4
300        seconds_per_beat = 60.0 / self.fade_in_analysis.bpm
301        musical_duration = crossfade_bars * beats_per_bar * seconds_per_beat
302
303        # Apply buffer constraint
304        actual_duration = min(musical_duration, SMART_CROSSFADE_DURATION)
305
306        # Log if we had to constrain the duration
307        if musical_duration > SMART_CROSSFADE_DURATION:
308            self.logger.log(
309                VERBOSE_LOG_LEVEL,
310                "Constraining crossfade duration from %.1fs to %.1fs (buffer limit)",
311                musical_duration,
312                actual_duration,
313            )
314
315        return actual_duration
316
317    def _calculate_optimal_crossfade_bars(self) -> int:
318        """Calculate optimal crossfade bars that fit in available buffer."""
319        bpm_in = self.fade_in_analysis.bpm
320        bpm_out = self.fade_out_analysis.bpm
321        bpm_diff_percent = abs(1.0 - bpm_in / bpm_out) * 100
322
323        # Calculate ideal bars based on BPM compatibility
324        ideal_bars = 10 if bpm_diff_percent <= self.time_stretch_bpm_percentage_threshold else 6
325
326        # Reduce bars until it fits in the fadein buffer
327        for bars in [ideal_bars, 8, 6, 4, 2, 1]:
328            if bars > ideal_bars:
329                continue
330
331            fadein_start_pos = self._calculate_optimal_fade_timing(bars)
332            if fadein_start_pos is None:
333                continue
334
335            # Calculate what the duration would be
336            test_duration = self._calculate_crossfade_duration(crossfade_bars=bars)
337
338            # Check if it fits in fadein buffer
339            fadein_buffer = SMART_CROSSFADE_DURATION - fadein_start_pos
340            if test_duration <= fadein_buffer:
341                if bars < ideal_bars:
342                    self.logger.log(
343                        VERBOSE_LOG_LEVEL,
344                        "Reduced crossfade from %d to %d bars (fadein buffer=%.1fs, needed=%.1fs)",
345                        ideal_bars,
346                        bars,
347                        fadein_buffer,
348                        test_duration,
349                    )
350                return bars
351
352        # Fall back to 1 bar if nothing else fits
353        return 1
354
355    def _calculate_optimal_fade_timing(self, crossfade_bars: int) -> float | None:
356        """Calculate beat positions for alignment."""
357        beats_per_bar = 4
358
359        def calculate_beat_positions(
360            fade_out_beats: npt.NDArray[np.float64],
361            fade_in_beats: npt.NDArray[np.float64],
362            num_beats: int,
363        ) -> float | None:
364            """Calculate start positions from beat arrays."""
365            if len(fade_out_beats) < num_beats or len(fade_in_beats) < num_beats:
366                return None
367
368            fade_in_slice = fade_in_beats[:num_beats]
369            return float(fade_in_slice[0])
370
371        # Try downbeats first for most musical timing
372        downbeat_positions = calculate_beat_positions(
373            self.extrapolated_fadeout_downbeats, self.fade_in_analysis.downbeats, crossfade_bars
374        )
375        if downbeat_positions:
376            return downbeat_positions
377
378        # Try regular beats if downbeats insufficient
379        required_beats = crossfade_bars * beats_per_bar
380        beat_positions = calculate_beat_positions(
381            self.fade_out_analysis.beats, self.fade_in_analysis.beats, required_beats
382        )
383        if beat_positions:
384            return beat_positions
385
386        # Fallback: No beat alignment possible
387        self.logger.log(VERBOSE_LOG_LEVEL, "No beat alignment possible (insufficient beats)")
388        return None
389
390    def _adjust_crossfade_to_downbeats(
391        self,
392        crossfade_duration: float,
393        fadein_start_pos: float | None,
394    ) -> float:
395        """Adjust crossfade duration to align with outgoing track's downbeats."""
396        # If we don't have downbeats or beat alignment is disabled, return original duration
397        if len(self.extrapolated_fadeout_downbeats) == 0 or fadein_start_pos is None:
398            return crossfade_duration
399
400        # Calculate where the crossfade would start in the buffer
401        ideal_start_pos = SMART_CROSSFADE_DURATION - crossfade_duration
402
403        # Debug logging
404        self.logger.log(
405            VERBOSE_LOG_LEVEL,
406            "Downbeat adjustment - ideal_start=%.2fs (buffer=%.1fs - crossfade=%.2fs), "
407            "fadein_start=%.2fs",
408            ideal_start_pos,
409            SMART_CROSSFADE_DURATION,
410            crossfade_duration,
411            fadein_start_pos,
412        )
413
414        # Find the closest downbeats (earlier and later)
415        earlier_downbeat = None
416        later_downbeat = None
417
418        for downbeat in self.extrapolated_fadeout_downbeats:
419            if downbeat <= ideal_start_pos:
420                earlier_downbeat = downbeat
421            elif downbeat > ideal_start_pos and later_downbeat is None:
422                later_downbeat = downbeat
423                break
424
425        # Try earlier downbeat first (longer crossfade)
426        if earlier_downbeat is not None:
427            adjusted_duration = float(SMART_CROSSFADE_DURATION - earlier_downbeat)
428            if fadein_start_pos + adjusted_duration <= SMART_CROSSFADE_DURATION:
429                if abs(adjusted_duration - crossfade_duration) > 0.1:
430                    self.logger.log(
431                        VERBOSE_LOG_LEVEL,
432                        "Adjusted crossfade duration from %.2fs to %.2fs to align with "
433                        "downbeat at %.2fs (earlier)",
434                        crossfade_duration,
435                        adjusted_duration,
436                        earlier_downbeat,
437                    )
438                return adjusted_duration
439
440        # Try later downbeat (shorter crossfade)
441        if later_downbeat is not None:
442            adjusted_duration = float(SMART_CROSSFADE_DURATION - later_downbeat)
443            if fadein_start_pos + adjusted_duration <= SMART_CROSSFADE_DURATION:
444                if abs(adjusted_duration - crossfade_duration) > 0.1:
445                    self.logger.log(
446                        VERBOSE_LOG_LEVEL,
447                        "Adjusted crossfade duration from %.2fs to %.2fs to align with "
448                        "downbeat at %.2fs (later)",
449                        crossfade_duration,
450                        adjusted_duration,
451                        later_downbeat,
452                    )
453                return adjusted_duration
454
455        # If no suitable downbeat found, return original duration
456        self.logger.log(
457            VERBOSE_LOG_LEVEL,
458            "Could not adjust crossfade duration to downbeats, using original %.2fs",
459            crossfade_duration,
460        )
461        return crossfade_duration
462
463
464class StandardCrossFade(SmartFade):
465    """Standard crossfade class that implements a standard crossfade mode."""
466
467    def __init__(self, logger: logging.Logger, crossfade_duration: float = 10.0) -> None:
468        """Initialize StandardCrossFade with crossfade duration."""
469        self.crossfade_duration = crossfade_duration
470        super().__init__(logger)
471
472    def _build(self) -> None:
473        """Build the standard crossfade filter chain."""
474        self.filters = [
475            CrossfadeFilter(logger=self.logger, crossfade_duration=self.crossfade_duration),
476        ]
477
478    async def apply(
479        self, fade_out_part: bytes, fade_in_part: bytes, pcm_format: AudioFormat
480    ) -> bytes:
481        """Apply the standard crossfade to the given PCM audio parts."""
482        # We need to override the default apply here, since standard crossfade only needs to be
483        # applied to the overlapping parts, not the full buffers.
484        crossfade_size = int(pcm_format.pcm_sample_size * self.crossfade_duration)
485        # Pre-crossfade: outgoing track minus the crossfaded portion
486        pre_crossfade = fade_out_part[:-crossfade_size]
487        # Post-crossfade: incoming track minus the crossfaded portion
488        post_crossfade = fade_in_part[crossfade_size:]
489        # Adjust portions to exact crossfade size
490        adjusted_fade_in_part = fade_in_part[:crossfade_size]
491        adjusted_fade_out_part = fade_out_part[-crossfade_size:]
492        # Adjust the duration to match actual sizes
493        self.crossfade_duration = min(
494            len(adjusted_fade_in_part) / pcm_format.pcm_sample_size,
495            len(adjusted_fade_out_part) / pcm_format.pcm_sample_size,
496        )
497        # Crossfaded portion: user's configured duration
498        crossfaded_section = await super().apply(
499            adjusted_fade_out_part, adjusted_fade_in_part, pcm_format
500        )
501        # Full result: everything concatenated
502        return pre_crossfade + crossfaded_section + post_crossfade
503
504
505# HELPER METHODS
506def get_bpm_diff_percentage(bpm1: float, bpm2: float) -> float:
507    """Calculate BPM difference percentage between two BPM values."""
508    return abs(1.0 - bpm1 / bpm2) * 100
509
510
511def extrapolate_downbeats(
512    downbeats: npt.NDArray[np.float64],
513    tempo_factor: float,
514    buffer_size: float = SMART_CROSSFADE_DURATION,
515    bpm: float | None = None,
516) -> npt.NDArray[np.float64]:
517    """Extrapolate downbeats based on actual intervals when detection is incomplete.
518
519    This is needed when we want to perform beat alignment in an 'atmospheric' outro
520    that does not have any detected downbeats.
521
522    Args:
523        downbeats: Array of detected downbeat positions in seconds
524        tempo_factor: Tempo adjustment factor for time stretching
525        buffer_size: Maximum buffer size in seconds
526        bpm: Optional BPM for validation when extrapolating with only 2 downbeats
527    """
528    # Handle case with exactly 2 downbeats (with BPM validation)
529    if len(downbeats) == 2 and bpm is not None:
530        interval = float(downbeats[1] - downbeats[0])
531
532        # Expected interval for this BPM (assuming 4/4 time signature)
533        expected_interval = (60.0 / bpm) * 4
534
535        # Only extrapolate if interval matches BPM within 15% tolerance
536        if abs(interval - expected_interval) / expected_interval < 0.15:
537            # Adjust detected downbeats for time stretching first
538            adjusted_downbeats = downbeats / tempo_factor
539            last_downbeat = adjusted_downbeats[-1]
540
541            # If the last downbeat is close to the buffer end, no extrapolation needed
542            if last_downbeat >= buffer_size - 5:
543                return adjusted_downbeats
544
545            # Adjust the interval for time stretching
546            adjusted_interval = interval / tempo_factor
547
548            # Extrapolate forward from last adjusted downbeat using adjusted interval
549            extrapolated = []
550            current_pos = last_downbeat + adjusted_interval
551            max_extrapolation_distance = 125.0  # Don't extrapolate more than 25s
552
553            while (
554                current_pos < buffer_size
555                and (current_pos - last_downbeat) <= max_extrapolation_distance
556            ):
557                extrapolated.append(current_pos)
558                current_pos += adjusted_interval
559
560            if extrapolated:
561                # Combine adjusted detected downbeats and extrapolated downbeats
562                return np.concatenate([adjusted_downbeats, np.array(extrapolated)])
563
564            return adjusted_downbeats
565        # else: interval doesn't match BPM, fall through to return original
566
567    if len(downbeats) < 2:
568        # Need at least 2 downbeats to extrapolate
569        return downbeats / tempo_factor
570
571    # Adjust detected downbeats for time stretching first
572    adjusted_downbeats = downbeats / tempo_factor
573    last_downbeat = adjusted_downbeats[-1]
574
575    # If the last downbeat is close to the buffer end, no extrapolation needed
576    if last_downbeat >= buffer_size - 5:
577        return adjusted_downbeats
578
579    # Calculate intervals from ORIGINAL downbeats (before time stretching)
580    intervals = np.diff(downbeats)
581    median_interval = float(np.median(intervals))
582    std_interval = float(np.std(intervals))
583
584    # Only extrapolate if intervals are consistent (low standard deviation)
585    if std_interval > 0.2:
586        return adjusted_downbeats
587
588    # Adjust the interval for time stretching
589    # When slowing down (tempo_factor < 1.0), intervals get longer
590    adjusted_interval = median_interval / tempo_factor
591
592    # Extrapolate forward from last adjusted downbeat using adjusted interval
593    extrapolated = []
594    current_pos = last_downbeat + adjusted_interval
595    max_extrapolation_distance = 25.0  # Don't extrapolate more than 25s
596
597    while current_pos < buffer_size and (current_pos - last_downbeat) <= max_extrapolation_distance:
598        extrapolated.append(current_pos)
599        current_pos += adjusted_interval
600
601    if extrapolated:
602        # Combine adjusted detected downbeats and extrapolated downbeats
603        return np.concatenate([adjusted_downbeats, np.array(extrapolated)])
604
605    return adjusted_downbeats
606