/
/
/
1"""Smart Fades - Audio fade implementations."""
2
3from __future__ import annotations
4
5import logging
6from abc import ABC, abstractmethod
7from typing import TYPE_CHECKING
8
9import aiofiles
10import numpy as np
11import numpy.typing as npt
12import shortuuid
13
14from music_assistant.constants import VERBOSE_LOG_LEVEL
15from music_assistant.controllers.streams.smart_fades.filters import (
16 CrossfadeFilter,
17 Filter,
18 FrequencySweepFilter,
19 TimeStretchFilter,
20 TrimFilter,
21)
22from music_assistant.helpers.process import communicate
23from music_assistant.helpers.util import remove_file
24from music_assistant.models.smart_fades import (
25 SmartFadesAnalysis,
26)
27
28if TYPE_CHECKING:
29 from music_assistant_models.media_items import AudioFormat
30
31SMART_CROSSFADE_DURATION = 45
32
33
34class SmartFade(ABC):
35 """Abstract base class for Smart Fades."""
36
37 filters: list[Filter]
38
39 def __init__(self, logger: logging.Logger) -> None:
40 """Initialize SmartFade base class."""
41 self.filters = []
42 self.logger = logger
43
44 @abstractmethod
45 def _build(self) -> None:
46 """Build the smart fades filter chain."""
47 ...
48
49 def _get_ffmpeg_filters(
50 self,
51 input_fadein_label: str = "[1]",
52 input_fadeout_label: str = "[0]",
53 ) -> list[str]:
54 """Get FFmpeg filters for smart fades."""
55 if not self.filters:
56 self._build()
57 filters = []
58 _cur_fadein_label = input_fadein_label
59 _cur_fadeout_label = input_fadeout_label
60 for audio_filter in self.filters:
61 filter_strings = audio_filter.apply(_cur_fadein_label, _cur_fadeout_label)
62 filters.extend(filter_strings)
63 _cur_fadein_label = f"[{audio_filter.output_fadein_label}]"
64 _cur_fadeout_label = f"[{audio_filter.output_fadeout_label}]"
65 return filters
66
67 async def apply(
68 self,
69 fade_out_part: bytes,
70 fade_in_part: bytes,
71 pcm_format: AudioFormat,
72 ) -> bytes:
73 """Apply the smart fade to the given PCM audio parts."""
74 # Write the fade_out_part to a temporary file
75 fadeout_filename = f"/tmp/{shortuuid.random(20)}.pcm" # noqa: S108
76 async with aiofiles.open(fadeout_filename, "wb") as outfile:
77 await outfile.write(fade_out_part)
78 args = [
79 "ffmpeg",
80 "-hide_banner",
81 "-loglevel",
82 "error",
83 # Input 1: fadeout part (as file)
84 "-acodec",
85 pcm_format.content_type.name.lower(), # e.g., "pcm_f32le" not just "f32le"
86 "-ac",
87 str(pcm_format.channels),
88 "-ar",
89 str(pcm_format.sample_rate),
90 "-channel_layout",
91 "mono" if pcm_format.channels == 1 else "stereo",
92 "-f",
93 pcm_format.content_type.value,
94 "-i",
95 fadeout_filename,
96 # Input 2: fade_in part (stdin)
97 "-acodec",
98 pcm_format.content_type.name.lower(),
99 "-ac",
100 str(pcm_format.channels),
101 "-ar",
102 str(pcm_format.sample_rate),
103 "-channel_layout",
104 "mono" if pcm_format.channels == 1 else "stereo",
105 "-f",
106 pcm_format.content_type.value,
107 "-i",
108 "-",
109 ]
110 smart_fade_filters = self._get_ffmpeg_filters()
111 self.logger.debug(
112 "Applying smartfade: %s",
113 self,
114 )
115 args.extend(
116 [
117 "-filter_complex",
118 ";".join(smart_fade_filters),
119 # Output format specification - must match input codec format
120 "-acodec",
121 pcm_format.content_type.name.lower(),
122 "-ac",
123 str(pcm_format.channels),
124 "-ar",
125 str(pcm_format.sample_rate),
126 "-channel_layout",
127 "mono" if pcm_format.channels == 1 else "stereo",
128 "-f",
129 pcm_format.content_type.value,
130 "-",
131 ]
132 )
133 self.logger.log(VERBOSE_LOG_LEVEL, "FFmpeg command args: %s", " ".join(args))
134
135 # Execute the enhanced smart fade with full buffer
136 _, raw_crossfade_output, stderr = await communicate(args, fade_in_part)
137 await remove_file(fadeout_filename)
138
139 if raw_crossfade_output:
140 return raw_crossfade_output
141 stderr_msg = stderr.decode() if stderr else "(no stderr output)"
142 raise RuntimeError(f"Smart crossfade failed. FFmpeg stderr: {stderr_msg}")
143
144 def __repr__(self) -> str:
145 """Return string representation of SmartFade showing the filter chain."""
146 if not self.filters:
147 return f"<{self.__class__.__name__}: 0 filters>"
148
149 chain = " â ".join(repr(f) for f in self.filters)
150 return f"<{self.__class__.__name__}: {len(self.filters)} filters> {chain}"
151
152
153class SmartCrossFade(SmartFade):
154 """Smart fades class that implements a Smart Fade mode."""
155
156 # Only apply time stretching if BPM difference is < this %
157 time_stretch_bpm_percentage_threshold: float = 5.0
158
159 def __init__(
160 self,
161 logger: logging.Logger,
162 fade_out_analysis: SmartFadesAnalysis,
163 fade_in_analysis: SmartFadesAnalysis,
164 ) -> None:
165 """Initialize SmartFades with analysis data.
166
167 Args:
168 fade_out_analysis: Analysis data for the outgoing track
169 fade_in_analysis: Analysis data for the incoming track
170 logger: Optional logger for debug output
171 """
172 self.fade_out_analysis = fade_out_analysis
173 self.fade_in_analysis = fade_in_analysis
174 super().__init__(logger)
175
176 def _build(self) -> None:
177 """Build the smart fades filter chain."""
178 # Calculate tempo factor for time stretching
179 bpm_ratio = self.fade_in_analysis.bpm / self.fade_out_analysis.bpm
180 bpm_diff_percent = abs(1.0 - bpm_ratio) * 100
181
182 # Extrapolate downbeats for better bar calculation
183 self.extrapolated_fadeout_downbeats = extrapolate_downbeats(
184 self.fade_out_analysis.downbeats,
185 tempo_factor=1.0,
186 bpm=self.fade_out_analysis.bpm,
187 )
188
189 # Additional verbose logging to debug rare failures
190 self.logger.log(
191 VERBOSE_LOG_LEVEL,
192 "SmartCrossFade build: fade_out: %s, fade_in: %s",
193 self.fade_out_analysis,
194 self.fade_in_analysis,
195 )
196
197 # Calculate optimal crossfade bars that fit in available buffer
198 crossfade_bars = self._calculate_optimal_crossfade_bars()
199
200 # Calculate beat positions for the selected bar count
201 fadein_start_pos = self._calculate_optimal_fade_timing(crossfade_bars)
202
203 # Calculate initial crossfade duration (may be adjusted later for downbeat alignment)
204 crossfade_duration = self._calculate_crossfade_duration(crossfade_bars=crossfade_bars)
205
206 # Add time stretch filter if needed
207 if (
208 0.1 < bpm_diff_percent <= self.time_stretch_bpm_percentage_threshold
209 and crossfade_bars > 4
210 ):
211 self.filters.append(TimeStretchFilter(logger=self.logger, stretch_ratio=bpm_ratio))
212 # Re-extrapolate downbeats with actual tempo factor for time-stretched audio
213 self.extrapolated_fadeout_downbeats = extrapolate_downbeats(
214 self.fade_out_analysis.downbeats,
215 tempo_factor=bpm_ratio,
216 bpm=self.fade_out_analysis.bpm,
217 )
218
219 # Check if we would have enough audio after beat alignment for the crossfade
220 if fadein_start_pos and fadein_start_pos + crossfade_duration <= SMART_CROSSFADE_DURATION:
221 self.filters.append(TrimFilter(logger=self.logger, fadein_start_pos=fadein_start_pos))
222 else:
223 self.logger.log(
224 VERBOSE_LOG_LEVEL,
225 "Skipping beat alignment: not enough audio after trim (%.1fs + %.1fs > %.1fs)",
226 fadein_start_pos,
227 crossfade_duration,
228 SMART_CROSSFADE_DURATION,
229 )
230
231 # Adjust crossfade duration to align with outgoing track's downbeats
232 crossfade_duration = self._adjust_crossfade_to_downbeats(
233 crossfade_duration=crossfade_duration,
234 fadein_start_pos=fadein_start_pos,
235 )
236
237 # 90 BPM -> 1500Hz, 140 BPM -> 2500Hz
238 avg_bpm = (self.fade_out_analysis.bpm + self.fade_in_analysis.bpm) / 2
239 crossover_freq = int(np.clip(1500 + (avg_bpm - 90) * 20, 1500, 2500))
240
241 # Adjust for BPM mismatch
242 if abs(bpm_ratio - 1.0) > 0.3:
243 crossover_freq = int(crossover_freq * 0.85)
244
245 # For shorter fades, use exp/exp curves to avoid abruptness
246 if crossfade_bars < 8:
247 fadeout_curve = "exponential"
248 fadein_curve = "exponential"
249 # For long fades, use log/linear curves
250 else:
251 # Use logarithmic curve to give the next track more space
252 fadeout_curve = "logarithmic"
253 # Use linear curve for transition, predictable and not too abrupt
254 fadein_curve = "linear"
255
256 # Create lowpass filter on the outgoing track (unfiltered â low-pass)
257 # Extended lowpass effect to gradually remove bass frequencies
258 fadeout_eq_duration = min(max(crossfade_duration * 2.5, 8.0), SMART_CROSSFADE_DURATION)
259 # The crossfade always happens at the END of the buffer
260 fadeout_eq_start = max(0, SMART_CROSSFADE_DURATION - fadeout_eq_duration)
261 fadeout_sweep = FrequencySweepFilter(
262 logger=self.logger,
263 sweep_type="lowpass",
264 target_freq=crossover_freq,
265 duration=fadeout_eq_duration,
266 start_time=fadeout_eq_start,
267 sweep_direction="fade_in",
268 poles=1,
269 curve_type=fadeout_curve,
270 stream_type="fadeout",
271 )
272 self.filters.append(fadeout_sweep)
273
274 # Create high pass filter on the incoming track (high-pass â unfiltered)
275 # Quicker highpass removal to avoid lingering vocals after crossfade
276 fadein_eq_duration = crossfade_duration / 1.5
277 fadein_sweep = FrequencySweepFilter(
278 logger=self.logger,
279 sweep_type="highpass",
280 target_freq=crossover_freq,
281 duration=fadein_eq_duration,
282 start_time=0,
283 sweep_direction="fade_out",
284 poles=1,
285 curve_type=fadein_curve,
286 stream_type="fadein",
287 )
288 self.filters.append(fadein_sweep)
289
290 # Add final crossfade filter
291 crossfade_filter = CrossfadeFilter(
292 logger=self.logger, crossfade_duration=crossfade_duration
293 )
294 self.filters.append(crossfade_filter)
295
296 def _calculate_crossfade_duration(self, crossfade_bars: int) -> float:
297 """Calculate final crossfade duration based on musical bars and BPM."""
298 # Calculate crossfade duration based on incoming track's BPM
299 beats_per_bar = 4
300 seconds_per_beat = 60.0 / self.fade_in_analysis.bpm
301 musical_duration = crossfade_bars * beats_per_bar * seconds_per_beat
302
303 # Apply buffer constraint
304 actual_duration = min(musical_duration, SMART_CROSSFADE_DURATION)
305
306 # Log if we had to constrain the duration
307 if musical_duration > SMART_CROSSFADE_DURATION:
308 self.logger.log(
309 VERBOSE_LOG_LEVEL,
310 "Constraining crossfade duration from %.1fs to %.1fs (buffer limit)",
311 musical_duration,
312 actual_duration,
313 )
314
315 return actual_duration
316
317 def _calculate_optimal_crossfade_bars(self) -> int:
318 """Calculate optimal crossfade bars that fit in available buffer."""
319 bpm_in = self.fade_in_analysis.bpm
320 bpm_out = self.fade_out_analysis.bpm
321 bpm_diff_percent = abs(1.0 - bpm_in / bpm_out) * 100
322
323 # Calculate ideal bars based on BPM compatibility
324 ideal_bars = 10 if bpm_diff_percent <= self.time_stretch_bpm_percentage_threshold else 6
325
326 # Reduce bars until it fits in the fadein buffer
327 for bars in [ideal_bars, 8, 6, 4, 2, 1]:
328 if bars > ideal_bars:
329 continue
330
331 fadein_start_pos = self._calculate_optimal_fade_timing(bars)
332 if fadein_start_pos is None:
333 continue
334
335 # Calculate what the duration would be
336 test_duration = self._calculate_crossfade_duration(crossfade_bars=bars)
337
338 # Check if it fits in fadein buffer
339 fadein_buffer = SMART_CROSSFADE_DURATION - fadein_start_pos
340 if test_duration <= fadein_buffer:
341 if bars < ideal_bars:
342 self.logger.log(
343 VERBOSE_LOG_LEVEL,
344 "Reduced crossfade from %d to %d bars (fadein buffer=%.1fs, needed=%.1fs)",
345 ideal_bars,
346 bars,
347 fadein_buffer,
348 test_duration,
349 )
350 return bars
351
352 # Fall back to 1 bar if nothing else fits
353 return 1
354
355 def _calculate_optimal_fade_timing(self, crossfade_bars: int) -> float | None:
356 """Calculate beat positions for alignment."""
357 beats_per_bar = 4
358
359 def calculate_beat_positions(
360 fade_out_beats: npt.NDArray[np.float64],
361 fade_in_beats: npt.NDArray[np.float64],
362 num_beats: int,
363 ) -> float | None:
364 """Calculate start positions from beat arrays."""
365 if len(fade_out_beats) < num_beats or len(fade_in_beats) < num_beats:
366 return None
367
368 fade_in_slice = fade_in_beats[:num_beats]
369 return float(fade_in_slice[0])
370
371 # Try downbeats first for most musical timing
372 downbeat_positions = calculate_beat_positions(
373 self.extrapolated_fadeout_downbeats, self.fade_in_analysis.downbeats, crossfade_bars
374 )
375 if downbeat_positions:
376 return downbeat_positions
377
378 # Try regular beats if downbeats insufficient
379 required_beats = crossfade_bars * beats_per_bar
380 beat_positions = calculate_beat_positions(
381 self.fade_out_analysis.beats, self.fade_in_analysis.beats, required_beats
382 )
383 if beat_positions:
384 return beat_positions
385
386 # Fallback: No beat alignment possible
387 self.logger.log(VERBOSE_LOG_LEVEL, "No beat alignment possible (insufficient beats)")
388 return None
389
390 def _adjust_crossfade_to_downbeats(
391 self,
392 crossfade_duration: float,
393 fadein_start_pos: float | None,
394 ) -> float:
395 """Adjust crossfade duration to align with outgoing track's downbeats."""
396 # If we don't have downbeats or beat alignment is disabled, return original duration
397 if len(self.extrapolated_fadeout_downbeats) == 0 or fadein_start_pos is None:
398 return crossfade_duration
399
400 # Calculate where the crossfade would start in the buffer
401 ideal_start_pos = SMART_CROSSFADE_DURATION - crossfade_duration
402
403 # Debug logging
404 self.logger.log(
405 VERBOSE_LOG_LEVEL,
406 "Downbeat adjustment - ideal_start=%.2fs (buffer=%.1fs - crossfade=%.2fs), "
407 "fadein_start=%.2fs",
408 ideal_start_pos,
409 SMART_CROSSFADE_DURATION,
410 crossfade_duration,
411 fadein_start_pos,
412 )
413
414 # Find the closest downbeats (earlier and later)
415 earlier_downbeat = None
416 later_downbeat = None
417
418 for downbeat in self.extrapolated_fadeout_downbeats:
419 if downbeat <= ideal_start_pos:
420 earlier_downbeat = downbeat
421 elif downbeat > ideal_start_pos and later_downbeat is None:
422 later_downbeat = downbeat
423 break
424
425 # Try earlier downbeat first (longer crossfade)
426 if earlier_downbeat is not None:
427 adjusted_duration = float(SMART_CROSSFADE_DURATION - earlier_downbeat)
428 if fadein_start_pos + adjusted_duration <= SMART_CROSSFADE_DURATION:
429 if abs(adjusted_duration - crossfade_duration) > 0.1:
430 self.logger.log(
431 VERBOSE_LOG_LEVEL,
432 "Adjusted crossfade duration from %.2fs to %.2fs to align with "
433 "downbeat at %.2fs (earlier)",
434 crossfade_duration,
435 adjusted_duration,
436 earlier_downbeat,
437 )
438 return adjusted_duration
439
440 # Try later downbeat (shorter crossfade)
441 if later_downbeat is not None:
442 adjusted_duration = float(SMART_CROSSFADE_DURATION - later_downbeat)
443 if fadein_start_pos + adjusted_duration <= SMART_CROSSFADE_DURATION:
444 if abs(adjusted_duration - crossfade_duration) > 0.1:
445 self.logger.log(
446 VERBOSE_LOG_LEVEL,
447 "Adjusted crossfade duration from %.2fs to %.2fs to align with "
448 "downbeat at %.2fs (later)",
449 crossfade_duration,
450 adjusted_duration,
451 later_downbeat,
452 )
453 return adjusted_duration
454
455 # If no suitable downbeat found, return original duration
456 self.logger.log(
457 VERBOSE_LOG_LEVEL,
458 "Could not adjust crossfade duration to downbeats, using original %.2fs",
459 crossfade_duration,
460 )
461 return crossfade_duration
462
463
464class StandardCrossFade(SmartFade):
465 """Standard crossfade class that implements a standard crossfade mode."""
466
467 def __init__(self, logger: logging.Logger, crossfade_duration: float = 10.0) -> None:
468 """Initialize StandardCrossFade with crossfade duration."""
469 self.crossfade_duration = crossfade_duration
470 super().__init__(logger)
471
472 def _build(self) -> None:
473 """Build the standard crossfade filter chain."""
474 self.filters = [
475 CrossfadeFilter(logger=self.logger, crossfade_duration=self.crossfade_duration),
476 ]
477
478 async def apply(
479 self, fade_out_part: bytes, fade_in_part: bytes, pcm_format: AudioFormat
480 ) -> bytes:
481 """Apply the standard crossfade to the given PCM audio parts."""
482 # We need to override the default apply here, since standard crossfade only needs to be
483 # applied to the overlapping parts, not the full buffers.
484 crossfade_size = int(pcm_format.pcm_sample_size * self.crossfade_duration)
485 # Pre-crossfade: outgoing track minus the crossfaded portion
486 pre_crossfade = fade_out_part[:-crossfade_size]
487 # Post-crossfade: incoming track minus the crossfaded portion
488 post_crossfade = fade_in_part[crossfade_size:]
489 # Adjust portions to exact crossfade size
490 adjusted_fade_in_part = fade_in_part[:crossfade_size]
491 adjusted_fade_out_part = fade_out_part[-crossfade_size:]
492 # Adjust the duration to match actual sizes
493 self.crossfade_duration = min(
494 len(adjusted_fade_in_part) / pcm_format.pcm_sample_size,
495 len(adjusted_fade_out_part) / pcm_format.pcm_sample_size,
496 )
497 # Crossfaded portion: user's configured duration
498 crossfaded_section = await super().apply(
499 adjusted_fade_out_part, adjusted_fade_in_part, pcm_format
500 )
501 # Full result: everything concatenated
502 return pre_crossfade + crossfaded_section + post_crossfade
503
504
505# HELPER METHODS
506def get_bpm_diff_percentage(bpm1: float, bpm2: float) -> float:
507 """Calculate BPM difference percentage between two BPM values."""
508 return abs(1.0 - bpm1 / bpm2) * 100
509
510
511def extrapolate_downbeats(
512 downbeats: npt.NDArray[np.float64],
513 tempo_factor: float,
514 buffer_size: float = SMART_CROSSFADE_DURATION,
515 bpm: float | None = None,
516) -> npt.NDArray[np.float64]:
517 """Extrapolate downbeats based on actual intervals when detection is incomplete.
518
519 This is needed when we want to perform beat alignment in an 'atmospheric' outro
520 that does not have any detected downbeats.
521
522 Args:
523 downbeats: Array of detected downbeat positions in seconds
524 tempo_factor: Tempo adjustment factor for time stretching
525 buffer_size: Maximum buffer size in seconds
526 bpm: Optional BPM for validation when extrapolating with only 2 downbeats
527 """
528 # Handle case with exactly 2 downbeats (with BPM validation)
529 if len(downbeats) == 2 and bpm is not None:
530 interval = float(downbeats[1] - downbeats[0])
531
532 # Expected interval for this BPM (assuming 4/4 time signature)
533 expected_interval = (60.0 / bpm) * 4
534
535 # Only extrapolate if interval matches BPM within 15% tolerance
536 if abs(interval - expected_interval) / expected_interval < 0.15:
537 # Adjust detected downbeats for time stretching first
538 adjusted_downbeats = downbeats / tempo_factor
539 last_downbeat = adjusted_downbeats[-1]
540
541 # If the last downbeat is close to the buffer end, no extrapolation needed
542 if last_downbeat >= buffer_size - 5:
543 return adjusted_downbeats
544
545 # Adjust the interval for time stretching
546 adjusted_interval = interval / tempo_factor
547
548 # Extrapolate forward from last adjusted downbeat using adjusted interval
549 extrapolated = []
550 current_pos = last_downbeat + adjusted_interval
551 max_extrapolation_distance = 125.0 # Don't extrapolate more than 25s
552
553 while (
554 current_pos < buffer_size
555 and (current_pos - last_downbeat) <= max_extrapolation_distance
556 ):
557 extrapolated.append(current_pos)
558 current_pos += adjusted_interval
559
560 if extrapolated:
561 # Combine adjusted detected downbeats and extrapolated downbeats
562 return np.concatenate([adjusted_downbeats, np.array(extrapolated)])
563
564 return adjusted_downbeats
565 # else: interval doesn't match BPM, fall through to return original
566
567 if len(downbeats) < 2:
568 # Need at least 2 downbeats to extrapolate
569 return downbeats / tempo_factor
570
571 # Adjust detected downbeats for time stretching first
572 adjusted_downbeats = downbeats / tempo_factor
573 last_downbeat = adjusted_downbeats[-1]
574
575 # If the last downbeat is close to the buffer end, no extrapolation needed
576 if last_downbeat >= buffer_size - 5:
577 return adjusted_downbeats
578
579 # Calculate intervals from ORIGINAL downbeats (before time stretching)
580 intervals = np.diff(downbeats)
581 median_interval = float(np.median(intervals))
582 std_interval = float(np.std(intervals))
583
584 # Only extrapolate if intervals are consistent (low standard deviation)
585 if std_interval > 0.2:
586 return adjusted_downbeats
587
588 # Adjust the interval for time stretching
589 # When slowing down (tempo_factor < 1.0), intervals get longer
590 adjusted_interval = median_interval / tempo_factor
591
592 # Extrapolate forward from last adjusted downbeat using adjusted interval
593 extrapolated = []
594 current_pos = last_downbeat + adjusted_interval
595 max_extrapolation_distance = 25.0 # Don't extrapolate more than 25s
596
597 while current_pos < buffer_size and (current_pos - last_downbeat) <= max_extrapolation_distance:
598 extrapolated.append(current_pos)
599 current_pos += adjusted_interval
600
601 if extrapolated:
602 # Combine adjusted detected downbeats and extrapolated downbeats
603 return np.concatenate([adjusted_downbeats, np.array(extrapolated)])
604
605 return adjusted_downbeats
606