/
/
/
1"""Various (server-only) tools and helpers."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import importlib
8import logging
9import os
10import re
11import shutil
12import socket
13import urllib.error
14import urllib.request
15from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine
16from contextlib import suppress
17from functools import lru_cache
18from importlib.metadata import PackageNotFoundError
19from importlib.metadata import version as pkg_version
20from pathlib import Path
21from types import TracebackType
22from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeVar, cast
23from urllib.parse import urlparse
24
25import chardet
26import ifaddr
27from music_assistant_models.enums import AlbumType
28from zeroconf import IPVersion
29
30from music_assistant.constants import (
31 ANNOUNCE_ALERT_FILE,
32 LIVE_INDICATORS,
33 SOUNDTRACK_INDICATORS,
34 VERBOSE_LOG_LEVEL,
35)
36from music_assistant.helpers.process import check_output
37
38if TYPE_CHECKING:
39 from collections.abc import Iterator
40
41 from chardet.resultdict import ResultDict
42 from zeroconf.asyncio import AsyncServiceInfo
43
44 from music_assistant.mass import MusicAssistant
45 from music_assistant.models import ProviderModuleType
46 from music_assistant.models.core_controller import CoreController
47 from music_assistant.models.provider import Provider
48
49from dataclasses import fields, is_dataclass
50
51LOGGER = logging.getLogger(__name__)
52
53HA_WHEELS = "https://wheels.home-assistant.io/musllinux/"
54
55T = TypeVar("T")
56CALLBACK_TYPE = Callable[[], None]
57
58
59def get_total_system_memory() -> float:
60 """Get total system memory in GB."""
61 try:
62 # Works on Linux and macOS
63 total_memory_bytes = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
64 return total_memory_bytes / (1024**3) # Convert to GB
65 except (AttributeError, ValueError):
66 # Fallback if sysconf is not available (e.g., Windows)
67 # Return a conservative default to disable buffering by default
68 return 0.0
69
70
71keyword_pattern = re.compile("title=|artist=")
72title_pattern = re.compile(r"title=\"(?P<title>.*?)\"")
73artist_pattern = re.compile(r"artist=\"(?P<artist>.*?)\"")
74dot_com_pattern = re.compile(r"(?P<netloc>\(?\w+\.(?:\w+\.)?(\w{2,3})\)?)")
75ad_pattern = re.compile(r"((ad|advertisement)_)|^AD\s\d+$|ADBREAK", flags=re.IGNORECASE)
76title_artist_order_pattern = re.compile(r"(?P<title>.+)\sBy:\s(?P<artist>.+)", flags=re.IGNORECASE)
77multi_space_pattern = re.compile(r"\s{2,}")
78end_junk_pattern = re.compile(r"(.+?)(\s\W+)$")
79
80VERSION_PARTS = (
81 # list of common version strings
82 "version",
83 "live",
84 "edit",
85 "remix",
86 "mix",
87 "acoustic",
88 "instrumental",
89 "karaoke",
90 "remaster",
91 "versie",
92 "unplugged",
93 "disco",
94 "akoestisch",
95 "deluxe",
96)
97IGNORE_TITLE_PARTS = (
98 # strings that may be stripped off a title part
99 # (most important the featuring parts)
100 "feat.",
101 "featuring",
102 "ft.",
103 "with ",
104 "explicit",
105)
106WITH_TITLE_WORDS = (
107 # words that, when following "with", indicate this is part of the song title
108 # not a featuring credit.
109 "someone",
110 "the",
111 "u",
112 "you",
113 "no",
114)
115
116
117def filename_from_string(string: str) -> str:
118 """Create filename from unsafe string."""
119 keepcharacters = (" ", ".", "_")
120 return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
121
122
123def try_parse_int(possible_int: Any, default: int | None = 0) -> int | None:
124 """Try to parse an int."""
125 try:
126 return int(float(possible_int))
127 except (TypeError, ValueError):
128 return default
129
130
131def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | None:
132 """Try to parse a float."""
133 try:
134 return float(possible_float)
135 except (TypeError, ValueError):
136 return default
137
138
139def try_parse_bool(possible_bool: Any) -> bool:
140 """Try to parse a bool."""
141 if isinstance(possible_bool, bool):
142 return possible_bool
143 return possible_bool in ["true", "True", "1", "on", "ON", 1]
144
145
146def try_parse_duration(duration_str: str) -> float:
147 """Try to parse a duration in seconds from a duration (HH:MM:SS) string."""
148 milliseconds = float("0." + duration_str.split(".")[-1]) if "." in duration_str else 0.0
149 duration_parts = duration_str.split(".")[0].split(",")[0].split(":")
150 if len(duration_parts) == 3:
151 seconds = sum(x * int(t) for x, t in zip([3600, 60, 1], duration_parts, strict=False))
152 elif len(duration_parts) == 2:
153 seconds = sum(x * int(t) for x, t in zip([60, 1], duration_parts, strict=False))
154 else:
155 seconds = int(duration_parts[0])
156 return seconds + milliseconds
157
158
159def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
160 """Try to parse version from the title."""
161 version = track_version or ""
162 for regex in (r"\(.*?\)", r"\[.*?\]", r" - .*"):
163 for title_part in re.findall(regex, title):
164 # Extract the content without brackets/dashes for checking
165 clean_part = title_part.translate(str.maketrans("", "", "()[]-")).strip().lower()
166
167 # Check if this should be ignored (featuring/explicit parts)
168 should_ignore = False
169 for ignore_str in IGNORE_TITLE_PARTS:
170 if clean_part.startswith(ignore_str):
171 # Special handling for "with " - check if followed by title words
172 if ignore_str == "with ":
173 # Extract the word after "with "
174 after_with = (
175 clean_part[len("with ") :].split()[0]
176 if len(clean_part) > len("with ")
177 else ""
178 )
179 if after_with in WITH_TITLE_WORDS:
180 # This is part of the title (e.g., "with you"), don't ignore
181 break
182 # Remove this part from the title
183 title = title.replace(title_part, "").strip()
184 should_ignore = True
185 break
186
187 if should_ignore:
188 continue
189
190 # Check if this part is a version
191 for version_str in VERSION_PARTS:
192 if version_str in clean_part:
193 # Preserve original casing for output
194 version = title_part.strip("()[]- ").strip()
195 title = title.replace(title_part, "").strip()
196 return title, version
197 return title, version
198
199
200def infer_album_type(title: str, version: str) -> AlbumType:
201 """Infer album type by looking for live or soundtrack indicators."""
202 combined = f"{title} {version}".lower()
203 for pat in LIVE_INDICATORS:
204 if re.search(pat, combined):
205 return AlbumType.LIVE
206 for pat in SOUNDTRACK_INDICATORS:
207 if re.search(pat, combined):
208 return AlbumType.SOUNDTRACK
209 return AlbumType.UNKNOWN
210
211
212def strip_ads(line: str) -> str:
213 """Strip Ads from line."""
214 if ad_pattern.search(line):
215 return "Advert"
216 return line
217
218
219def strip_url(line: str) -> str:
220 """Strip URL from line."""
221 return (
222 " ".join([p for p in line.split() if (not urlparse(p).scheme or not urlparse(p).netloc)])
223 ).rstrip()
224
225
226def strip_dotcom(line: str) -> str:
227 """Strip scheme-less netloc from line."""
228 return dot_com_pattern.sub("", line)
229
230
231def strip_end_junk(line: str) -> str:
232 """Strip non-word info from end of line."""
233 return end_junk_pattern.sub(r"\1", line)
234
235
236def swap_title_artist_order(line: str) -> str:
237 """Swap title/artist order in line."""
238 return title_artist_order_pattern.sub(r"\g<artist> - \g<title>", line)
239
240
241def strip_multi_space(line: str) -> str:
242 """Strip multi-whitespace from line."""
243 return multi_space_pattern.sub(" ", line)
244
245
246def multi_strip(line: str) -> str:
247 """Strip assorted junk from line."""
248 return strip_multi_space(
249 swap_title_artist_order(strip_end_junk(strip_dotcom(strip_url(strip_ads(line)))))
250 ).rstrip()
251
252
253def clean_stream_title(line: str) -> str:
254 """Strip junk text from radio streamtitle."""
255 title: str = ""
256 artist: str = ""
257
258 if not keyword_pattern.search(line):
259 return multi_strip(line)
260
261 if match := title_pattern.search(line):
262 title = multi_strip(match.group("title"))
263
264 if match := artist_pattern.search(line):
265 possible_artist = multi_strip(match.group("artist"))
266 if possible_artist and possible_artist != title:
267 artist = possible_artist
268
269 if not title and not artist:
270 return ""
271
272 if title:
273 if re.search(" - ", title) or not artist:
274 return title
275 if artist:
276 return f"{artist} - {title}"
277
278 if artist:
279 return artist
280
281 return line
282
283
284async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]:
285 """Return all IP-adresses of all network interfaces."""
286
287 def call() -> tuple[str, ...]:
288 result: list[tuple[int, str]] = []
289 # try to get the primary IP address
290 # this is the IP address of the default route
291 _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
292 _sock.settimeout(0)
293 try:
294 # doesn't even have to be reachable
295 _sock.connect(("10.254.254.254", 1))
296 primary_ip = _sock.getsockname()[0]
297 except Exception:
298 primary_ip = ""
299 finally:
300 _sock.close()
301 # get all IP addresses of all network interfaces
302 adapters = ifaddr.get_adapters()
303 for adapter in adapters:
304 for ip in adapter.ips:
305 if ip.is_IPv6 and not include_ipv6:
306 continue
307 # ifaddr returns IPv6 addresses as (address, flowinfo, scope_id) tuples
308 ip_str = ip.ip[0] if isinstance(ip.ip, tuple) else ip.ip
309 if ip_str.startswith(("127", "169.254")):
310 # filter out IPv4 loopback/APIPA address
311 continue
312 if ip_str.startswith(("::1", "::ffff:", "fe80")):
313 # filter out IPv6 loopback/link-local address
314 continue
315 if ip_str == primary_ip:
316 score = 10
317 elif ip_str.startswith(("192.168.",)):
318 # we rank the 192.168 range a bit higher as its most
319 # often used as the private network subnet
320 score = 2
321 elif ip_str.startswith(("172.", "10.", "192.")):
322 # we rank the 172 range a bit lower as its most
323 # often used as the private docker network
324 score = 1
325 else:
326 score = 0
327 result.append((score, ip_str))
328 result.sort(key=lambda x: x[0], reverse=True)
329 return tuple(ip[1] for ip in result)
330
331 return await asyncio.to_thread(call)
332
333
334async def get_primary_ip_address() -> str | None:
335 """Return the primary IP address of the system."""
336
337
338async def is_port_in_use(port: int) -> bool:
339 """Check if port is in use."""
340
341 def _is_port_in_use() -> bool:
342 # Try both IPv4 and IPv6 to support single-stack and dual-stack systems.
343 # A port is considered free if it can be bound on at least one address family.
344 for family, addr in ((socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")):
345 try:
346 with socket.socket(family, socket.SOCK_STREAM) as _sock:
347 # Set SO_REUSEADDR to match asyncio.start_server behavior
348 # This allows binding to ports in TIME_WAIT state
349 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
350 _sock.bind((addr, port))
351 return False
352 except OSError:
353 continue
354 return True
355
356 return await asyncio.to_thread(_is_port_in_use)
357
358
359async def select_free_port(range_start: int, range_end: int) -> int:
360 """Automatically find available port within range."""
361 for port in range(range_start, range_end):
362 if not await is_port_in_use(port):
363 return port
364 msg = "No free port available"
365 raise OSError(msg)
366
367
368async def get_ip_from_host(dns_name: str) -> str | None:
369 """Resolve (first) IP-address for given dns name."""
370
371 def _resolve() -> str | None:
372 try:
373 return socket.gethostbyname(dns_name)
374 except Exception:
375 # fail gracefully!
376 return None
377
378 return await asyncio.to_thread(_resolve)
379
380
381async def get_ip_pton(ip_string: str) -> bytes:
382 """Return socket pton for a local ip."""
383 try:
384 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET, ip_string)
385 except OSError:
386 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)
387
388
389def format_ip_for_url(ip_address: str) -> str:
390 """Wrap IPv6 addresses in brackets for use in URLs (RFC 2732)."""
391 if ":" in ip_address:
392 return f"[{ip_address}]"
393 return ip_address
394
395
396async def get_folder_size(folderpath: str) -> float:
397 """Return folder size in gb."""
398
399 def _get_folder_size(folderpath: str) -> float:
400 total_size = 0
401 for dirpath, _dirnames, filenames in os.walk(folderpath):
402 for _file in filenames:
403 _fp = os.path.join(dirpath, _file)
404 total_size += Path(_fp).stat().st_size
405 return total_size / float(1 << 30)
406
407 return await asyncio.to_thread(_get_folder_size, folderpath)
408
409
410def get_changed_keys(
411 dict1: dict[str, Any],
412 dict2: dict[str, Any],
413 recursive: bool = False,
414) -> set[str]:
415 """Compare 2 dicts and return set of changed keys."""
416 # TODO: Check with Marcel whether we should calculate new dicts based on ignore_keys
417 return set(get_changed_dict_values(dict1, dict2, recursive).keys())
418 # return set(get_changed_dict_values(dict1, dict2, ignore_keys, recursive).keys())
419
420
421def get_changed_dict_values(
422 dict1: dict[str, Any],
423 dict2: dict[str, Any],
424 recursive: bool = False,
425) -> dict[str, tuple[Any, Any]]:
426 """
427 Compare 2 dicts and return dict of changed values.
428
429 dict key is the changed key, value is tuple of old and new values.
430 """
431 if not dict1 and not dict2:
432 return {}
433 if not dict1:
434 return {key: (None, value) for key, value in dict2.items()}
435 if not dict2:
436 return {key: (None, value) for key, value in dict1.items()}
437 changed_values = {}
438 for key, value in dict2.items():
439 if isinstance(value, dict) and isinstance(dict1[key], dict) and recursive:
440 changed_subvalues = get_changed_dict_values(dict1[key], value, recursive)
441 for subkey, subvalue in changed_subvalues.items():
442 changed_values[f"{key}.{subkey}"] = subvalue
443 continue
444 if key not in dict1:
445 changed_values[key] = (None, value)
446 continue
447 if dict1[key] != value:
448 changed_values[key] = (dict1[key], value)
449 return changed_values
450
451
452def get_changed_dataclass_values(
453 obj1: T,
454 obj2: T,
455 recursive: bool = False,
456) -> dict[str, tuple[Any, Any]]:
457 """
458 Compare 2 dataclass instances of the same type and return dict of changed field values.
459
460 dict key is the changed field name, value is tuple of old and new values.
461 """
462 if not (is_dataclass(obj1) and is_dataclass(obj2)):
463 raise ValueError("Both objects must be dataclass instances")
464
465 changed_values: dict[str, tuple[Any, Any]] = {}
466 for field in fields(obj1):
467 val1 = getattr(obj1, field.name, None)
468 val2 = getattr(obj2, field.name, None)
469 if recursive and is_dataclass(val1) and is_dataclass(val2):
470 sub_changes = get_changed_dataclass_values(val1, val2, recursive)
471 for sub_field, sub_value in sub_changes.items():
472 changed_values[f"{field.name}.{sub_field}"] = sub_value
473 continue
474 if recursive and isinstance(val1, dict) and isinstance(val2, dict):
475 sub_changes = get_changed_dict_values(val1, val2, recursive=recursive)
476 for sub_field, sub_value in sub_changes.items():
477 changed_values[f"{field.name}.{sub_field}"] = sub_value
478 continue
479 if val1 != val2:
480 changed_values[field.name] = (val1, val2)
481 return changed_values
482
483
484def empty_queue[T](q: asyncio.Queue[T]) -> None:
485 """Empty an asyncio Queue."""
486 for _ in range(q.qsize()):
487 try:
488 q.get_nowait()
489 q.task_done()
490 except (asyncio.QueueEmpty, ValueError):
491 pass
492
493
494async def install_package(package: str) -> None:
495 """Install package with pip, raise when install failed."""
496 LOGGER.debug("Installing python package %s", package)
497 args = ["uv", "pip", "install", "--no-cache", "--find-links", HA_WHEELS, package]
498 return_code, output = await check_output(*args)
499 if return_code != 0:
500 msg = f"Failed to install package {package}\n{output.decode()}"
501 raise RuntimeError(msg)
502
503
504async def get_package_version(pkg_name: str) -> str | None:
505 """
506 Return the version of an installed (python) package.
507
508 Will return None if the package is not found.
509 """
510 try:
511 return await asyncio.to_thread(pkg_version, pkg_name)
512 except PackageNotFoundError:
513 return None
514
515
516async def is_hass_supervisor() -> bool:
517 """Return if we're running inside the HA Supervisor (e.g. HAOS)."""
518 # Fast path: check for HA supervisor token environment variable
519 # This is always set when running inside the HA supervisor
520 if not os.environ.get("SUPERVISOR_TOKEN"):
521 return False
522
523 # Token exists, verify the supervisor is actually reachable
524 def _check() -> bool:
525 try:
526 urllib.request.urlopen("http://supervisor/core", timeout=1)
527 except urllib.error.URLError as err:
528 # this should return a 401 unauthorized if it exists
529 return getattr(err, "code", 999) == 401
530 except Exception:
531 return False
532 return False
533
534 return await asyncio.to_thread(_check)
535
536
537async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType:
538 """Return module for given provider domain and make sure the requirements are met."""
539
540 @lru_cache
541 def _get_provider_module(domain: str) -> ProviderModuleType:
542 return cast(
543 "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers")
544 )
545
546 # ensure module requirements are met
547 for requirement in requirements:
548 if "==" not in requirement:
549 # we should really get rid of unpinned requirements
550 continue
551 package_name, version = requirement.split("==", 1)
552 installed_version = await get_package_version(package_name)
553 if installed_version == "0.0.0":
554 # ignore editable installs
555 continue
556 if installed_version != version:
557 await install_package(requirement)
558
559 # try to load the module
560 try:
561 return await asyncio.to_thread(_get_provider_module, domain)
562 except ImportError:
563 # (re)install ALL requirements
564 for requirement in requirements:
565 await install_package(requirement)
566 # try loading the provider again to be safe
567 # this will fail if something else is wrong (as it should)
568 return await asyncio.to_thread(_get_provider_module, domain)
569
570
571async def has_tmpfs_mount() -> bool:
572 """Check if we have a tmpfs mount."""
573
574 def _has_tmpfs_mount() -> bool:
575 """Check if we have a tmpfs mount."""
576 try:
577 with open("/proc/mounts") as file:
578 for line in file:
579 if "tmpfs /tmp tmpfs rw" in line:
580 return True
581 except (FileNotFoundError, OSError, PermissionError):
582 pass
583 return False
584
585 return await asyncio.to_thread(_has_tmpfs_mount)
586
587
588async def get_free_space(folder: str) -> float:
589 """Return free space on given folderpath in GB."""
590
591 def _get_free_space(folder: str) -> float:
592 """Return free space on given folderpath in GB."""
593 try:
594 res = shutil.disk_usage(folder)
595 return res.free / float(1 << 30)
596 except (FileNotFoundError, OSError, PermissionError):
597 return 0.0
598
599 return await asyncio.to_thread(_get_free_space, folder)
600
601
602async def get_free_space_percentage(folder: str) -> float:
603 """Return free space on given folderpath in percentage."""
604
605 def _get_free_space(folder: str) -> float:
606 """Return free space on given folderpath in GB."""
607 try:
608 res = shutil.disk_usage(folder)
609 return res.free / res.total * 100
610 except (FileNotFoundError, OSError, PermissionError):
611 return 0.0
612
613 return await asyncio.to_thread(_get_free_space, folder)
614
615
616async def has_enough_space(folder: str, size: int) -> bool:
617 """Check if folder has enough free space."""
618 return await get_free_space(folder) > size
619
620
621def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
622 """Chunk bytes data into smaller chunks."""
623 for i in range(0, len(data), chunk_size):
624 yield data[i : i + chunk_size]
625
626
627async def remove_file(file_path: str) -> None:
628 """Remove file path (if it exists)."""
629 if not await asyncio.to_thread(os.path.exists, file_path):
630 return
631 await asyncio.to_thread(os.remove, file_path)
632 LOGGER.log(VERBOSE_LOG_LEVEL, "Removed file: %s", file_path)
633
634
635def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None:
636 """Get primary IP address from zeroconf discovery info."""
637 for address in discovery_info.parsed_addresses(IPVersion.V4Only):
638 if address.startswith("127"):
639 # filter out loopback address
640 continue
641 if address.startswith("169.254"):
642 # filter out APIPA address
643 continue
644 return address
645 # fall back to IPv6 addresses if no usable IPv4 address found
646 for address in discovery_info.parsed_addresses(IPVersion.V6Only):
647 if address.startswith(("::1", "fe80")):
648 # filter out loopback and link-local addresses
649 continue
650 return address
651 return None
652
653
654def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None:
655 """Get port from zeroconf discovery info."""
656 return discovery_info.port
657
658
659async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None:
660 """Force close an async generator."""
661 task = asyncio.create_task(agen.__anext__())
662 task.cancel()
663 with suppress(asyncio.CancelledError, StopAsyncIteration):
664 await task
665 await agen.aclose()
666
667
668async def detect_charset(data: bytes, fallback: str = "utf-8") -> str:
669 """Detect charset of raw data."""
670 try:
671 detected: ResultDict = await asyncio.to_thread(chardet.detect, data)
672 if detected and detected["encoding"] and detected["confidence"] > 0.75:
673 assert isinstance(detected["encoding"], str) # for type checking
674 return detected["encoding"]
675 except Exception as err:
676 LOGGER.debug("Failed to detect charset: %s", err)
677 return fallback
678
679
680def parse_optional_bool(value: Any) -> bool | None:
681 """Parse an optional boolean value from various input types."""
682 if value is None:
683 return None
684 if isinstance(value, bool):
685 return value
686 if isinstance(value, str):
687 value_lower = value.strip().lower()
688 if value_lower in ("true", "1", "yes", "on"):
689 return True
690 if value_lower in ("false", "0", "no", "off"):
691 return False
692 if isinstance(value, (int, float)):
693 return bool(value)
694 return None
695
696
697def merge_dict(
698 base_dict: dict[Any, Any],
699 new_dict: dict[Any, Any],
700 allow_overwite: bool = False,
701) -> dict[Any, Any]:
702 """Merge dict without overwriting existing values."""
703 final_dict = base_dict.copy()
704 for key, value in new_dict.items():
705 if final_dict.get(key) and isinstance(value, dict):
706 final_dict[key] = merge_dict(final_dict[key], value)
707 if final_dict.get(key) and isinstance(value, tuple):
708 final_dict[key] = merge_tuples(final_dict[key], value)
709 if final_dict.get(key) and isinstance(value, list):
710 final_dict[key] = merge_lists(final_dict[key], value)
711 elif not final_dict.get(key) or allow_overwite:
712 final_dict[key] = value
713 return final_dict
714
715
716def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
717 """Merge 2 tuples."""
718 return tuple(x for x in base if x not in new) + tuple(new)
719
720
721def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
722 """Merge 2 lists."""
723 return [x for x in base if x not in new] + list(new)
724
725
726def percentage(part: float, whole: float) -> int:
727 """Calculate percentage."""
728 return int(100 * float(part) / float(whole))
729
730
731def validate_announcement_chime_url(url: str) -> bool:
732 """Validate announcement chime URL format."""
733 if not url or not url.strip():
734 return True # Empty URL is valid
735
736 if url == ANNOUNCE_ALERT_FILE:
737 return True # Built-in chime file is valid
738
739 try:
740 parsed = urlparse(url.strip())
741
742 if parsed.scheme not in ("http", "https"):
743 return False
744
745 if not parsed.netloc:
746 return False
747
748 path_lower = parsed.path.lower()
749 audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac")
750
751 return any(path_lower.endswith(ext) for ext in audio_extensions)
752
753 except Exception:
754 return False
755
756
757async def get_mac_address(ip_address: str) -> str | None:
758 """Get MAC address for given IP address via ARP lookup."""
759 try:
760 from getmac import get_mac_address as getmac_lookup # noqa: PLC0415
761
762 return await asyncio.to_thread(getmac_lookup, ip=ip_address)
763 except ImportError:
764 LOGGER.debug("getmac module not available, cannot resolve MAC from IP")
765 return None
766 except Exception as err:
767 LOGGER.debug("Failed to resolve MAC address for %s: %s", ip_address, err)
768 return None
769
770
771def is_locally_administered_mac(mac_address: str) -> bool:
772 """
773 Check if a MAC address is locally administered (virtual/randomized).
774
775 Locally administered addresses have bit 1 of the first octet set to 1.
776 These are often used by devices for virtual interfaces or protocol-specific
777 addresses (e.g., AirPlay, DLNA may use different virtual MACs than the real hardware MAC).
778
779 :param mac_address: MAC address in any common format (with :, -, or no separator).
780 :return: True if locally administered, False if globally unique (real hardware MAC).
781 """
782 # Normalize MAC address
783 mac_clean = mac_address.upper().replace(":", "").replace("-", "")
784 if len(mac_clean) < 2:
785 return False
786
787 # Get first octet and check bit 1 (second bit from right)
788 try:
789 first_octet = int(mac_clean[:2], 16)
790 return bool(first_octet & 0x02)
791 except ValueError:
792 return False
793
794
795def normalize_mac_for_matching(mac_address: str) -> str:
796 """
797 Normalize a MAC address for device matching by masking out the locally-administered bit.
798
799 Some protocols (like AirPlay) report a locally-administered MAC address variant where
800 bit 1 of the first octet is set. For example:
801 - Real hardware MAC: 54:78:C9:E6:0D:A0 (first byte 0x54 = 01010100)
802 - AirPlay reports: 56:78:C9:E6:0D:A0 (first byte 0x56 = 01010110)
803
804 These represent the same device but differ only in the locally-administered bit.
805 This function normalizes the MAC by clearing bit 1 of the first octet, allowing
806 both variants to match the same device.
807
808 :param mac_address: MAC address in any common format (with :, -, or no separator).
809 :return: Normalized MAC address in lowercase without separators, with the
810 locally-administered bit cleared.
811 """
812 # Normalize MAC address (remove separators, lowercase)
813 mac_clean = mac_address.lower().replace(":", "").replace("-", "")
814 if len(mac_clean) != 12:
815 # Invalid MAC length, return as-is
816 return mac_clean
817
818 try:
819 # Parse first octet and clear bit 1 (the locally-administered bit)
820 first_octet = int(mac_clean[:2], 16)
821 first_octet_normalized = first_octet & ~0x02 # Clear bit 1
822 # Reconstruct the MAC with the normalized first octet
823 return f"{first_octet_normalized:02x}{mac_clean[2:]}"
824 except ValueError:
825 # Invalid hex, return as-is
826 return mac_clean
827
828
829def is_valid_mac_address(mac_address: str | None) -> bool:
830 """
831 Check if a MAC address is valid and usable for device identification.
832
833 Invalid MAC addresses include:
834 - None or empty strings
835 - Null MAC: 00:00:00:00:00:00
836 - Broadcast MAC: ff:ff:ff:ff:ff:ff
837 - Any MAC that doesn't follow the expected pattern
838
839 :param mac_address: MAC address to validate.
840 :return: True if valid and usable, False otherwise.
841 """
842 if not mac_address:
843 return False
844
845 # Normalize MAC address (remove separators and convert to lowercase)
846 normalized = mac_address.lower().replace(":", "").replace("-", "")
847
848 # Check for invalid/reserved MAC addresses
849 if normalized in ("000000000000", "ffffffffffff"):
850 return False
851
852 # Check length and hex validity
853 if len(normalized) != 12:
854 return False
855
856 try:
857 int(normalized, 16)
858 return True
859 except ValueError:
860 return False
861
862
863def normalize_ip_address(ip_address: str | None) -> str | None:
864 """
865 Normalize IP address for comparison.
866
867 Handles IPv6-mapped IPv4 addresses (e.g., ::ffff:192.168.1.64 -> 192.168.1.64).
868
869 :param ip_address: IP address to normalize.
870 :return: Normalized IP address or None if invalid.
871 """
872 if not ip_address:
873 return None
874
875 # Handle IPv6-mapped IPv4 addresses
876 if ip_address.startswith("::ffff:"):
877 # Extract the IPv4 part
878 return ip_address[7:]
879
880 return ip_address
881
882
883async def resolve_real_mac_address(reported_mac: str | None, ip_address: str | None) -> str | None:
884 """
885 Resolve the real MAC address for a device.
886
887 Some devices report different virtual MAC addresses per protocol (AirPlay, DLNA,
888 Chromecast). This function tries to resolve the actual hardware MAC via ARP
889 when the reported MAC appears to be locally administered (virtual).
890
891 :param reported_mac: The MAC address reported by the protocol.
892 :param ip_address: The IP address of the device (for ARP lookup).
893 :return: The real MAC address if found, or None if it couldn't be resolved.
894 """
895 if not ip_address:
896 return None
897
898 # If no MAC reported or it's a locally administered one, try ARP lookup
899 if not reported_mac or is_locally_administered_mac(reported_mac):
900 real_mac = await get_mac_address(ip_address)
901 if real_mac and is_valid_mac_address(real_mac):
902 return real_mac.upper()
903
904 return None
905
906
907class TaskManager:
908 """
909 Helper class to run many tasks at once.
910
911 This is basically an alternative to asyncio.TaskGroup but this will not
912 cancel all operations when one of the tasks fails.
913 Logging of exceptions is done by the mass.create_task helper.
914 """
915
916 def __init__(self, mass: MusicAssistant, limit: int = 0):
917 """Initialize the TaskManager."""
918 self.mass = mass
919 self._tasks: list[asyncio.Task[None]] = []
920 self._semaphore = asyncio.Semaphore(limit) if limit else None
921
922 def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]:
923 """Create a new task and add it to the manager."""
924 task = self.mass.create_task(coro)
925 self._tasks.append(task)
926 return task
927
928 async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None:
929 """Create a new task with semaphore limit."""
930 assert self._semaphore is not None
931
932 def task_done_callback(_task: asyncio.Task[None]) -> None:
933 assert self._semaphore is not None # for type checking
934 self._tasks.remove(task)
935 self._semaphore.release()
936
937 await self._semaphore.acquire()
938 task: asyncio.Task[None] = self.create_task(coro)
939 task.add_done_callback(task_done_callback)
940
941 async def __aenter__(self) -> Self:
942 """Enter context manager."""
943 return self
944
945 async def __aexit__(
946 self,
947 exc_type: type[BaseException] | None,
948 exc_val: BaseException | None,
949 exc_tb: TracebackType | None,
950 ) -> bool | None:
951 """Exit context manager."""
952 if len(self._tasks) > 0:
953 await asyncio.wait(self._tasks)
954 self._tasks.clear()
955 return None
956
957
958_R = TypeVar("_R")
959_P = ParamSpec("_P")
960
961
962def lock[**P, R]( # type: ignore[valid-type]
963 func: Callable[_P, Awaitable[_R]],
964) -> Callable[_P, Coroutine[Any, Any, _R]]:
965 """Call async function using a Lock."""
966
967 @functools.wraps(func)
968 async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
969 """Call async function using the throttler with retries."""
970 if not (func_lock := getattr(func, "lock", None)):
971 func_lock = asyncio.Lock()
972 func.lock = func_lock # type: ignore[attr-defined]
973 async with func_lock:
974 return await func(*args, **kwargs)
975
976 return wrapper
977
978
979class TimedAsyncGenerator:
980 """
981 Async iterable that times out after a given time.
982
983 Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9
984 """
985
986 def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0):
987 """
988 Initialize the AsyncTimedIterable.
989
990 Args:
991 iterable: The async iterable to wrap.
992 timeout: The timeout in seconds for each iteration.
993 """
994
995 class AsyncTimedIterator:
996 def __init__(self) -> None:
997 self._iterator = iterable.__aiter__()
998
999 async def __anext__(self) -> Any:
1000 result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout))
1001 if not result:
1002 raise StopAsyncIteration
1003 return result
1004
1005 self._factory = AsyncTimedIterator
1006
1007 def __aiter__(self): # type: ignore[no-untyped-def]
1008 """Return the async iterator."""
1009 return self._factory()
1010
1011
1012def guard_single_request[ProviderT: "Provider | CoreController", **P, R](
1013 func: Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
1014) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
1015 """Guard single request to a function."""
1016
1017 @functools.wraps(func)
1018 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
1019 mass = self.mass
1020 # create a task_id dynamically based on the function and args/kwargs
1021 cache_key_parts = [func.__class__.__name__, func.__name__, *args]
1022 for key in sorted(kwargs.keys()):
1023 cache_key_parts.append(f"{key}{kwargs[key]}")
1024 task_id = ".".join(map(str, cache_key_parts))
1025 task: asyncio.Task[R] = mass.create_task(
1026 func,
1027 self,
1028 *args,
1029 task_id=task_id,
1030 abort_existing=False,
1031 eager_start=True,
1032 **kwargs,
1033 )
1034 return await task
1035
1036 return wrapper
1037