/
/
/
1"""Various (server-only) tools and helpers."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import importlib
8import logging
9import os
10import re
11import shutil
12import socket
13import urllib.error
14import urllib.request
15from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine
16from contextlib import suppress
17from functools import lru_cache
18from importlib.metadata import PackageNotFoundError
19from importlib.metadata import version as pkg_version
20from pathlib import Path
21from types import TracebackType
22from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeVar, cast
23from urllib.parse import urlparse
24
25import chardet
26import ifaddr
27from music_assistant_models.enums import AlbumType
28from zeroconf import IPVersion
29
30from music_assistant.constants import (
31 ANNOUNCE_ALERT_FILE,
32 LIVE_INDICATORS,
33 SOUNDTRACK_INDICATORS,
34 VERBOSE_LOG_LEVEL,
35)
36from music_assistant.helpers.process import check_output
37
38if TYPE_CHECKING:
39 from collections.abc import Iterator
40
41 from chardet.resultdict import ResultDict
42 from zeroconf.asyncio import AsyncServiceInfo
43
44 from music_assistant.mass import MusicAssistant
45 from music_assistant.models import ProviderModuleType
46 from music_assistant.models.core_controller import CoreController
47 from music_assistant.models.provider import Provider
48
49from dataclasses import fields, is_dataclass
50
51LOGGER = logging.getLogger(__name__)
52
53HA_WHEELS = "https://wheels.home-assistant.io/musllinux/"
54
55T = TypeVar("T")
56CALLBACK_TYPE = Callable[[], None]
57
58
59def get_total_system_memory() -> float:
60 """Get total system memory in GB."""
61 try:
62 # Works on Linux and macOS
63 total_memory_bytes = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
64 return total_memory_bytes / (1024**3) # Convert to GB
65 except (AttributeError, ValueError):
66 # Fallback if sysconf is not available (e.g., Windows)
67 # Return a conservative default to disable buffering by default
68 return 0.0
69
70
71keyword_pattern = re.compile("title=|artist=")
72title_pattern = re.compile(r"title=\"(?P<title>.*?)\"")
73artist_pattern = re.compile(r"artist=\"(?P<artist>.*?)\"")
74dot_com_pattern = re.compile(r"(?P<netloc>\(?\w+\.(?:\w+\.)?(\w{2,3})\)?)")
75ad_pattern = re.compile(r"((ad|advertisement)_)|^AD\s\d+$|ADBREAK", flags=re.IGNORECASE)
76title_artist_order_pattern = re.compile(r"(?P<title>.+)\sBy:\s(?P<artist>.+)", flags=re.IGNORECASE)
77multi_space_pattern = re.compile(r"\s{2,}")
78end_junk_pattern = re.compile(r"(.+?)(\s\W+)$")
79
80VERSION_PARTS = (
81 # list of common version strings
82 "version",
83 "live",
84 "edit",
85 "remix",
86 "mix",
87 "acoustic",
88 "instrumental",
89 "karaoke",
90 "remaster",
91 "versie",
92 "unplugged",
93 "disco",
94 "akoestisch",
95 "deluxe",
96)
97IGNORE_TITLE_PARTS = (
98 # strings that may be stripped off a title part
99 # (most important the featuring parts)
100 "feat.",
101 "featuring",
102 "ft.",
103 "with ",
104 "explicit",
105)
106WITH_TITLE_WORDS = (
107 # words that, when following "with", indicate this is part of the song title
108 # not a featuring credit.
109 "someone",
110 "the",
111 "u",
112 "you",
113 "no",
114)
115
116
117def filename_from_string(string: str) -> str:
118 """Create filename from unsafe string."""
119 keepcharacters = (" ", ".", "_")
120 return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
121
122
123def try_parse_int(possible_int: Any, default: int | None = 0) -> int | None:
124 """Try to parse an int."""
125 try:
126 return int(float(possible_int))
127 except (TypeError, ValueError):
128 return default
129
130
131def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | None:
132 """Try to parse a float."""
133 try:
134 return float(possible_float)
135 except (TypeError, ValueError):
136 return default
137
138
139def try_parse_bool(possible_bool: Any) -> bool:
140 """Try to parse a bool."""
141 if isinstance(possible_bool, bool):
142 return possible_bool
143 return possible_bool in ["true", "True", "1", "on", "ON", 1]
144
145
146def try_parse_duration(duration_str: str) -> float:
147 """Try to parse a duration in seconds from a duration (HH:MM:SS) string."""
148 milliseconds = float("0." + duration_str.split(".")[-1]) if "." in duration_str else 0.0
149 duration_parts = duration_str.split(".")[0].split(",")[0].split(":")
150 if len(duration_parts) == 3:
151 seconds = sum(x * int(t) for x, t in zip([3600, 60, 1], duration_parts, strict=False))
152 elif len(duration_parts) == 2:
153 seconds = sum(x * int(t) for x, t in zip([60, 1], duration_parts, strict=False))
154 else:
155 seconds = int(duration_parts[0])
156 return seconds + milliseconds
157
158
159def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
160 """Try to parse version from the title."""
161 version = track_version or ""
162 for regex in (r"\(.*?\)", r"\[.*?\]", r" - .*"):
163 for title_part in re.findall(regex, title):
164 # Extract the content without brackets/dashes for checking
165 clean_part = title_part.translate(str.maketrans("", "", "()[]-")).strip().lower()
166
167 # Check if this should be ignored (featuring/explicit parts)
168 should_ignore = False
169 for ignore_str in IGNORE_TITLE_PARTS:
170 if clean_part.startswith(ignore_str):
171 # Special handling for "with " - check if followed by title words
172 if ignore_str == "with ":
173 # Extract the word after "with "
174 after_with = (
175 clean_part[len("with ") :].split()[0]
176 if len(clean_part) > len("with ")
177 else ""
178 )
179 if after_with in WITH_TITLE_WORDS:
180 # This is part of the title (e.g., "with you"), don't ignore
181 break
182 # Remove this part from the title
183 title = title.replace(title_part, "").strip()
184 should_ignore = True
185 break
186
187 if should_ignore:
188 continue
189
190 # Check if this part is a version
191 for version_str in VERSION_PARTS:
192 if version_str in clean_part:
193 # Preserve original casing for output
194 version = title_part.strip("()[]- ").strip()
195 title = title.replace(title_part, "").strip()
196 return title, version
197 return title, version
198
199
200def infer_album_type(title: str, version: str) -> AlbumType:
201 """Infer album type by looking for live or soundtrack indicators."""
202 combined = f"{title} {version}".lower()
203 for pat in LIVE_INDICATORS:
204 if re.search(pat, combined):
205 return AlbumType.LIVE
206 for pat in SOUNDTRACK_INDICATORS:
207 if re.search(pat, combined):
208 return AlbumType.SOUNDTRACK
209 return AlbumType.UNKNOWN
210
211
212def strip_ads(line: str) -> str:
213 """Strip Ads from line."""
214 if ad_pattern.search(line):
215 return "Advert"
216 return line
217
218
219def strip_url(line: str) -> str:
220 """Strip URL from line."""
221 return (
222 " ".join([p for p in line.split() if (not urlparse(p).scheme or not urlparse(p).netloc)])
223 ).rstrip()
224
225
226def strip_dotcom(line: str) -> str:
227 """Strip scheme-less netloc from line."""
228 return dot_com_pattern.sub("", line)
229
230
231def strip_end_junk(line: str) -> str:
232 """Strip non-word info from end of line."""
233 return end_junk_pattern.sub(r"\1", line)
234
235
236def swap_title_artist_order(line: str) -> str:
237 """Swap title/artist order in line."""
238 return title_artist_order_pattern.sub(r"\g<artist> - \g<title>", line)
239
240
241def strip_multi_space(line: str) -> str:
242 """Strip multi-whitespace from line."""
243 return multi_space_pattern.sub(" ", line)
244
245
246def multi_strip(line: str) -> str:
247 """Strip assorted junk from line."""
248 return strip_multi_space(
249 swap_title_artist_order(strip_end_junk(strip_dotcom(strip_url(strip_ads(line)))))
250 ).rstrip()
251
252
253def clean_stream_title(line: str) -> str:
254 """Strip junk text from radio streamtitle."""
255 title: str = ""
256 artist: str = ""
257
258 if not keyword_pattern.search(line):
259 return multi_strip(line)
260
261 if match := title_pattern.search(line):
262 title = multi_strip(match.group("title"))
263
264 if match := artist_pattern.search(line):
265 possible_artist = multi_strip(match.group("artist"))
266 if possible_artist and possible_artist != title:
267 artist = possible_artist
268
269 if not title and not artist:
270 return ""
271
272 if title:
273 if re.search(" - ", title) or not artist:
274 return title
275 if artist:
276 return f"{artist} - {title}"
277
278 if artist:
279 return artist
280
281 return line
282
283
284async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]:
285 """Return all IP-adresses of all network interfaces."""
286
287 def call() -> tuple[str, ...]:
288 result: list[tuple[int, str]] = []
289 # try to get the primary IP address
290 # this is the IP address of the default route
291 _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
292 _sock.settimeout(0)
293 try:
294 # doesn't even have to be reachable
295 _sock.connect(("10.254.254.254", 1))
296 primary_ip = _sock.getsockname()[0]
297 except Exception:
298 primary_ip = ""
299 finally:
300 _sock.close()
301 # get all IP addresses of all network interfaces
302 adapters = ifaddr.get_adapters()
303 for adapter in adapters:
304 for ip in adapter.ips:
305 if ip.is_IPv6 and not include_ipv6:
306 continue
307 # ifaddr returns IPv6 addresses as (address, flowinfo, scope_id) tuples
308 ip_str = ip.ip[0] if isinstance(ip.ip, tuple) else ip.ip
309 if ip_str.startswith(("127", "169.254")):
310 # filter out IPv4 loopback/APIPA address
311 continue
312 if ip_str.startswith(("::1", "::ffff:", "fe80")):
313 # filter out IPv6 loopback/link-local address
314 continue
315 if ip_str == primary_ip:
316 score = 10
317 elif ip_str.startswith(("192.168.",)):
318 # we rank the 192.168 range a bit higher as its most
319 # often used as the private network subnet
320 score = 2
321 elif ip_str.startswith(("172.", "10.", "192.")):
322 # we rank the 172 range a bit lower as its most
323 # often used as the private docker network
324 score = 1
325 else:
326 score = 0
327 result.append((score, ip_str))
328 result.sort(key=lambda x: x[0], reverse=True)
329 return tuple(ip[1] for ip in result)
330
331 return await asyncio.to_thread(call)
332
333
334async def get_primary_ip_address() -> str | None:
335 """Return the primary IP address of the system."""
336
337
338async def is_port_in_use(port: int) -> bool:
339 """Check if port is in use."""
340
341 def _is_port_in_use() -> bool:
342 # Try both IPv4 and IPv6 to support single-stack and dual-stack systems.
343 # A port is considered free if it can be bound on at least one address family.
344 for family, addr in ((socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")):
345 try:
346 with socket.socket(family, socket.SOCK_STREAM) as _sock:
347 # Set SO_REUSEADDR to match asyncio.start_server behavior
348 # This allows binding to ports in TIME_WAIT state
349 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
350 _sock.bind((addr, port))
351 return False
352 except OSError:
353 continue
354 return True
355
356 return await asyncio.to_thread(_is_port_in_use)
357
358
359async def select_free_port(range_start: int, range_end: int) -> int:
360 """Automatically find available port within range."""
361 for port in range(range_start, range_end):
362 if not await is_port_in_use(port):
363 return port
364 msg = "No free port available"
365 raise OSError(msg)
366
367
368async def get_ip_from_host(dns_name: str) -> str | None:
369 """Resolve (first) IP-address for given dns name."""
370
371 def _resolve() -> str | None:
372 try:
373 return socket.gethostbyname(dns_name)
374 except Exception:
375 # fail gracefully!
376 return None
377
378 return await asyncio.to_thread(_resolve)
379
380
381async def get_ip_pton(ip_string: str) -> bytes:
382 """Return socket pton for a local ip."""
383 try:
384 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET, ip_string)
385 except OSError:
386 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)
387
388
389def format_ip_for_url(ip_address: str) -> str:
390 """Wrap IPv6 addresses in brackets for use in URLs (RFC 2732)."""
391 if ":" in ip_address:
392 return f"[{ip_address}]"
393 return ip_address
394
395
396async def get_folder_size(folderpath: str) -> float:
397 """Return folder size in gb."""
398
399 def _get_folder_size(folderpath: str) -> float:
400 total_size = 0
401 for dirpath, _dirnames, filenames in os.walk(folderpath):
402 for _file in filenames:
403 _fp = os.path.join(dirpath, _file)
404 total_size += Path(_fp).stat().st_size
405 return total_size / float(1 << 30)
406
407 return await asyncio.to_thread(_get_folder_size, folderpath)
408
409
410def get_changed_keys(
411 dict1: dict[str, Any],
412 dict2: dict[str, Any],
413 recursive: bool = False,
414) -> set[str]:
415 """Compare 2 dicts and return set of changed keys."""
416 # TODO: Check with Marcel whether we should calculate new dicts based on ignore_keys
417 return set(get_changed_dict_values(dict1, dict2, recursive).keys())
418 # return set(get_changed_dict_values(dict1, dict2, ignore_keys, recursive).keys())
419
420
421def get_changed_dict_values(
422 dict1: dict[str, Any],
423 dict2: dict[str, Any],
424 recursive: bool = False,
425) -> dict[str, tuple[Any, Any]]:
426 """
427 Compare 2 dicts and return dict of changed values.
428
429 dict key is the changed key, value is tuple of old and new values.
430 """
431 if not dict1 and not dict2:
432 return {}
433 if not dict1:
434 return {key: (None, value) for key, value in dict2.items()}
435 if not dict2:
436 return {key: (None, value) for key, value in dict1.items()}
437 changed_values = {}
438 for key, value in dict2.items():
439 if isinstance(value, dict) and isinstance(dict1[key], dict) and recursive:
440 changed_subvalues = get_changed_dict_values(dict1[key], value, recursive)
441 for subkey, subvalue in changed_subvalues.items():
442 changed_values[f"{key}.{subkey}"] = subvalue
443 continue
444 if key not in dict1:
445 changed_values[key] = (None, value)
446 continue
447 if dict1[key] != value:
448 changed_values[key] = (dict1[key], value)
449 return changed_values
450
451
452def get_changed_dataclass_values(
453 obj1: T,
454 obj2: T,
455 recursive: bool = False,
456) -> dict[str, tuple[Any, Any]]:
457 """
458 Compare 2 dataclass instances of the same type and return dict of changed field values.
459
460 dict key is the changed field name, value is tuple of old and new values.
461 """
462 if not (is_dataclass(obj1) and is_dataclass(obj2)):
463 raise ValueError("Both objects must be dataclass instances")
464
465 changed_values: dict[str, tuple[Any, Any]] = {}
466 for field in fields(obj1):
467 val1 = getattr(obj1, field.name, None)
468 val2 = getattr(obj2, field.name, None)
469 if recursive and is_dataclass(val1) and is_dataclass(val2):
470 sub_changes = get_changed_dataclass_values(val1, val2, recursive)
471 for sub_field, sub_value in sub_changes.items():
472 changed_values[f"{field.name}.{sub_field}"] = sub_value
473 continue
474 if recursive and isinstance(val1, dict) and isinstance(val2, dict):
475 sub_changes = get_changed_dict_values(val1, val2, recursive=recursive)
476 for sub_field, sub_value in sub_changes.items():
477 changed_values[f"{field.name}.{sub_field}"] = sub_value
478 continue
479 if val1 != val2:
480 changed_values[field.name] = (val1, val2)
481 return changed_values
482
483
484def empty_queue[T](q: asyncio.Queue[T]) -> None:
485 """Empty an asyncio Queue."""
486 for _ in range(q.qsize()):
487 try:
488 q.get_nowait()
489 q.task_done()
490 except (asyncio.QueueEmpty, ValueError):
491 pass
492
493
494async def install_package(package: str) -> None:
495 """Install package with pip, raise when install failed."""
496 LOGGER.debug("Installing python package %s", package)
497 args = ["uv", "pip", "install", "--no-cache", "--find-links", HA_WHEELS, package]
498 return_code, output = await check_output(*args)
499 if return_code != 0:
500 msg = f"Failed to install package {package}\n{output.decode()}"
501 raise RuntimeError(msg)
502
503
504async def get_package_version(pkg_name: str) -> str | None:
505 """
506 Return the version of an installed (python) package.
507
508 Will return None if the package is not found.
509 """
510 try:
511 return await asyncio.to_thread(pkg_version, pkg_name)
512 except PackageNotFoundError:
513 return None
514
515
516async def is_hass_supervisor() -> bool:
517 """Return if we're running inside the HA Supervisor (e.g. HAOS)."""
518
519 def _check() -> bool:
520 try:
521 urllib.request.urlopen("http://supervisor/core", timeout=1)
522 except urllib.error.URLError as err:
523 # this should return a 401 unauthorized if it exists
524 return getattr(err, "code", 999) == 401
525 except Exception:
526 return False
527 return False
528
529 return await asyncio.to_thread(_check)
530
531
532async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType:
533 """Return module for given provider domain and make sure the requirements are met."""
534
535 @lru_cache
536 def _get_provider_module(domain: str) -> ProviderModuleType:
537 return cast(
538 "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers")
539 )
540
541 # ensure module requirements are met
542 for requirement in requirements:
543 if "==" not in requirement:
544 # we should really get rid of unpinned requirements
545 continue
546 package_name, version = requirement.split("==", 1)
547 installed_version = await get_package_version(package_name)
548 if installed_version == "0.0.0":
549 # ignore editable installs
550 continue
551 if installed_version != version:
552 await install_package(requirement)
553
554 # try to load the module
555 try:
556 return await asyncio.to_thread(_get_provider_module, domain)
557 except ImportError:
558 # (re)install ALL requirements
559 for requirement in requirements:
560 await install_package(requirement)
561 # try loading the provider again to be safe
562 # this will fail if something else is wrong (as it should)
563 return await asyncio.to_thread(_get_provider_module, domain)
564
565
566async def has_tmpfs_mount() -> bool:
567 """Check if we have a tmpfs mount."""
568
569 def _has_tmpfs_mount() -> bool:
570 """Check if we have a tmpfs mount."""
571 try:
572 with open("/proc/mounts") as file:
573 for line in file:
574 if "tmpfs /tmp tmpfs rw" in line:
575 return True
576 except (FileNotFoundError, OSError, PermissionError):
577 pass
578 return False
579
580 return await asyncio.to_thread(_has_tmpfs_mount)
581
582
583async def get_free_space(folder: str) -> float:
584 """Return free space on given folderpath in GB."""
585
586 def _get_free_space(folder: str) -> float:
587 """Return free space on given folderpath in GB."""
588 try:
589 res = shutil.disk_usage(folder)
590 return res.free / float(1 << 30)
591 except (FileNotFoundError, OSError, PermissionError):
592 return 0.0
593
594 return await asyncio.to_thread(_get_free_space, folder)
595
596
597async def get_free_space_percentage(folder: str) -> float:
598 """Return free space on given folderpath in percentage."""
599
600 def _get_free_space(folder: str) -> float:
601 """Return free space on given folderpath in GB."""
602 try:
603 res = shutil.disk_usage(folder)
604 return res.free / res.total * 100
605 except (FileNotFoundError, OSError, PermissionError):
606 return 0.0
607
608 return await asyncio.to_thread(_get_free_space, folder)
609
610
611async def has_enough_space(folder: str, size: int) -> bool:
612 """Check if folder has enough free space."""
613 return await get_free_space(folder) > size
614
615
616def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
617 """Chunk bytes data into smaller chunks."""
618 for i in range(0, len(data), chunk_size):
619 yield data[i : i + chunk_size]
620
621
622async def remove_file(file_path: str) -> None:
623 """Remove file path (if it exists)."""
624 if not await asyncio.to_thread(os.path.exists, file_path):
625 return
626 await asyncio.to_thread(os.remove, file_path)
627 LOGGER.log(VERBOSE_LOG_LEVEL, "Removed file: %s", file_path)
628
629
630def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None:
631 """Get primary IP address from zeroconf discovery info."""
632 for address in discovery_info.parsed_addresses(IPVersion.V4Only):
633 if address.startswith("127"):
634 # filter out loopback address
635 continue
636 if address.startswith("169.254"):
637 # filter out APIPA address
638 continue
639 return address
640 # fall back to IPv6 addresses if no usable IPv4 address found
641 for address in discovery_info.parsed_addresses(IPVersion.V6Only):
642 if address.startswith(("::1", "fe80")):
643 # filter out loopback and link-local addresses
644 continue
645 return address
646 return None
647
648
649def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None:
650 """Get port from zeroconf discovery info."""
651 return discovery_info.port
652
653
654async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None:
655 """Force close an async generator."""
656 task = asyncio.create_task(agen.__anext__())
657 task.cancel()
658 with suppress(asyncio.CancelledError, StopAsyncIteration):
659 await task
660 await agen.aclose()
661
662
663async def detect_charset(data: bytes, fallback: str = "utf-8") -> str:
664 """Detect charset of raw data."""
665 try:
666 detected: ResultDict = await asyncio.to_thread(chardet.detect, data)
667 if detected and detected["encoding"] and detected["confidence"] > 0.75:
668 assert isinstance(detected["encoding"], str) # for type checking
669 return detected["encoding"]
670 except Exception as err:
671 LOGGER.debug("Failed to detect charset: %s", err)
672 return fallback
673
674
675def merge_dict(
676 base_dict: dict[Any, Any],
677 new_dict: dict[Any, Any],
678 allow_overwite: bool = False,
679) -> dict[Any, Any]:
680 """Merge dict without overwriting existing values."""
681 final_dict = base_dict.copy()
682 for key, value in new_dict.items():
683 if final_dict.get(key) and isinstance(value, dict):
684 final_dict[key] = merge_dict(final_dict[key], value)
685 if final_dict.get(key) and isinstance(value, tuple):
686 final_dict[key] = merge_tuples(final_dict[key], value)
687 if final_dict.get(key) and isinstance(value, list):
688 final_dict[key] = merge_lists(final_dict[key], value)
689 elif not final_dict.get(key) or allow_overwite:
690 final_dict[key] = value
691 return final_dict
692
693
694def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
695 """Merge 2 tuples."""
696 return tuple(x for x in base if x not in new) + tuple(new)
697
698
699def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
700 """Merge 2 lists."""
701 return [x for x in base if x not in new] + list(new)
702
703
704def percentage(part: float, whole: float) -> int:
705 """Calculate percentage."""
706 return int(100 * float(part) / float(whole))
707
708
709def validate_announcement_chime_url(url: str) -> bool:
710 """Validate announcement chime URL format."""
711 if not url or not url.strip():
712 return True # Empty URL is valid
713
714 if url == ANNOUNCE_ALERT_FILE:
715 return True # Built-in chime file is valid
716
717 try:
718 parsed = urlparse(url.strip())
719
720 if parsed.scheme not in ("http", "https"):
721 return False
722
723 if not parsed.netloc:
724 return False
725
726 path_lower = parsed.path.lower()
727 audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac")
728
729 return any(path_lower.endswith(ext) for ext in audio_extensions)
730
731 except Exception:
732 return False
733
734
735async def get_mac_address(ip_address: str) -> str | None:
736 """Get MAC address for given IP address via ARP lookup."""
737 try:
738 from getmac import get_mac_address as getmac_lookup # noqa: PLC0415
739
740 return await asyncio.to_thread(getmac_lookup, ip=ip_address)
741 except ImportError:
742 LOGGER.debug("getmac module not available, cannot resolve MAC from IP")
743 return None
744 except Exception as err:
745 LOGGER.debug("Failed to resolve MAC address for %s: %s", ip_address, err)
746 return None
747
748
749def is_locally_administered_mac(mac_address: str) -> bool:
750 """
751 Check if a MAC address is locally administered (virtual/randomized).
752
753 Locally administered addresses have bit 1 of the first octet set to 1.
754 These are often used by devices for virtual interfaces or protocol-specific
755 addresses (e.g., AirPlay, DLNA may use different virtual MACs than the real hardware MAC).
756
757 :param mac_address: MAC address in any common format (with :, -, or no separator).
758 :return: True if locally administered, False if globally unique (real hardware MAC).
759 """
760 # Normalize MAC address
761 mac_clean = mac_address.upper().replace(":", "").replace("-", "")
762 if len(mac_clean) < 2:
763 return False
764
765 # Get first octet and check bit 1 (second bit from right)
766 try:
767 first_octet = int(mac_clean[:2], 16)
768 return bool(first_octet & 0x02)
769 except ValueError:
770 return False
771
772
773def normalize_mac_for_matching(mac_address: str) -> str:
774 """
775 Normalize a MAC address for device matching by masking out the locally-administered bit.
776
777 Some protocols (like AirPlay) report a locally-administered MAC address variant where
778 bit 1 of the first octet is set. For example:
779 - Real hardware MAC: 54:78:C9:E6:0D:A0 (first byte 0x54 = 01010100)
780 - AirPlay reports: 56:78:C9:E6:0D:A0 (first byte 0x56 = 01010110)
781
782 These represent the same device but differ only in the locally-administered bit.
783 This function normalizes the MAC by clearing bit 1 of the first octet, allowing
784 both variants to match the same device.
785
786 :param mac_address: MAC address in any common format (with :, -, or no separator).
787 :return: Normalized MAC address in lowercase without separators, with the
788 locally-administered bit cleared.
789 """
790 # Normalize MAC address (remove separators, lowercase)
791 mac_clean = mac_address.lower().replace(":", "").replace("-", "")
792 if len(mac_clean) != 12:
793 # Invalid MAC length, return as-is
794 return mac_clean
795
796 try:
797 # Parse first octet and clear bit 1 (the locally-administered bit)
798 first_octet = int(mac_clean[:2], 16)
799 first_octet_normalized = first_octet & ~0x02 # Clear bit 1
800 # Reconstruct the MAC with the normalized first octet
801 return f"{first_octet_normalized:02x}{mac_clean[2:]}"
802 except ValueError:
803 # Invalid hex, return as-is
804 return mac_clean
805
806
807def is_valid_mac_address(mac_address: str | None) -> bool:
808 """
809 Check if a MAC address is valid and usable for device identification.
810
811 Invalid MAC addresses include:
812 - None or empty strings
813 - Null MAC: 00:00:00:00:00:00
814 - Broadcast MAC: ff:ff:ff:ff:ff:ff
815 - Any MAC that doesn't follow the expected pattern
816
817 :param mac_address: MAC address to validate.
818 :return: True if valid and usable, False otherwise.
819 """
820 if not mac_address:
821 return False
822
823 # Normalize MAC address (remove separators and convert to lowercase)
824 normalized = mac_address.lower().replace(":", "").replace("-", "")
825
826 # Check for invalid/reserved MAC addresses
827 if normalized in ("000000000000", "ffffffffffff"):
828 return False
829
830 # Check length and hex validity
831 if len(normalized) != 12:
832 return False
833
834 try:
835 int(normalized, 16)
836 return True
837 except ValueError:
838 return False
839
840
841def normalize_ip_address(ip_address: str | None) -> str | None:
842 """
843 Normalize IP address for comparison.
844
845 Handles IPv6-mapped IPv4 addresses (e.g., ::ffff:192.168.1.64 -> 192.168.1.64).
846
847 :param ip_address: IP address to normalize.
848 :return: Normalized IP address or None if invalid.
849 """
850 if not ip_address:
851 return None
852
853 # Handle IPv6-mapped IPv4 addresses
854 if ip_address.startswith("::ffff:"):
855 # Extract the IPv4 part
856 return ip_address[7:]
857
858 return ip_address
859
860
861async def resolve_real_mac_address(reported_mac: str | None, ip_address: str | None) -> str | None:
862 """
863 Resolve the real MAC address for a device.
864
865 Some devices report different virtual MAC addresses per protocol (AirPlay, DLNA,
866 Chromecast). This function tries to resolve the actual hardware MAC via ARP
867 when the reported MAC appears to be locally administered (virtual).
868
869 :param reported_mac: The MAC address reported by the protocol.
870 :param ip_address: The IP address of the device (for ARP lookup).
871 :return: The real MAC address if found, or None if it couldn't be resolved.
872 """
873 if not ip_address:
874 return None
875
876 # If no MAC reported or it's a locally administered one, try ARP lookup
877 if not reported_mac or is_locally_administered_mac(reported_mac):
878 real_mac = await get_mac_address(ip_address)
879 if real_mac and is_valid_mac_address(real_mac):
880 return real_mac.upper()
881
882 return None
883
884
885class TaskManager:
886 """
887 Helper class to run many tasks at once.
888
889 This is basically an alternative to asyncio.TaskGroup but this will not
890 cancel all operations when one of the tasks fails.
891 Logging of exceptions is done by the mass.create_task helper.
892 """
893
894 def __init__(self, mass: MusicAssistant, limit: int = 0):
895 """Initialize the TaskManager."""
896 self.mass = mass
897 self._tasks: list[asyncio.Task[None]] = []
898 self._semaphore = asyncio.Semaphore(limit) if limit else None
899
900 def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]:
901 """Create a new task and add it to the manager."""
902 task = self.mass.create_task(coro)
903 self._tasks.append(task)
904 return task
905
906 async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None:
907 """Create a new task with semaphore limit."""
908 assert self._semaphore is not None
909
910 def task_done_callback(_task: asyncio.Task[None]) -> None:
911 assert self._semaphore is not None # for type checking
912 self._tasks.remove(task)
913 self._semaphore.release()
914
915 await self._semaphore.acquire()
916 task: asyncio.Task[None] = self.create_task(coro)
917 task.add_done_callback(task_done_callback)
918
919 async def __aenter__(self) -> Self:
920 """Enter context manager."""
921 return self
922
923 async def __aexit__(
924 self,
925 exc_type: type[BaseException] | None,
926 exc_val: BaseException | None,
927 exc_tb: TracebackType | None,
928 ) -> bool | None:
929 """Exit context manager."""
930 if len(self._tasks) > 0:
931 await asyncio.wait(self._tasks)
932 self._tasks.clear()
933 return None
934
935
936_R = TypeVar("_R")
937_P = ParamSpec("_P")
938
939
940def lock[**P, R]( # type: ignore[valid-type]
941 func: Callable[_P, Awaitable[_R]],
942) -> Callable[_P, Coroutine[Any, Any, _R]]:
943 """Call async function using a Lock."""
944
945 @functools.wraps(func)
946 async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
947 """Call async function using the throttler with retries."""
948 if not (func_lock := getattr(func, "lock", None)):
949 func_lock = asyncio.Lock()
950 func.lock = func_lock # type: ignore[attr-defined]
951 async with func_lock:
952 return await func(*args, **kwargs)
953
954 return wrapper
955
956
957class TimedAsyncGenerator:
958 """
959 Async iterable that times out after a given time.
960
961 Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9
962 """
963
964 def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0):
965 """
966 Initialize the AsyncTimedIterable.
967
968 Args:
969 iterable: The async iterable to wrap.
970 timeout: The timeout in seconds for each iteration.
971 """
972
973 class AsyncTimedIterator:
974 def __init__(self) -> None:
975 self._iterator = iterable.__aiter__()
976
977 async def __anext__(self) -> Any:
978 result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout))
979 if not result:
980 raise StopAsyncIteration
981 return result
982
983 self._factory = AsyncTimedIterator
984
985 def __aiter__(self): # type: ignore[no-untyped-def]
986 """Return the async iterator."""
987 return self._factory()
988
989
990def guard_single_request[ProviderT: "Provider | CoreController", **P, R](
991 func: Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
992) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
993 """Guard single request to a function."""
994
995 @functools.wraps(func)
996 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
997 mass = self.mass
998 # create a task_id dynamically based on the function and args/kwargs
999 cache_key_parts = [func.__class__.__name__, func.__name__, *args]
1000 for key in sorted(kwargs.keys()):
1001 cache_key_parts.append(f"{key}{kwargs[key]}")
1002 task_id = ".".join(map(str, cache_key_parts))
1003 task: asyncio.Task[R] = mass.create_task(
1004 func,
1005 self,
1006 *args,
1007 task_id=task_id,
1008 abort_existing=False,
1009 eager_start=True,
1010 **kwargs,
1011 )
1012 return await task
1013
1014 return wrapper
1015