/
/
/
1"""Smart Fades - Audio fade implementations."""
2
3from __future__ import annotations
4
5import logging
6from abc import ABC, abstractmethod
7from typing import TYPE_CHECKING
8
9import aiofiles
10import numpy as np
11import numpy.typing as npt
12import shortuuid
13
14from music_assistant.constants import VERBOSE_LOG_LEVEL
15from music_assistant.controllers.streams.smart_fades.filters import (
16 CrossfadeFilter,
17 Filter,
18 FrequencySweepFilter,
19 TimeStretchFilter,
20 TrimFilter,
21)
22from music_assistant.helpers.process import communicate
23from music_assistant.helpers.util import remove_file
24from music_assistant.models.smart_fades import (
25 SmartFadesAnalysis,
26)
27
28if TYPE_CHECKING:
29 from music_assistant_models.media_items import AudioFormat
30
31SMART_CROSSFADE_DURATION = 45
32
33
34class SmartFade(ABC):
35 """Abstract base class for Smart Fades."""
36
37 filters: list[Filter]
38
39 def __init__(self, logger: logging.Logger) -> None:
40 """Initialize SmartFade base class."""
41 self.filters = []
42 self.logger = logger
43
44 @abstractmethod
45 def _build(self) -> None:
46 """Build the smart fades filter chain."""
47 ...
48
49 def _get_ffmpeg_filters(
50 self,
51 input_fadein_label: str = "[1]",
52 input_fadeout_label: str = "[0]",
53 ) -> list[str]:
54 """Get FFmpeg filters for smart fades."""
55 if not self.filters:
56 self._build()
57 filters = []
58 _cur_fadein_label = input_fadein_label
59 _cur_fadeout_label = input_fadeout_label
60 for audio_filter in self.filters:
61 filter_strings = audio_filter.apply(_cur_fadein_label, _cur_fadeout_label)
62 filters.extend(filter_strings)
63 _cur_fadein_label = f"[{audio_filter.output_fadein_label}]"
64 _cur_fadeout_label = f"[{audio_filter.output_fadeout_label}]"
65 return filters
66
67 async def apply(
68 self,
69 fade_out_part: bytes,
70 fade_in_part: bytes,
71 pcm_format: AudioFormat,
72 ) -> bytes:
73 """Apply the smart fade to the given PCM audio parts."""
74 # Write the fade_out_part to a temporary file
75 fadeout_filename = f"/tmp/{shortuuid.random(20)}.pcm" # noqa: S108
76 async with aiofiles.open(fadeout_filename, "wb") as outfile:
77 await outfile.write(fade_out_part)
78
79 args = [
80 "ffmpeg",
81 "-hide_banner",
82 "-loglevel",
83 "error",
84 # Input 1: fadeout part (as file)
85 "-acodec",
86 pcm_format.content_type.name.lower(), # e.g., "pcm_f32le" not just "f32le"
87 "-ac",
88 str(pcm_format.channels),
89 "-ar",
90 str(pcm_format.sample_rate),
91 "-channel_layout",
92 "mono" if pcm_format.channels == 1 else "stereo",
93 "-f",
94 pcm_format.content_type.value,
95 "-i",
96 fadeout_filename,
97 # Input 2: fade_in part (stdin)
98 "-acodec",
99 pcm_format.content_type.name.lower(),
100 "-ac",
101 str(pcm_format.channels),
102 "-ar",
103 str(pcm_format.sample_rate),
104 "-channel_layout",
105 "mono" if pcm_format.channels == 1 else "stereo",
106 "-f",
107 pcm_format.content_type.value,
108 "-i",
109 "-",
110 ]
111 smart_fade_filters = self._get_ffmpeg_filters()
112 self.logger.debug(
113 "Applying smartfade: %s",
114 self,
115 )
116 args.extend(
117 [
118 "-filter_complex",
119 ";".join(smart_fade_filters),
120 # Output format specification - must match input codec format
121 "-acodec",
122 pcm_format.content_type.name.lower(),
123 "-ac",
124 str(pcm_format.channels),
125 "-ar",
126 str(pcm_format.sample_rate),
127 "-channel_layout",
128 "mono" if pcm_format.channels == 1 else "stereo",
129 "-f",
130 pcm_format.content_type.value,
131 "-",
132 ]
133 )
134 self.logger.log(VERBOSE_LOG_LEVEL, "FFmpeg command args: %s", " ".join(args))
135
136 try:
137 # Execute the enhanced smart fade with full buffer
138 _, raw_crossfade_output, stderr = await communicate(args, fade_in_part)
139
140 if raw_crossfade_output:
141 return raw_crossfade_output
142 stderr_msg = stderr.decode() if stderr else "(no stderr output)"
143 raise RuntimeError(f"Smart crossfade failed. FFmpeg stderr: {stderr_msg}")
144 finally:
145 # Always cleanup temp file, even if ffmpeg fails
146 await remove_file(fadeout_filename)
147
148 def __repr__(self) -> str:
149 """Return string representation of SmartFade showing the filter chain."""
150 if not self.filters:
151 return f"<{self.__class__.__name__}: 0 filters>"
152
153 chain = " â ".join(repr(f) for f in self.filters)
154 return f"<{self.__class__.__name__}: {len(self.filters)} filters> {chain}"
155
156
157class SmartCrossFade(SmartFade):
158 """Smart fades class that implements a Smart Fade mode."""
159
160 # Only apply time stretching if BPM difference is < this %
161 time_stretch_bpm_percentage_threshold: float = 5.0
162
163 def __init__(
164 self,
165 logger: logging.Logger,
166 fade_out_analysis: SmartFadesAnalysis,
167 fade_in_analysis: SmartFadesAnalysis,
168 ) -> None:
169 """Initialize SmartFades with analysis data.
170
171 Args:
172 fade_out_analysis: Analysis data for the outgoing track
173 fade_in_analysis: Analysis data for the incoming track
174 logger: Optional logger for debug output
175 """
176 self.fade_out_analysis = fade_out_analysis
177 self.fade_in_analysis = fade_in_analysis
178 super().__init__(logger)
179
180 def _build(self) -> None:
181 """Build the smart fades filter chain."""
182 # Calculate tempo factor for time stretching
183 bpm_ratio = self.fade_in_analysis.bpm / self.fade_out_analysis.bpm
184 bpm_diff_percent = abs(1.0 - bpm_ratio) * 100
185
186 # Extrapolate downbeats for better bar calculation
187 self.extrapolated_fadeout_downbeats = extrapolate_downbeats(
188 self.fade_out_analysis.downbeats,
189 tempo_factor=1.0,
190 bpm=self.fade_out_analysis.bpm,
191 )
192
193 # Additional verbose logging to debug rare failures
194 self.logger.log(
195 VERBOSE_LOG_LEVEL,
196 "SmartCrossFade build: fade_out: %s, fade_in: %s",
197 self.fade_out_analysis,
198 self.fade_in_analysis,
199 )
200
201 # Calculate optimal crossfade bars that fit in available buffer
202 crossfade_bars = self._calculate_optimal_crossfade_bars()
203
204 # Calculate beat positions for the selected bar count
205 fadein_start_pos = self._calculate_optimal_fade_timing(crossfade_bars)
206
207 # Calculate initial crossfade duration (may be adjusted later for downbeat alignment)
208 crossfade_duration = self._calculate_crossfade_duration(crossfade_bars=crossfade_bars)
209
210 # Add time stretch filter if needed
211 if (
212 0.1 < bpm_diff_percent <= self.time_stretch_bpm_percentage_threshold
213 and crossfade_bars > 4
214 ):
215 self.filters.append(TimeStretchFilter(logger=self.logger, stretch_ratio=bpm_ratio))
216 # Re-extrapolate downbeats with actual tempo factor for time-stretched audio
217 self.extrapolated_fadeout_downbeats = extrapolate_downbeats(
218 self.fade_out_analysis.downbeats,
219 tempo_factor=bpm_ratio,
220 bpm=self.fade_out_analysis.bpm,
221 )
222
223 # Check if we would have enough audio after beat alignment for the crossfade
224 if fadein_start_pos and fadein_start_pos + crossfade_duration <= SMART_CROSSFADE_DURATION:
225 self.filters.append(TrimFilter(logger=self.logger, fadein_start_pos=fadein_start_pos))
226 else:
227 self.logger.log(
228 VERBOSE_LOG_LEVEL,
229 "Skipping beat alignment: not enough audio after trim (%.1fs + %.1fs > %.1fs)",
230 fadein_start_pos,
231 crossfade_duration,
232 SMART_CROSSFADE_DURATION,
233 )
234
235 # Adjust crossfade duration to align with outgoing track's downbeats
236 crossfade_duration = self._adjust_crossfade_to_downbeats(
237 crossfade_duration=crossfade_duration,
238 fadein_start_pos=fadein_start_pos,
239 )
240
241 # 90 BPM -> 1500Hz, 140 BPM -> 2500Hz
242 avg_bpm = (self.fade_out_analysis.bpm + self.fade_in_analysis.bpm) / 2
243 crossover_freq = int(np.clip(1500 + (avg_bpm - 90) * 20, 1500, 2500))
244
245 # Adjust for BPM mismatch
246 if abs(bpm_ratio - 1.0) > 0.3:
247 crossover_freq = int(crossover_freq * 0.85)
248
249 # For shorter fades, use exp/exp curves to avoid abruptness
250 if crossfade_bars < 8:
251 fadeout_curve = "exponential"
252 fadein_curve = "exponential"
253 # For long fades, use log/linear curves
254 else:
255 # Use logarithmic curve to give the next track more space
256 fadeout_curve = "logarithmic"
257 # Use linear curve for transition, predictable and not too abrupt
258 fadein_curve = "linear"
259
260 # Create lowpass filter on the outgoing track (unfiltered â low-pass)
261 # Extended lowpass effect to gradually remove bass frequencies
262 fadeout_eq_duration = min(max(crossfade_duration * 2.5, 8.0), SMART_CROSSFADE_DURATION)
263 # The crossfade always happens at the END of the buffer
264 fadeout_eq_start = max(0, SMART_CROSSFADE_DURATION - fadeout_eq_duration)
265 fadeout_sweep = FrequencySweepFilter(
266 logger=self.logger,
267 sweep_type="lowpass",
268 target_freq=crossover_freq,
269 duration=fadeout_eq_duration,
270 start_time=fadeout_eq_start,
271 sweep_direction="fade_in",
272 poles=1,
273 curve_type=fadeout_curve,
274 stream_type="fadeout",
275 )
276 self.filters.append(fadeout_sweep)
277
278 # Create high pass filter on the incoming track (high-pass â unfiltered)
279 # Quicker highpass removal to avoid lingering vocals after crossfade
280 fadein_eq_duration = crossfade_duration / 1.5
281 fadein_sweep = FrequencySweepFilter(
282 logger=self.logger,
283 sweep_type="highpass",
284 target_freq=crossover_freq,
285 duration=fadein_eq_duration,
286 start_time=0,
287 sweep_direction="fade_out",
288 poles=1,
289 curve_type=fadein_curve,
290 stream_type="fadein",
291 )
292 self.filters.append(fadein_sweep)
293
294 # Add final crossfade filter
295 crossfade_filter = CrossfadeFilter(
296 logger=self.logger, crossfade_duration=crossfade_duration
297 )
298 self.filters.append(crossfade_filter)
299
300 def _calculate_crossfade_duration(self, crossfade_bars: int) -> float:
301 """Calculate final crossfade duration based on musical bars and BPM."""
302 # Calculate crossfade duration based on incoming track's BPM
303 beats_per_bar = 4
304 seconds_per_beat = 60.0 / self.fade_in_analysis.bpm
305 musical_duration = crossfade_bars * beats_per_bar * seconds_per_beat
306
307 # Apply buffer constraint
308 actual_duration = min(musical_duration, SMART_CROSSFADE_DURATION)
309
310 # Log if we had to constrain the duration
311 if musical_duration > SMART_CROSSFADE_DURATION:
312 self.logger.log(
313 VERBOSE_LOG_LEVEL,
314 "Constraining crossfade duration from %.1fs to %.1fs (buffer limit)",
315 musical_duration,
316 actual_duration,
317 )
318
319 return actual_duration
320
321 def _calculate_optimal_crossfade_bars(self) -> int:
322 """Calculate optimal crossfade bars that fit in available buffer."""
323 bpm_in = self.fade_in_analysis.bpm
324 bpm_out = self.fade_out_analysis.bpm
325 bpm_diff_percent = abs(1.0 - bpm_in / bpm_out) * 100
326
327 # Calculate ideal bars based on BPM compatibility
328 ideal_bars = 10 if bpm_diff_percent <= self.time_stretch_bpm_percentage_threshold else 6
329
330 # Reduce bars until it fits in the fadein buffer
331 for bars in [ideal_bars, 8, 6, 4, 2, 1]:
332 if bars > ideal_bars:
333 continue
334
335 fadein_start_pos = self._calculate_optimal_fade_timing(bars)
336 if fadein_start_pos is None:
337 continue
338
339 # Calculate what the duration would be
340 test_duration = self._calculate_crossfade_duration(crossfade_bars=bars)
341
342 # Check if it fits in fadein buffer
343 fadein_buffer = SMART_CROSSFADE_DURATION - fadein_start_pos
344 if test_duration <= fadein_buffer:
345 if bars < ideal_bars:
346 self.logger.log(
347 VERBOSE_LOG_LEVEL,
348 "Reduced crossfade from %d to %d bars (fadein buffer=%.1fs, needed=%.1fs)",
349 ideal_bars,
350 bars,
351 fadein_buffer,
352 test_duration,
353 )
354 return bars
355
356 # Fall back to 1 bar if nothing else fits
357 return 1
358
359 def _calculate_optimal_fade_timing(self, crossfade_bars: int) -> float | None:
360 """Calculate beat positions for alignment."""
361 beats_per_bar = 4
362
363 def calculate_beat_positions(
364 fade_out_beats: npt.NDArray[np.float64],
365 fade_in_beats: npt.NDArray[np.float64],
366 num_beats: int,
367 ) -> float | None:
368 """Calculate start positions from beat arrays."""
369 if len(fade_out_beats) < num_beats or len(fade_in_beats) < num_beats:
370 return None
371
372 fade_in_slice = fade_in_beats[:num_beats]
373 return float(fade_in_slice[0])
374
375 # Try downbeats first for most musical timing
376 downbeat_positions = calculate_beat_positions(
377 self.extrapolated_fadeout_downbeats, self.fade_in_analysis.downbeats, crossfade_bars
378 )
379 if downbeat_positions:
380 return downbeat_positions
381
382 # Try regular beats if downbeats insufficient
383 required_beats = crossfade_bars * beats_per_bar
384 beat_positions = calculate_beat_positions(
385 self.fade_out_analysis.beats, self.fade_in_analysis.beats, required_beats
386 )
387 if beat_positions:
388 return beat_positions
389
390 # Fallback: No beat alignment possible
391 self.logger.log(VERBOSE_LOG_LEVEL, "No beat alignment possible (insufficient beats)")
392 return None
393
394 def _adjust_crossfade_to_downbeats(
395 self,
396 crossfade_duration: float,
397 fadein_start_pos: float | None,
398 ) -> float:
399 """Adjust crossfade duration to align with outgoing track's downbeats."""
400 # If we don't have downbeats or beat alignment is disabled, return original duration
401 if len(self.extrapolated_fadeout_downbeats) == 0 or fadein_start_pos is None:
402 return crossfade_duration
403
404 # Calculate where the crossfade would start in the buffer
405 ideal_start_pos = SMART_CROSSFADE_DURATION - crossfade_duration
406
407 # Debug logging
408 self.logger.log(
409 VERBOSE_LOG_LEVEL,
410 "Downbeat adjustment - ideal_start=%.2fs (buffer=%.1fs - crossfade=%.2fs), "
411 "fadein_start=%.2fs",
412 ideal_start_pos,
413 SMART_CROSSFADE_DURATION,
414 crossfade_duration,
415 fadein_start_pos,
416 )
417
418 # Find the closest downbeats (earlier and later)
419 earlier_downbeat = None
420 later_downbeat = None
421
422 for downbeat in self.extrapolated_fadeout_downbeats:
423 if downbeat <= ideal_start_pos:
424 earlier_downbeat = downbeat
425 elif downbeat > ideal_start_pos and later_downbeat is None:
426 later_downbeat = downbeat
427 break
428
429 # Try earlier downbeat first (longer crossfade)
430 if earlier_downbeat is not None:
431 adjusted_duration = float(SMART_CROSSFADE_DURATION - earlier_downbeat)
432 if fadein_start_pos + adjusted_duration <= SMART_CROSSFADE_DURATION:
433 if abs(adjusted_duration - crossfade_duration) > 0.1:
434 self.logger.log(
435 VERBOSE_LOG_LEVEL,
436 "Adjusted crossfade duration from %.2fs to %.2fs to align with "
437 "downbeat at %.2fs (earlier)",
438 crossfade_duration,
439 adjusted_duration,
440 earlier_downbeat,
441 )
442 return adjusted_duration
443
444 # Try later downbeat (shorter crossfade)
445 if later_downbeat is not None:
446 adjusted_duration = float(SMART_CROSSFADE_DURATION - later_downbeat)
447 if fadein_start_pos + adjusted_duration <= SMART_CROSSFADE_DURATION:
448 if abs(adjusted_duration - crossfade_duration) > 0.1:
449 self.logger.log(
450 VERBOSE_LOG_LEVEL,
451 "Adjusted crossfade duration from %.2fs to %.2fs to align with "
452 "downbeat at %.2fs (later)",
453 crossfade_duration,
454 adjusted_duration,
455 later_downbeat,
456 )
457 return adjusted_duration
458
459 # If no suitable downbeat found, return original duration
460 self.logger.log(
461 VERBOSE_LOG_LEVEL,
462 "Could not adjust crossfade duration to downbeats, using original %.2fs",
463 crossfade_duration,
464 )
465 return crossfade_duration
466
467
468class StandardCrossFade(SmartFade):
469 """Standard crossfade class that implements a standard crossfade mode."""
470
471 def __init__(self, logger: logging.Logger, crossfade_duration: float = 10.0) -> None:
472 """Initialize StandardCrossFade with crossfade duration."""
473 self.crossfade_duration = crossfade_duration
474 super().__init__(logger)
475
476 def _build(self) -> None:
477 """Build the standard crossfade filter chain."""
478 self.filters = [
479 CrossfadeFilter(logger=self.logger, crossfade_duration=self.crossfade_duration),
480 ]
481
482 async def apply(
483 self, fade_out_part: bytes, fade_in_part: bytes, pcm_format: AudioFormat
484 ) -> bytes:
485 """Apply the standard crossfade to the given PCM audio parts."""
486 # We need to override the default apply here, since standard crossfade only needs to be
487 # applied to the overlapping parts, not the full buffers.
488 crossfade_size = int(pcm_format.pcm_sample_size * self.crossfade_duration)
489 # Pre-crossfade: outgoing track minus the crossfaded portion
490 pre_crossfade = fade_out_part[:-crossfade_size]
491 # Post-crossfade: incoming track minus the crossfaded portion
492 post_crossfade = fade_in_part[crossfade_size:]
493 # Adjust portions to exact crossfade size
494 adjusted_fade_in_part = fade_in_part[:crossfade_size]
495 adjusted_fade_out_part = fade_out_part[-crossfade_size:]
496 # Adjust the duration to match actual sizes
497 self.crossfade_duration = min(
498 len(adjusted_fade_in_part) / pcm_format.pcm_sample_size,
499 len(adjusted_fade_out_part) / pcm_format.pcm_sample_size,
500 )
501 # Crossfaded portion: user's configured duration
502 crossfaded_section = await super().apply(
503 adjusted_fade_out_part, adjusted_fade_in_part, pcm_format
504 )
505 # Full result: everything concatenated
506 return pre_crossfade + crossfaded_section + post_crossfade
507
508
509# HELPER METHODS
510def get_bpm_diff_percentage(bpm1: float, bpm2: float) -> float:
511 """Calculate BPM difference percentage between two BPM values."""
512 return abs(1.0 - bpm1 / bpm2) * 100
513
514
515def extrapolate_downbeats(
516 downbeats: npt.NDArray[np.float64],
517 tempo_factor: float,
518 buffer_size: float = SMART_CROSSFADE_DURATION,
519 bpm: float | None = None,
520) -> npt.NDArray[np.float64]:
521 """Extrapolate downbeats based on actual intervals when detection is incomplete.
522
523 This is needed when we want to perform beat alignment in an 'atmospheric' outro
524 that does not have any detected downbeats.
525
526 Args:
527 downbeats: Array of detected downbeat positions in seconds
528 tempo_factor: Tempo adjustment factor for time stretching
529 buffer_size: Maximum buffer size in seconds
530 bpm: Optional BPM for validation when extrapolating with only 2 downbeats
531 """
532 # Handle case with exactly 2 downbeats (with BPM validation)
533 if len(downbeats) == 2 and bpm is not None:
534 interval = float(downbeats[1] - downbeats[0])
535
536 # Expected interval for this BPM (assuming 4/4 time signature)
537 expected_interval = (60.0 / bpm) * 4
538
539 # Only extrapolate if interval matches BPM within 15% tolerance
540 if abs(interval - expected_interval) / expected_interval < 0.15:
541 # Adjust detected downbeats for time stretching first
542 adjusted_downbeats = downbeats / tempo_factor
543 last_downbeat = adjusted_downbeats[-1]
544
545 # If the last downbeat is close to the buffer end, no extrapolation needed
546 if last_downbeat >= buffer_size - 5:
547 return adjusted_downbeats
548
549 # Adjust the interval for time stretching
550 adjusted_interval = interval / tempo_factor
551
552 # Extrapolate forward from last adjusted downbeat using adjusted interval
553 extrapolated = []
554 current_pos = last_downbeat + adjusted_interval
555 max_extrapolation_distance = 125.0 # Don't extrapolate more than 25s
556
557 while (
558 current_pos < buffer_size
559 and (current_pos - last_downbeat) <= max_extrapolation_distance
560 ):
561 extrapolated.append(current_pos)
562 current_pos += adjusted_interval
563
564 if extrapolated:
565 # Combine adjusted detected downbeats and extrapolated downbeats
566 return np.concatenate([adjusted_downbeats, np.array(extrapolated)])
567
568 return adjusted_downbeats
569 # else: interval doesn't match BPM, fall through to return original
570
571 if len(downbeats) < 2:
572 # Need at least 2 downbeats to extrapolate
573 return downbeats / tempo_factor
574
575 # Adjust detected downbeats for time stretching first
576 adjusted_downbeats = downbeats / tempo_factor
577 last_downbeat = adjusted_downbeats[-1]
578
579 # If the last downbeat is close to the buffer end, no extrapolation needed
580 if last_downbeat >= buffer_size - 5:
581 return adjusted_downbeats
582
583 # Calculate intervals from ORIGINAL downbeats (before time stretching)
584 intervals = np.diff(downbeats)
585 median_interval = float(np.median(intervals))
586 std_interval = float(np.std(intervals))
587
588 # Only extrapolate if intervals are consistent (low standard deviation)
589 if std_interval > 0.2:
590 return adjusted_downbeats
591
592 # Adjust the interval for time stretching
593 # When slowing down (tempo_factor < 1.0), intervals get longer
594 adjusted_interval = median_interval / tempo_factor
595
596 # Extrapolate forward from last adjusted downbeat using adjusted interval
597 extrapolated = []
598 current_pos = last_downbeat + adjusted_interval
599 max_extrapolation_distance = 25.0 # Don't extrapolate more than 25s
600
601 while current_pos < buffer_size and (current_pos - last_downbeat) <= max_extrapolation_distance:
602 extrapolated.append(current_pos)
603 current_pos += adjusted_interval
604
605 if extrapolated:
606 # Combine adjusted detected downbeats and extrapolated downbeats
607 return np.concatenate([adjusted_downbeats, np.array(extrapolated)])
608
609 return adjusted_downbeats
610