/
/
/
1"""Various (server-only) tools and helpers."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import importlib
8import logging
9import os
10import re
11import shutil
12import socket
13import urllib.error
14import urllib.request
15from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine
16from contextlib import suppress
17from functools import lru_cache
18from importlib.metadata import PackageNotFoundError
19from importlib.metadata import version as pkg_version
20from pathlib import Path
21from types import TracebackType
22from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeVar, cast
23from urllib.parse import urlparse
24
25import chardet
26import ifaddr
27from music_assistant_models.enums import AlbumType
28from zeroconf import IPVersion
29
30from music_assistant.constants import (
31 ANNOUNCE_ALERT_FILE,
32 LIVE_INDICATORS,
33 SOUNDTRACK_INDICATORS,
34 VERBOSE_LOG_LEVEL,
35)
36from music_assistant.helpers.process import check_output
37
38if TYPE_CHECKING:
39 from collections.abc import Iterator
40
41 from chardet.resultdict import ResultDict
42 from zeroconf.asyncio import AsyncServiceInfo
43
44 from music_assistant.mass import MusicAssistant
45 from music_assistant.models import ProviderModuleType
46 from music_assistant.models.core_controller import CoreController
47 from music_assistant.models.provider import Provider
48
49from dataclasses import fields, is_dataclass
50
51LOGGER = logging.getLogger(__name__)
52
53HA_WHEELS = "https://wheels.home-assistant.io/musllinux/"
54
55T = TypeVar("T")
56CALLBACK_TYPE = Callable[[], None]
57
58
59def get_total_system_memory() -> float:
60 """Get total system memory in GB."""
61 try:
62 # Works on Linux and macOS
63 total_memory_bytes = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
64 return total_memory_bytes / (1024**3) # Convert to GB
65 except (AttributeError, ValueError):
66 # Fallback if sysconf is not available (e.g., Windows)
67 # Return a conservative default to disable buffering by default
68 return 0.0
69
70
71keyword_pattern = re.compile("title=|artist=")
72title_pattern = re.compile(r"title=\"(?P<title>.*?)\"")
73artist_pattern = re.compile(r"artist=\"(?P<artist>.*?)\"")
74dot_com_pattern = re.compile(r"(?P<netloc>\(?\w+\.(?:\w+\.)?(\w{2,3})\)?)")
75ad_pattern = re.compile(r"((ad|advertisement)_)|^AD\s\d+$|ADBREAK", flags=re.IGNORECASE)
76title_artist_order_pattern = re.compile(r"(?P<title>.+)\sBy:\s(?P<artist>.+)", flags=re.IGNORECASE)
77multi_space_pattern = re.compile(r"\s{2,}")
78end_junk_pattern = re.compile(r"(.+?)(\s\W+)$")
79
80VERSION_PARTS = (
81 # list of common version strings
82 "version",
83 "live",
84 "edit",
85 "remix",
86 "mix",
87 "acoustic",
88 "instrumental",
89 "karaoke",
90 "remaster",
91 "versie",
92 "unplugged",
93 "disco",
94 "akoestisch",
95 "deluxe",
96)
97IGNORE_TITLE_PARTS = (
98 # strings that may be stripped off a title part
99 # (most important the featuring parts)
100 "feat.",
101 "featuring",
102 "ft.",
103 "with ",
104 "explicit",
105)
106WITH_TITLE_WORDS = (
107 # words that, when following "with", indicate this is part of the song title
108 # not a featuring credit.
109 "someone",
110 "the",
111 "u",
112 "you",
113 "no",
114)
115
116
117def filename_from_string(string: str) -> str:
118 """Create filename from unsafe string."""
119 keepcharacters = (" ", ".", "_")
120 return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
121
122
123def try_parse_int(possible_int: Any, default: int | None = 0) -> int | None:
124 """Try to parse an int."""
125 try:
126 return int(float(possible_int))
127 except (TypeError, ValueError):
128 return default
129
130
131def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | None:
132 """Try to parse a float."""
133 try:
134 return float(possible_float)
135 except (TypeError, ValueError):
136 return default
137
138
139def try_parse_bool(possible_bool: Any) -> bool:
140 """Try to parse a bool."""
141 if isinstance(possible_bool, bool):
142 return possible_bool
143 return possible_bool in ["true", "True", "1", "on", "ON", 1]
144
145
146def try_parse_duration(duration_str: str) -> float:
147 """Try to parse a duration in seconds from a duration (HH:MM:SS) string."""
148 milliseconds = float("0." + duration_str.split(".")[-1]) if "." in duration_str else 0.0
149 duration_parts = duration_str.split(".")[0].split(",")[0].split(":")
150 if len(duration_parts) == 3:
151 seconds = sum(x * int(t) for x, t in zip([3600, 60, 1], duration_parts, strict=False))
152 elif len(duration_parts) == 2:
153 seconds = sum(x * int(t) for x, t in zip([60, 1], duration_parts, strict=False))
154 else:
155 seconds = int(duration_parts[0])
156 return seconds + milliseconds
157
158
159def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
160 """Try to parse version from the title."""
161 version = track_version or ""
162 for regex in (r"\(.*?\)", r"\[.*?\]", r" - .*"):
163 for title_part in re.findall(regex, title):
164 # Extract the content without brackets/dashes for checking
165 clean_part = title_part.translate(str.maketrans("", "", "()[]-")).strip().lower()
166
167 # Check if this should be ignored (featuring/explicit parts)
168 should_ignore = False
169 for ignore_str in IGNORE_TITLE_PARTS:
170 if clean_part.startswith(ignore_str):
171 # Special handling for "with " - check if followed by title words
172 if ignore_str == "with ":
173 # Extract the word after "with "
174 after_with = (
175 clean_part[len("with ") :].split()[0]
176 if len(clean_part) > len("with ")
177 else ""
178 )
179 if after_with in WITH_TITLE_WORDS:
180 # This is part of the title (e.g., "with you"), don't ignore
181 break
182 # Remove this part from the title
183 title = title.replace(title_part, "").strip()
184 should_ignore = True
185 break
186
187 if should_ignore:
188 continue
189
190 # Check if this part is a version
191 for version_str in VERSION_PARTS:
192 if version_str in clean_part:
193 # Preserve original casing for output
194 version = title_part.strip("()[]- ").strip()
195 title = title.replace(title_part, "").strip()
196 return title, version
197 return title, version
198
199
200def infer_album_type(title: str, version: str) -> AlbumType:
201 """Infer album type by looking for live or soundtrack indicators."""
202 combined = f"{title} {version}".lower()
203 for pat in LIVE_INDICATORS:
204 if re.search(pat, combined):
205 return AlbumType.LIVE
206 for pat in SOUNDTRACK_INDICATORS:
207 if re.search(pat, combined):
208 return AlbumType.SOUNDTRACK
209 return AlbumType.UNKNOWN
210
211
212def strip_ads(line: str) -> str:
213 """Strip Ads from line."""
214 if ad_pattern.search(line):
215 return "Advert"
216 return line
217
218
219def strip_url(line: str) -> str:
220 """Strip URL from line."""
221 return (
222 " ".join([p for p in line.split() if (not urlparse(p).scheme or not urlparse(p).netloc)])
223 ).rstrip()
224
225
226def strip_dotcom(line: str) -> str:
227 """Strip scheme-less netloc from line."""
228 return dot_com_pattern.sub("", line)
229
230
231def strip_end_junk(line: str) -> str:
232 """Strip non-word info from end of line."""
233 return end_junk_pattern.sub(r"\1", line)
234
235
236def swap_title_artist_order(line: str) -> str:
237 """Swap title/artist order in line."""
238 return title_artist_order_pattern.sub(r"\g<artist> - \g<title>", line)
239
240
241def strip_multi_space(line: str) -> str:
242 """Strip multi-whitespace from line."""
243 return multi_space_pattern.sub(" ", line)
244
245
246def multi_strip(line: str) -> str:
247 """Strip assorted junk from line."""
248 return strip_multi_space(
249 swap_title_artist_order(strip_end_junk(strip_dotcom(strip_url(strip_ads(line)))))
250 ).rstrip()
251
252
253def clean_stream_title(line: str) -> str:
254 """Strip junk text from radio streamtitle."""
255 title: str = ""
256 artist: str = ""
257
258 if not keyword_pattern.search(line):
259 return multi_strip(line)
260
261 if match := title_pattern.search(line):
262 title = multi_strip(match.group("title"))
263
264 if match := artist_pattern.search(line):
265 possible_artist = multi_strip(match.group("artist"))
266 if possible_artist and possible_artist != title:
267 artist = possible_artist
268
269 if not title and not artist:
270 return ""
271
272 if title:
273 if re.search(" - ", title) or not artist:
274 return title
275 if artist:
276 return f"{artist} - {title}"
277
278 if artist:
279 return artist
280
281 return line
282
283
284async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]:
285 """Return all IP-adresses of all network interfaces."""
286
287 def call() -> tuple[str, ...]:
288 result: list[tuple[int, str]] = []
289 # try to get the primary IP address
290 # this is the IP address of the default route
291 _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
292 _sock.settimeout(0)
293 try:
294 # doesn't even have to be reachable
295 _sock.connect(("10.254.254.254", 1))
296 primary_ip = _sock.getsockname()[0]
297 except Exception:
298 primary_ip = ""
299 finally:
300 _sock.close()
301 # get all IP addresses of all network interfaces
302 adapters = ifaddr.get_adapters()
303 for adapter in adapters:
304 for ip in adapter.ips:
305 if ip.is_IPv6 and not include_ipv6:
306 continue
307 # ifaddr returns IPv6 addresses as (address, flowinfo, scope_id) tuples
308 ip_str = ip.ip[0] if isinstance(ip.ip, tuple) else ip.ip
309 if ip_str.startswith(("127", "169.254")):
310 # filter out IPv4 loopback/APIPA address
311 continue
312 if ip_str.startswith(("::1", "::ffff:", "fe80")):
313 # filter out IPv6 loopback/link-local address
314 continue
315 if ip_str == primary_ip:
316 score = 10
317 elif ip_str.startswith(("192.168.",)):
318 # we rank the 192.168 range a bit higher as its most
319 # often used as the private network subnet
320 score = 2
321 elif ip_str.startswith(("172.", "10.", "192.")):
322 # we rank the 172 range a bit lower as its most
323 # often used as the private docker network
324 score = 1
325 else:
326 score = 0
327 result.append((score, ip_str))
328 result.sort(key=lambda x: x[0], reverse=True)
329 return tuple(ip[1] for ip in result)
330
331 return await asyncio.to_thread(call)
332
333
334async def get_primary_ip_address() -> str | None:
335 """Return the primary IP address of the system."""
336
337
338async def is_port_in_use(port: int) -> bool:
339 """Check if port is in use."""
340
341 def _is_port_in_use() -> bool:
342 # Try both IPv4 and IPv6 to support single-stack and dual-stack systems.
343 # A port is considered free if it can be bound on at least one address family.
344 for family, addr in ((socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")):
345 try:
346 with socket.socket(family, socket.SOCK_STREAM) as _sock:
347 # Set SO_REUSEADDR to match asyncio.start_server behavior
348 # This allows binding to ports in TIME_WAIT state
349 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
350 _sock.bind((addr, port))
351 return False
352 except OSError:
353 continue
354 return True
355
356 return await asyncio.to_thread(_is_port_in_use)
357
358
359async def select_free_port(range_start: int, range_end: int) -> int:
360 """Automatically find available port within range."""
361 for port in range(range_start, range_end):
362 if not await is_port_in_use(port):
363 return port
364 msg = "No free port available"
365 raise OSError(msg)
366
367
368async def get_ip_from_host(dns_name: str) -> str | None:
369 """Resolve (first) IP-address for given dns name."""
370
371 def _resolve() -> str | None:
372 try:
373 return socket.gethostbyname(dns_name)
374 except Exception:
375 # fail gracefully!
376 return None
377
378 return await asyncio.to_thread(_resolve)
379
380
381async def get_ip_pton(ip_string: str) -> bytes:
382 """Return socket pton for a local ip."""
383 try:
384 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET, ip_string)
385 except OSError:
386 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)
387
388
389def format_ip_for_url(ip_address: str) -> str:
390 """Wrap IPv6 addresses in brackets for use in URLs (RFC 2732)."""
391 if ":" in ip_address:
392 return f"[{ip_address}]"
393 return ip_address
394
395
396async def get_folder_size(folderpath: str) -> float:
397 """Return folder size in gb."""
398
399 def _get_folder_size(folderpath: str) -> float:
400 total_size = 0
401 for dirpath, _dirnames, filenames in os.walk(folderpath):
402 for _file in filenames:
403 _fp = os.path.join(dirpath, _file)
404 total_size += Path(_fp).stat().st_size
405 return total_size / float(1 << 30)
406
407 return await asyncio.to_thread(_get_folder_size, folderpath)
408
409
410def get_changed_keys(
411 dict1: dict[str, Any],
412 dict2: dict[str, Any],
413 recursive: bool = False,
414) -> set[str]:
415 """Compare 2 dicts and return set of changed keys."""
416 # TODO: Check with Marcel whether we should calculate new dicts based on ignore_keys
417 return set(get_changed_dict_values(dict1, dict2, recursive).keys())
418 # return set(get_changed_dict_values(dict1, dict2, ignore_keys, recursive).keys())
419
420
421def get_changed_dict_values(
422 dict1: dict[str, Any],
423 dict2: dict[str, Any],
424 recursive: bool = False,
425) -> dict[str, tuple[Any, Any]]:
426 """
427 Compare 2 dicts and return dict of changed values.
428
429 dict key is the changed key, value is tuple of old and new values.
430 """
431 if not dict1 and not dict2:
432 return {}
433 if not dict1:
434 return {key: (None, value) for key, value in dict2.items()}
435 if not dict2:
436 return {key: (None, value) for key, value in dict1.items()}
437 changed_values = {}
438 for key, value in dict2.items():
439 if isinstance(value, dict) and isinstance(dict1[key], dict) and recursive:
440 changed_subvalues = get_changed_dict_values(dict1[key], value, recursive)
441 for subkey, subvalue in changed_subvalues.items():
442 changed_values[f"{key}.{subkey}"] = subvalue
443 continue
444 if key not in dict1:
445 changed_values[key] = (None, value)
446 continue
447 if dict1[key] != value:
448 changed_values[key] = (dict1[key], value)
449 return changed_values
450
451
452def get_changed_dataclass_values(
453 obj1: T,
454 obj2: T,
455 recursive: bool = False,
456) -> dict[str, tuple[Any, Any]]:
457 """
458 Compare 2 dataclass instances of the same type and return dict of changed field values.
459
460 dict key is the changed field name, value is tuple of old and new values.
461 """
462 if not (is_dataclass(obj1) and is_dataclass(obj2)):
463 raise ValueError("Both objects must be dataclass instances")
464
465 changed_values: dict[str, tuple[Any, Any]] = {}
466 for field in fields(obj1):
467 val1 = getattr(obj1, field.name, None)
468 val2 = getattr(obj2, field.name, None)
469 if recursive and is_dataclass(val1) and is_dataclass(val2):
470 sub_changes = get_changed_dataclass_values(val1, val2, recursive)
471 for sub_field, sub_value in sub_changes.items():
472 changed_values[f"{field.name}.{sub_field}"] = sub_value
473 continue
474 if recursive and isinstance(val1, dict) and isinstance(val2, dict):
475 sub_changes = get_changed_dict_values(val1, val2, recursive=recursive)
476 for sub_field, sub_value in sub_changes.items():
477 changed_values[f"{field.name}.{sub_field}"] = sub_value
478 continue
479 if val1 != val2:
480 changed_values[field.name] = (val1, val2)
481 return changed_values
482
483
484def empty_queue[T](q: asyncio.Queue[T]) -> None:
485 """Empty an asyncio Queue."""
486 for _ in range(q.qsize()):
487 try:
488 q.get_nowait()
489 q.task_done()
490 except (asyncio.QueueEmpty, ValueError):
491 pass
492
493
494async def install_package(package: str) -> None:
495 """Install package with pip, raise when install failed."""
496 LOGGER.debug("Installing python package %s", package)
497 args = ["uv", "pip", "install", "--no-cache", "--find-links", HA_WHEELS, package]
498 return_code, output = await check_output(*args)
499 if return_code != 0:
500 msg = f"Failed to install package {package}\n{output.decode()}"
501 raise RuntimeError(msg)
502
503
504async def get_package_version(pkg_name: str) -> str | None:
505 """
506 Return the version of an installed (python) package.
507
508 Will return None if the package is not found.
509 """
510 try:
511 return await asyncio.to_thread(pkg_version, pkg_name)
512 except PackageNotFoundError:
513 return None
514
515
516async def is_hass_supervisor() -> bool:
517 """Return if we're running inside the HA Supervisor (e.g. HAOS)."""
518
519 def _check() -> bool:
520 try:
521 urllib.request.urlopen("http://supervisor/core", timeout=1)
522 except urllib.error.URLError as err:
523 # this should return a 401 unauthorized if it exists
524 return getattr(err, "code", 999) == 401
525 except Exception:
526 return False
527 return False
528
529 return await asyncio.to_thread(_check)
530
531
532async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType:
533 """Return module for given provider domain and make sure the requirements are met."""
534
535 @lru_cache
536 def _get_provider_module(domain: str) -> ProviderModuleType:
537 return cast(
538 "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers")
539 )
540
541 # ensure module requirements are met
542 for requirement in requirements:
543 if "==" not in requirement:
544 # we should really get rid of unpinned requirements
545 continue
546 package_name, version = requirement.split("==", 1)
547 installed_version = await get_package_version(package_name)
548 if installed_version == "0.0.0":
549 # ignore editable installs
550 continue
551 if installed_version != version:
552 await install_package(requirement)
553
554 # try to load the module
555 try:
556 return await asyncio.to_thread(_get_provider_module, domain)
557 except ImportError:
558 # (re)install ALL requirements
559 for requirement in requirements:
560 await install_package(requirement)
561 # try loading the provider again to be safe
562 # this will fail if something else is wrong (as it should)
563 return await asyncio.to_thread(_get_provider_module, domain)
564
565
566async def has_tmpfs_mount() -> bool:
567 """Check if we have a tmpfs mount."""
568
569 def _has_tmpfs_mount() -> bool:
570 """Check if we have a tmpfs mount."""
571 try:
572 with open("/proc/mounts") as file:
573 for line in file:
574 if "tmpfs /tmp tmpfs rw" in line:
575 return True
576 except (FileNotFoundError, OSError, PermissionError):
577 pass
578 return False
579
580 return await asyncio.to_thread(_has_tmpfs_mount)
581
582
583async def get_free_space(folder: str) -> float:
584 """Return free space on given folderpath in GB."""
585
586 def _get_free_space(folder: str) -> float:
587 """Return free space on given folderpath in GB."""
588 try:
589 res = shutil.disk_usage(folder)
590 return res.free / float(1 << 30)
591 except (FileNotFoundError, OSError, PermissionError):
592 return 0.0
593
594 return await asyncio.to_thread(_get_free_space, folder)
595
596
597async def get_free_space_percentage(folder: str) -> float:
598 """Return free space on given folderpath in percentage."""
599
600 def _get_free_space(folder: str) -> float:
601 """Return free space on given folderpath in GB."""
602 try:
603 res = shutil.disk_usage(folder)
604 return res.free / res.total * 100
605 except (FileNotFoundError, OSError, PermissionError):
606 return 0.0
607
608 return await asyncio.to_thread(_get_free_space, folder)
609
610
611async def has_enough_space(folder: str, size: int) -> bool:
612 """Check if folder has enough free space."""
613 return await get_free_space(folder) > size
614
615
616def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
617 """Chunk bytes data into smaller chunks."""
618 for i in range(0, len(data), chunk_size):
619 yield data[i : i + chunk_size]
620
621
622async def remove_file(file_path: str) -> None:
623 """Remove file path (if it exists)."""
624 if not await asyncio.to_thread(os.path.exists, file_path):
625 return
626 await asyncio.to_thread(os.remove, file_path)
627 LOGGER.log(VERBOSE_LOG_LEVEL, "Removed file: %s", file_path)
628
629
630def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None:
631 """Get primary IP address from zeroconf discovery info."""
632 for address in discovery_info.parsed_addresses(IPVersion.V4Only):
633 if address.startswith("127"):
634 # filter out loopback address
635 continue
636 if address.startswith("169.254"):
637 # filter out APIPA address
638 continue
639 return address
640 # fall back to IPv6 addresses if no usable IPv4 address found
641 for address in discovery_info.parsed_addresses(IPVersion.V6Only):
642 if address.startswith(("::1", "fe80")):
643 # filter out loopback and link-local addresses
644 continue
645 return address
646 return None
647
648
649def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None:
650 """Get port from zeroconf discovery info."""
651 return discovery_info.port
652
653
654async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None:
655 """Force close an async generator."""
656 task = asyncio.create_task(agen.__anext__())
657 task.cancel()
658 with suppress(asyncio.CancelledError, StopAsyncIteration):
659 await task
660 await agen.aclose()
661
662
663async def detect_charset(data: bytes, fallback: str = "utf-8") -> str:
664 """Detect charset of raw data."""
665 try:
666 detected: ResultDict = await asyncio.to_thread(chardet.detect, data)
667 if detected and detected["encoding"] and detected["confidence"] > 0.75:
668 assert isinstance(detected["encoding"], str) # for type checking
669 return detected["encoding"]
670 except Exception as err:
671 LOGGER.debug("Failed to detect charset: %s", err)
672 return fallback
673
674
675def merge_dict(
676 base_dict: dict[Any, Any],
677 new_dict: dict[Any, Any],
678 allow_overwite: bool = False,
679) -> dict[Any, Any]:
680 """Merge dict without overwriting existing values."""
681 final_dict = base_dict.copy()
682 for key, value in new_dict.items():
683 if final_dict.get(key) and isinstance(value, dict):
684 final_dict[key] = merge_dict(final_dict[key], value)
685 if final_dict.get(key) and isinstance(value, tuple):
686 final_dict[key] = merge_tuples(final_dict[key], value)
687 if final_dict.get(key) and isinstance(value, list):
688 final_dict[key] = merge_lists(final_dict[key], value)
689 elif not final_dict.get(key) or allow_overwite:
690 final_dict[key] = value
691 return final_dict
692
693
694def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
695 """Merge 2 tuples."""
696 return tuple(x for x in base if x not in new) + tuple(new)
697
698
699def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
700 """Merge 2 lists."""
701 return [x for x in base if x not in new] + list(new)
702
703
704def percentage(part: float, whole: float) -> int:
705 """Calculate percentage."""
706 return int(100 * float(part) / float(whole))
707
708
709def validate_announcement_chime_url(url: str) -> bool:
710 """Validate announcement chime URL format."""
711 if not url or not url.strip():
712 return True # Empty URL is valid
713
714 if url == ANNOUNCE_ALERT_FILE:
715 return True # Built-in chime file is valid
716
717 try:
718 parsed = urlparse(url.strip())
719
720 if parsed.scheme not in ("http", "https"):
721 return False
722
723 if not parsed.netloc:
724 return False
725
726 path_lower = parsed.path.lower()
727 audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac")
728
729 return any(path_lower.endswith(ext) for ext in audio_extensions)
730
731 except Exception:
732 return False
733
734
735async def get_mac_address(ip_address: str) -> str | None:
736 """Get MAC address for given IP address via ARP lookup."""
737 try:
738 from getmac import get_mac_address as getmac_lookup # noqa: PLC0415
739
740 return await asyncio.to_thread(getmac_lookup, ip=ip_address)
741 except ImportError:
742 LOGGER.debug("getmac module not available, cannot resolve MAC from IP")
743 return None
744 except Exception as err:
745 LOGGER.debug("Failed to resolve MAC address for %s: %s", ip_address, err)
746 return None
747
748
749def is_locally_administered_mac(mac_address: str) -> bool:
750 """
751 Check if a MAC address is locally administered (virtual/randomized).
752
753 Locally administered addresses have bit 1 of the first octet set to 1.
754 These are often used by devices for virtual interfaces or protocol-specific
755 addresses (e.g., AirPlay, DLNA may use different virtual MACs than the real hardware MAC).
756
757 :param mac_address: MAC address in any common format (with :, -, or no separator).
758 :return: True if locally administered, False if globally unique (real hardware MAC).
759 """
760 # Normalize MAC address
761 mac_clean = mac_address.upper().replace(":", "").replace("-", "")
762 if len(mac_clean) < 2:
763 return False
764
765 # Get first octet and check bit 1 (second bit from right)
766 try:
767 first_octet = int(mac_clean[:2], 16)
768 return bool(first_octet & 0x02)
769 except ValueError:
770 return False
771
772
773def is_valid_mac_address(mac_address: str | None) -> bool:
774 """
775 Check if a MAC address is valid and usable for device identification.
776
777 Invalid MAC addresses include:
778 - None or empty strings
779 - Null MAC: 00:00:00:00:00:00
780 - Broadcast MAC: ff:ff:ff:ff:ff:ff
781 - Any MAC that doesn't follow the expected pattern
782
783 :param mac_address: MAC address to validate.
784 :return: True if valid and usable, False otherwise.
785 """
786 if not mac_address:
787 return False
788
789 # Normalize MAC address (remove separators and convert to lowercase)
790 normalized = mac_address.lower().replace(":", "").replace("-", "")
791
792 # Check for invalid/reserved MAC addresses
793 if normalized in ("000000000000", "ffffffffffff"):
794 return False
795
796 # Check length and hex validity
797 if len(normalized) != 12:
798 return False
799
800 try:
801 int(normalized, 16)
802 return True
803 except ValueError:
804 return False
805
806
807def normalize_ip_address(ip_address: str | None) -> str | None:
808 """
809 Normalize IP address for comparison.
810
811 Handles IPv6-mapped IPv4 addresses (e.g., ::ffff:192.168.1.64 -> 192.168.1.64).
812
813 :param ip_address: IP address to normalize.
814 :return: Normalized IP address or None if invalid.
815 """
816 if not ip_address:
817 return None
818
819 # Handle IPv6-mapped IPv4 addresses
820 if ip_address.startswith("::ffff:"):
821 # Extract the IPv4 part
822 return ip_address[7:]
823
824 return ip_address
825
826
827async def resolve_real_mac_address(reported_mac: str | None, ip_address: str | None) -> str | None:
828 """
829 Resolve the real MAC address for a device.
830
831 Some devices report different virtual MAC addresses per protocol (AirPlay, DLNA,
832 Chromecast). This function tries to resolve the actual hardware MAC via ARP
833 when the reported MAC appears to be locally administered (virtual).
834
835 :param reported_mac: The MAC address reported by the protocol.
836 :param ip_address: The IP address of the device (for ARP lookup).
837 :return: The real MAC address if found, or None if it couldn't be resolved.
838 """
839 if not ip_address:
840 return None
841
842 # If no MAC reported or it's a locally administered one, try ARP lookup
843 if not reported_mac or is_locally_administered_mac(reported_mac):
844 real_mac = await get_mac_address(ip_address)
845 if real_mac and is_valid_mac_address(real_mac):
846 return real_mac.upper()
847
848 return None
849
850
851class TaskManager:
852 """
853 Helper class to run many tasks at once.
854
855 This is basically an alternative to asyncio.TaskGroup but this will not
856 cancel all operations when one of the tasks fails.
857 Logging of exceptions is done by the mass.create_task helper.
858 """
859
860 def __init__(self, mass: MusicAssistant, limit: int = 0):
861 """Initialize the TaskManager."""
862 self.mass = mass
863 self._tasks: list[asyncio.Task[None]] = []
864 self._semaphore = asyncio.Semaphore(limit) if limit else None
865
866 def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]:
867 """Create a new task and add it to the manager."""
868 task = self.mass.create_task(coro)
869 self._tasks.append(task)
870 return task
871
872 async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None:
873 """Create a new task with semaphore limit."""
874 assert self._semaphore is not None
875
876 def task_done_callback(_task: asyncio.Task[None]) -> None:
877 assert self._semaphore is not None # for type checking
878 self._tasks.remove(task)
879 self._semaphore.release()
880
881 await self._semaphore.acquire()
882 task: asyncio.Task[None] = self.create_task(coro)
883 task.add_done_callback(task_done_callback)
884
885 async def __aenter__(self) -> Self:
886 """Enter context manager."""
887 return self
888
889 async def __aexit__(
890 self,
891 exc_type: type[BaseException] | None,
892 exc_val: BaseException | None,
893 exc_tb: TracebackType | None,
894 ) -> bool | None:
895 """Exit context manager."""
896 if len(self._tasks) > 0:
897 await asyncio.wait(self._tasks)
898 self._tasks.clear()
899 return None
900
901
902_R = TypeVar("_R")
903_P = ParamSpec("_P")
904
905
906def lock[**P, R]( # type: ignore[valid-type]
907 func: Callable[_P, Awaitable[_R]],
908) -> Callable[_P, Coroutine[Any, Any, _R]]:
909 """Call async function using a Lock."""
910
911 @functools.wraps(func)
912 async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
913 """Call async function using the throttler with retries."""
914 if not (func_lock := getattr(func, "lock", None)):
915 func_lock = asyncio.Lock()
916 func.lock = func_lock # type: ignore[attr-defined]
917 async with func_lock:
918 return await func(*args, **kwargs)
919
920 return wrapper
921
922
923class TimedAsyncGenerator:
924 """
925 Async iterable that times out after a given time.
926
927 Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9
928 """
929
930 def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0):
931 """
932 Initialize the AsyncTimedIterable.
933
934 Args:
935 iterable: The async iterable to wrap.
936 timeout: The timeout in seconds for each iteration.
937 """
938
939 class AsyncTimedIterator:
940 def __init__(self) -> None:
941 self._iterator = iterable.__aiter__()
942
943 async def __anext__(self) -> Any:
944 result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout))
945 if not result:
946 raise StopAsyncIteration
947 return result
948
949 self._factory = AsyncTimedIterator
950
951 def __aiter__(self): # type: ignore[no-untyped-def]
952 """Return the async iterator."""
953 return self._factory()
954
955
956def guard_single_request[ProviderT: "Provider | CoreController", **P, R](
957 func: Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
958) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
959 """Guard single request to a function."""
960
961 @functools.wraps(func)
962 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
963 mass = self.mass
964 # create a task_id dynamically based on the function and args/kwargs
965 cache_key_parts = [func.__class__.__name__, func.__name__, *args]
966 for key in sorted(kwargs.keys()):
967 cache_key_parts.append(f"{key}{kwargs[key]}")
968 task_id = ".".join(map(str, cache_key_parts))
969 task: asyncio.Task[R] = mass.create_task(
970 func,
971 self,
972 *args,
973 task_id=task_id,
974 abort_existing=False,
975 **kwargs,
976 )
977 return await task
978
979 return wrapper
980