/
/
/
1"""Various (server-only) tools and helpers."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import importlib
8import logging
9import os
10import re
11import shutil
12import socket
13import urllib.error
14import urllib.request
15from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine
16from contextlib import suppress
17from functools import lru_cache
18from importlib.metadata import PackageNotFoundError
19from importlib.metadata import version as pkg_version
20from pathlib import Path
21from types import TracebackType
22from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeVar, cast
23from urllib.parse import urlparse
24
25import chardet
26import ifaddr
27from music_assistant_models.enums import AlbumType
28from zeroconf import IPVersion
29
30from music_assistant.constants import (
31 ANNOUNCE_ALERT_FILE,
32 LIVE_INDICATORS,
33 SOUNDTRACK_INDICATORS,
34 VERBOSE_LOG_LEVEL,
35)
36from music_assistant.helpers.process import check_output
37
38if TYPE_CHECKING:
39 from collections.abc import Iterator
40
41 from chardet.resultdict import ResultDict
42 from zeroconf.asyncio import AsyncServiceInfo
43
44 from music_assistant.mass import MusicAssistant
45 from music_assistant.models import ProviderModuleType
46 from music_assistant.models.core_controller import CoreController
47 from music_assistant.models.provider import Provider
48
49from dataclasses import fields, is_dataclass
50
51LOGGER = logging.getLogger(__name__)
52
53HA_WHEELS = "https://wheels.home-assistant.io/musllinux/"
54
55T = TypeVar("T")
56CALLBACK_TYPE = Callable[[], None]
57
58
59def get_total_system_memory() -> float:
60 """Get total system memory in GB."""
61 try:
62 # Works on Linux and macOS
63 total_memory_bytes = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
64 return total_memory_bytes / (1024**3) # Convert to GB
65 except (AttributeError, ValueError):
66 # Fallback if sysconf is not available (e.g., Windows)
67 # Return a conservative default to disable buffering by default
68 return 0.0
69
70
71keyword_pattern = re.compile("title=|artist=")
72title_pattern = re.compile(r"title=\"(?P<title>.*?)\"")
73artist_pattern = re.compile(r"artist=\"(?P<artist>.*?)\"")
74dot_com_pattern = re.compile(r"(?P<netloc>\(?\w+\.(?:\w+\.)?(\w{2,3})\)?)")
75ad_pattern = re.compile(r"((ad|advertisement)_)|^AD\s\d+$|ADBREAK", flags=re.IGNORECASE)
76title_artist_order_pattern = re.compile(r"(?P<title>.+)\sBy:\s(?P<artist>.+)", flags=re.IGNORECASE)
77multi_space_pattern = re.compile(r"\s{2,}")
78end_junk_pattern = re.compile(r"(.+?)(\s\W+)$")
79
80VERSION_PARTS = (
81 # list of common version strings
82 "version",
83 "live",
84 "edit",
85 "remix",
86 "mix",
87 "acoustic",
88 "instrumental",
89 "karaoke",
90 "remaster",
91 "versie",
92 "unplugged",
93 "disco",
94 "akoestisch",
95 "deluxe",
96)
97IGNORE_TITLE_PARTS = (
98 # strings that may be stripped off a title part
99 # (most important the featuring parts)
100 "feat.",
101 "featuring",
102 "ft.",
103 "with ",
104 "explicit",
105)
106WITH_TITLE_WORDS = (
107 # words that, when following "with", indicate this is part of the song title
108 # not a featuring credit.
109 "someone",
110 "the",
111 "u",
112 "you",
113 "no",
114)
115
116
117def filename_from_string(string: str) -> str:
118 """Create filename from unsafe string."""
119 keepcharacters = (" ", ".", "_")
120 return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
121
122
123def try_parse_int(possible_int: Any, default: int | None = 0) -> int | None:
124 """Try to parse an int."""
125 try:
126 return int(float(possible_int))
127 except (TypeError, ValueError):
128 return default
129
130
131def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | None:
132 """Try to parse a float."""
133 try:
134 return float(possible_float)
135 except (TypeError, ValueError):
136 return default
137
138
139def try_parse_bool(possible_bool: Any) -> bool:
140 """Try to parse a bool."""
141 if isinstance(possible_bool, bool):
142 return possible_bool
143 return possible_bool in ["true", "True", "1", "on", "ON", 1]
144
145
146def try_parse_duration(duration_str: str) -> float:
147 """Try to parse a duration in seconds from a duration (HH:MM:SS) string."""
148 milliseconds = float("0." + duration_str.split(".")[-1]) if "." in duration_str else 0.0
149 duration_parts = duration_str.split(".")[0].split(",")[0].split(":")
150 if len(duration_parts) == 3:
151 seconds = sum(x * int(t) for x, t in zip([3600, 60, 1], duration_parts, strict=False))
152 elif len(duration_parts) == 2:
153 seconds = sum(x * int(t) for x, t in zip([60, 1], duration_parts, strict=False))
154 else:
155 seconds = int(duration_parts[0])
156 return seconds + milliseconds
157
158
159def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
160 """Try to parse version from the title."""
161 version = track_version or ""
162 for regex in (r"\(.*?\)", r"\[.*?\]", r" - .*"):
163 for title_part in re.findall(regex, title):
164 # Extract the content without brackets/dashes for checking
165 clean_part = title_part.translate(str.maketrans("", "", "()[]-")).strip().lower()
166
167 # Check if this should be ignored (featuring/explicit parts)
168 should_ignore = False
169 for ignore_str in IGNORE_TITLE_PARTS:
170 if clean_part.startswith(ignore_str):
171 # Special handling for "with " - check if followed by title words
172 if ignore_str == "with ":
173 # Extract the word after "with "
174 after_with = (
175 clean_part[len("with ") :].split()[0]
176 if len(clean_part) > len("with ")
177 else ""
178 )
179 if after_with in WITH_TITLE_WORDS:
180 # This is part of the title (e.g., "with you"), don't ignore
181 break
182 # Remove this part from the title
183 title = title.replace(title_part, "").strip()
184 should_ignore = True
185 break
186
187 if should_ignore:
188 continue
189
190 # Check if this part is a version
191 for version_str in VERSION_PARTS:
192 if version_str in clean_part:
193 # Preserve original casing for output
194 version = title_part.strip("()[]- ").strip()
195 title = title.replace(title_part, "").strip()
196 return title, version
197 return title, version
198
199
200def infer_album_type(title: str, version: str) -> AlbumType:
201 """Infer album type by looking for live or soundtrack indicators."""
202 combined = f"{title} {version}".lower()
203 for pat in LIVE_INDICATORS:
204 if re.search(pat, combined):
205 return AlbumType.LIVE
206 for pat in SOUNDTRACK_INDICATORS:
207 if re.search(pat, combined):
208 return AlbumType.SOUNDTRACK
209 return AlbumType.UNKNOWN
210
211
212def strip_ads(line: str) -> str:
213 """Strip Ads from line."""
214 if ad_pattern.search(line):
215 return "Advert"
216 return line
217
218
219def strip_url(line: str) -> str:
220 """Strip URL from line."""
221 return (
222 " ".join([p for p in line.split() if (not urlparse(p).scheme or not urlparse(p).netloc)])
223 ).rstrip()
224
225
226def strip_dotcom(line: str) -> str:
227 """Strip scheme-less netloc from line."""
228 return dot_com_pattern.sub("", line)
229
230
231def strip_end_junk(line: str) -> str:
232 """Strip non-word info from end of line."""
233 return end_junk_pattern.sub(r"\1", line)
234
235
236def swap_title_artist_order(line: str) -> str:
237 """Swap title/artist order in line."""
238 return title_artist_order_pattern.sub(r"\g<artist> - \g<title>", line)
239
240
241def strip_multi_space(line: str) -> str:
242 """Strip multi-whitespace from line."""
243 return multi_space_pattern.sub(" ", line)
244
245
246def multi_strip(line: str) -> str:
247 """Strip assorted junk from line."""
248 return strip_multi_space(
249 swap_title_artist_order(strip_end_junk(strip_dotcom(strip_url(strip_ads(line)))))
250 ).rstrip()
251
252
253def clean_stream_title(line: str) -> str:
254 """Strip junk text from radio streamtitle."""
255 title: str = ""
256 artist: str = ""
257
258 if not keyword_pattern.search(line):
259 return multi_strip(line)
260
261 if match := title_pattern.search(line):
262 title = multi_strip(match.group("title"))
263
264 if match := artist_pattern.search(line):
265 possible_artist = multi_strip(match.group("artist"))
266 if possible_artist and possible_artist != title:
267 artist = possible_artist
268
269 if not title and not artist:
270 return ""
271
272 if title:
273 if re.search(" - ", title) or not artist:
274 return title
275 if artist:
276 return f"{artist} - {title}"
277
278 if artist:
279 return artist
280
281 return line
282
283
284async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]:
285 """Return all IP-adresses of all network interfaces."""
286
287 def call() -> tuple[str, ...]:
288 result: list[tuple[int, str]] = []
289 # try to get the primary IP address
290 # this is the IP address of the default route
291 _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
292 _sock.settimeout(0)
293 try:
294 # doesn't even have to be reachable
295 _sock.connect(("10.254.254.254", 1))
296 primary_ip = _sock.getsockname()[0]
297 except Exception:
298 primary_ip = ""
299 finally:
300 _sock.close()
301 # get all IP addresses of all network interfaces
302 adapters = ifaddr.get_adapters()
303 for adapter in adapters:
304 for ip in adapter.ips:
305 if ip.is_IPv6 and not include_ipv6:
306 continue
307 # ifaddr returns IPv6 addresses as (address, flowinfo, scope_id) tuples
308 ip_str = ip.ip[0] if isinstance(ip.ip, tuple) else ip.ip
309 if ip_str.startswith(("127", "169.254")):
310 # filter out IPv4 loopback/APIPA address
311 continue
312 if ip_str.startswith(("::1", "::ffff:", "fe80")):
313 # filter out IPv6 loopback/link-local address
314 continue
315 if ip_str == primary_ip:
316 score = 10
317 elif ip_str.startswith(("192.168.",)):
318 # we rank the 192.168 range a bit higher as its most
319 # often used as the private network subnet
320 score = 2
321 elif ip_str.startswith(("172.", "10.", "192.")):
322 # we rank the 172 range a bit lower as its most
323 # often used as the private docker network
324 score = 1
325 else:
326 score = 0
327 result.append((score, ip_str))
328 result.sort(key=lambda x: x[0], reverse=True)
329 return tuple(ip[1] for ip in result)
330
331 return await asyncio.to_thread(call)
332
333
334async def get_primary_ip_address() -> str | None:
335 """Return the primary IP address of the system."""
336
337
338async def is_port_in_use(port: int) -> bool:
339 """Check if port is in use."""
340
341 def _is_port_in_use() -> bool:
342 # Try both IPv4 and IPv6 to support single-stack and dual-stack systems.
343 # A port is considered free if it can be bound on at least one address family.
344 for family, addr in ((socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")):
345 try:
346 with socket.socket(family, socket.SOCK_STREAM) as _sock:
347 # Set SO_REUSEADDR to match asyncio.start_server behavior
348 # This allows binding to ports in TIME_WAIT state
349 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
350 _sock.bind((addr, port))
351 return False
352 except OSError:
353 continue
354 return True
355
356 return await asyncio.to_thread(_is_port_in_use)
357
358
359async def select_free_port(range_start: int, range_end: int) -> int:
360 """Automatically find available port within range."""
361 for port in range(range_start, range_end):
362 if not await is_port_in_use(port):
363 return port
364 msg = "No free port available"
365 raise OSError(msg)
366
367
368async def get_ip_from_host(dns_name: str) -> str | None:
369 """Resolve (first) IP-address for given dns name."""
370
371 def _resolve() -> str | None:
372 try:
373 return socket.gethostbyname(dns_name)
374 except Exception:
375 # fail gracefully!
376 return None
377
378 return await asyncio.to_thread(_resolve)
379
380
381async def get_ip_pton(ip_string: str) -> bytes:
382 """Return socket pton for a local ip."""
383 try:
384 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET, ip_string)
385 except OSError:
386 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)
387
388
389def format_ip_for_url(ip_address: str) -> str:
390 """Wrap IPv6 addresses in brackets for use in URLs (RFC 2732)."""
391 if ":" in ip_address:
392 return f"[{ip_address}]"
393 return ip_address
394
395
396async def get_folder_size(folderpath: str) -> float:
397 """Return folder size in gb."""
398
399 def _get_folder_size(folderpath: str) -> float:
400 total_size = 0
401 for dirpath, _dirnames, filenames in os.walk(folderpath):
402 for _file in filenames:
403 _fp = os.path.join(dirpath, _file)
404 total_size += Path(_fp).stat().st_size
405 return total_size / float(1 << 30)
406
407 return await asyncio.to_thread(_get_folder_size, folderpath)
408
409
410def get_changed_keys(
411 dict1: dict[str, Any],
412 dict2: dict[str, Any],
413 recursive: bool = False,
414) -> set[str]:
415 """Compare 2 dicts and return set of changed keys."""
416 # TODO: Check with Marcel whether we should calculate new dicts based on ignore_keys
417 return set(get_changed_dict_values(dict1, dict2, recursive).keys())
418 # return set(get_changed_dict_values(dict1, dict2, ignore_keys, recursive).keys())
419
420
421def get_changed_dict_values(
422 dict1: dict[str, Any],
423 dict2: dict[str, Any],
424 recursive: bool = False,
425) -> dict[str, tuple[Any, Any]]:
426 """
427 Compare 2 dicts and return dict of changed values.
428
429 dict key is the changed key, value is tuple of old and new values.
430 """
431 if not dict1 and not dict2:
432 return {}
433 if not dict1:
434 return {key: (None, value) for key, value in dict2.items()}
435 if not dict2:
436 return {key: (None, value) for key, value in dict1.items()}
437 changed_values = {}
438 for key, value in dict2.items():
439 if isinstance(value, dict) and isinstance(dict1[key], dict) and recursive:
440 changed_subvalues = get_changed_dict_values(dict1[key], value, recursive)
441 for subkey, subvalue in changed_subvalues.items():
442 changed_values[f"{key}.{subkey}"] = subvalue
443 continue
444 if key not in dict1:
445 changed_values[key] = (None, value)
446 continue
447 if dict1[key] != value:
448 changed_values[key] = (dict1[key], value)
449 return changed_values
450
451
452def get_changed_dataclass_values(
453 obj1: T,
454 obj2: T,
455 recursive: bool = False,
456) -> dict[str, tuple[Any, Any]]:
457 """
458 Compare 2 dataclass instances of the same type and return dict of changed field values.
459
460 dict key is the changed field name, value is tuple of old and new values.
461 """
462 if not (is_dataclass(obj1) and is_dataclass(obj2)):
463 raise ValueError("Both objects must be dataclass instances")
464
465 changed_values: dict[str, tuple[Any, Any]] = {}
466 for field in fields(obj1):
467 val1 = getattr(obj1, field.name, None)
468 val2 = getattr(obj2, field.name, None)
469 if recursive and is_dataclass(val1) and is_dataclass(val2):
470 sub_changes = get_changed_dataclass_values(val1, val2, recursive)
471 for sub_field, sub_value in sub_changes.items():
472 changed_values[f"{field.name}.{sub_field}"] = sub_value
473 continue
474 if recursive and isinstance(val1, dict) and isinstance(val2, dict):
475 sub_changes = get_changed_dict_values(val1, val2, recursive=recursive)
476 for sub_field, sub_value in sub_changes.items():
477 changed_values[f"{field.name}.{sub_field}"] = sub_value
478 continue
479 if val1 != val2:
480 changed_values[field.name] = (val1, val2)
481 return changed_values
482
483
484def empty_queue[T](q: asyncio.Queue[T]) -> None:
485 """Empty an asyncio Queue."""
486 for _ in range(q.qsize()):
487 try:
488 q.get_nowait()
489 q.task_done()
490 except (asyncio.QueueEmpty, ValueError):
491 pass
492
493
494async def install_package(package: str) -> None:
495 """Install package with pip, raise when install failed."""
496 LOGGER.debug("Installing python package %s", package)
497 args = ["uv", "pip", "install", "--no-cache", "--find-links", HA_WHEELS, package]
498 return_code, output = await check_output(*args)
499 if return_code != 0:
500 msg = f"Failed to install package {package}\n{output.decode()}"
501 raise RuntimeError(msg)
502
503
504async def get_package_version(pkg_name: str) -> str | None:
505 """
506 Return the version of an installed (python) package.
507
508 Will return None if the package is not found.
509 """
510 try:
511 return await asyncio.to_thread(pkg_version, pkg_name)
512 except PackageNotFoundError:
513 return None
514
515
516async def is_hass_supervisor() -> bool:
517 """Return if we're running inside the HA Supervisor (e.g. HAOS)."""
518
519 def _check() -> bool:
520 try:
521 urllib.request.urlopen("http://supervisor/core", timeout=1)
522 except urllib.error.URLError as err:
523 # this should return a 401 unauthorized if it exists
524 return getattr(err, "code", 999) == 401
525 except Exception:
526 return False
527 return False
528
529 return await asyncio.to_thread(_check)
530
531
532async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType:
533 """Return module for given provider domain and make sure the requirements are met."""
534
535 @lru_cache
536 def _get_provider_module(domain: str) -> ProviderModuleType:
537 return cast(
538 "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers")
539 )
540
541 # ensure module requirements are met
542 for requirement in requirements:
543 if "==" not in requirement:
544 # we should really get rid of unpinned requirements
545 continue
546 package_name, version = requirement.split("==", 1)
547 installed_version = await get_package_version(package_name)
548 if installed_version == "0.0.0":
549 # ignore editable installs
550 continue
551 if installed_version != version:
552 await install_package(requirement)
553
554 # try to load the module
555 try:
556 return await asyncio.to_thread(_get_provider_module, domain)
557 except ImportError:
558 # (re)install ALL requirements
559 for requirement in requirements:
560 await install_package(requirement)
561 # try loading the provider again to be safe
562 # this will fail if something else is wrong (as it should)
563 return await asyncio.to_thread(_get_provider_module, domain)
564
565
566async def has_tmpfs_mount() -> bool:
567 """Check if we have a tmpfs mount."""
568
569 def _has_tmpfs_mount() -> bool:
570 """Check if we have a tmpfs mount."""
571 try:
572 with open("/proc/mounts") as file:
573 for line in file:
574 if "tmpfs /tmp tmpfs rw" in line:
575 return True
576 except (FileNotFoundError, OSError, PermissionError):
577 pass
578 return False
579
580 return await asyncio.to_thread(_has_tmpfs_mount)
581
582
583async def get_free_space(folder: str) -> float:
584 """Return free space on given folderpath in GB."""
585
586 def _get_free_space(folder: str) -> float:
587 """Return free space on given folderpath in GB."""
588 try:
589 res = shutil.disk_usage(folder)
590 return res.free / float(1 << 30)
591 except (FileNotFoundError, OSError, PermissionError):
592 return 0.0
593
594 return await asyncio.to_thread(_get_free_space, folder)
595
596
597async def get_free_space_percentage(folder: str) -> float:
598 """Return free space on given folderpath in percentage."""
599
600 def _get_free_space(folder: str) -> float:
601 """Return free space on given folderpath in GB."""
602 try:
603 res = shutil.disk_usage(folder)
604 return res.free / res.total * 100
605 except (FileNotFoundError, OSError, PermissionError):
606 return 0.0
607
608 return await asyncio.to_thread(_get_free_space, folder)
609
610
611async def has_enough_space(folder: str, size: int) -> bool:
612 """Check if folder has enough free space."""
613 return await get_free_space(folder) > size
614
615
616def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
617 """Chunk bytes data into smaller chunks."""
618 for i in range(0, len(data), chunk_size):
619 yield data[i : i + chunk_size]
620
621
622async def remove_file(file_path: str) -> None:
623 """Remove file path (if it exists)."""
624 if not await asyncio.to_thread(os.path.exists, file_path):
625 return
626 await asyncio.to_thread(os.remove, file_path)
627 LOGGER.log(VERBOSE_LOG_LEVEL, "Removed file: %s", file_path)
628
629
630def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None:
631 """Get primary IP address from zeroconf discovery info."""
632 for address in discovery_info.parsed_addresses(IPVersion.V4Only):
633 if address.startswith("127"):
634 # filter out loopback address
635 continue
636 if address.startswith("169.254"):
637 # filter out APIPA address
638 continue
639 return address
640 # fall back to IPv6 addresses if no usable IPv4 address found
641 for address in discovery_info.parsed_addresses(IPVersion.V6Only):
642 if address.startswith(("::1", "fe80")):
643 # filter out loopback and link-local addresses
644 continue
645 return address
646 return None
647
648
649def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None:
650 """Get port from zeroconf discovery info."""
651 return discovery_info.port
652
653
654async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None:
655 """Force close an async generator."""
656 task = asyncio.create_task(agen.__anext__())
657 task.cancel()
658 with suppress(asyncio.CancelledError, StopAsyncIteration):
659 await task
660 await agen.aclose()
661
662
663async def detect_charset(data: bytes, fallback: str = "utf-8") -> str:
664 """Detect charset of raw data."""
665 try:
666 detected: ResultDict = await asyncio.to_thread(chardet.detect, data)
667 if detected and detected["encoding"] and detected["confidence"] > 0.75:
668 assert isinstance(detected["encoding"], str) # for type checking
669 return detected["encoding"]
670 except Exception as err:
671 LOGGER.debug("Failed to detect charset: %s", err)
672 return fallback
673
674
675def parse_optional_bool(value: Any) -> bool | None:
676 """Parse an optional boolean value from various input types."""
677 if value is None:
678 return None
679 if isinstance(value, bool):
680 return value
681 if isinstance(value, str):
682 value_lower = value.strip().lower()
683 if value_lower in ("true", "1", "yes", "on"):
684 return True
685 if value_lower in ("false", "0", "no", "off"):
686 return False
687 if isinstance(value, (int, float)):
688 return bool(value)
689 return None
690
691
692def merge_dict(
693 base_dict: dict[Any, Any],
694 new_dict: dict[Any, Any],
695 allow_overwite: bool = False,
696) -> dict[Any, Any]:
697 """Merge dict without overwriting existing values."""
698 final_dict = base_dict.copy()
699 for key, value in new_dict.items():
700 if final_dict.get(key) and isinstance(value, dict):
701 final_dict[key] = merge_dict(final_dict[key], value)
702 if final_dict.get(key) and isinstance(value, tuple):
703 final_dict[key] = merge_tuples(final_dict[key], value)
704 if final_dict.get(key) and isinstance(value, list):
705 final_dict[key] = merge_lists(final_dict[key], value)
706 elif not final_dict.get(key) or allow_overwite:
707 final_dict[key] = value
708 return final_dict
709
710
711def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
712 """Merge 2 tuples."""
713 return tuple(x for x in base if x not in new) + tuple(new)
714
715
716def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
717 """Merge 2 lists."""
718 return [x for x in base if x not in new] + list(new)
719
720
721def percentage(part: float, whole: float) -> int:
722 """Calculate percentage."""
723 return int(100 * float(part) / float(whole))
724
725
726def validate_announcement_chime_url(url: str) -> bool:
727 """Validate announcement chime URL format."""
728 if not url or not url.strip():
729 return True # Empty URL is valid
730
731 if url == ANNOUNCE_ALERT_FILE:
732 return True # Built-in chime file is valid
733
734 try:
735 parsed = urlparse(url.strip())
736
737 if parsed.scheme not in ("http", "https"):
738 return False
739
740 if not parsed.netloc:
741 return False
742
743 path_lower = parsed.path.lower()
744 audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac")
745
746 return any(path_lower.endswith(ext) for ext in audio_extensions)
747
748 except Exception:
749 return False
750
751
752async def get_mac_address(ip_address: str) -> str | None:
753 """Get MAC address for given IP address via ARP lookup."""
754 try:
755 from getmac import get_mac_address as getmac_lookup # noqa: PLC0415
756
757 return await asyncio.to_thread(getmac_lookup, ip=ip_address)
758 except ImportError:
759 LOGGER.debug("getmac module not available, cannot resolve MAC from IP")
760 return None
761 except Exception as err:
762 LOGGER.debug("Failed to resolve MAC address for %s: %s", ip_address, err)
763 return None
764
765
766def is_locally_administered_mac(mac_address: str) -> bool:
767 """
768 Check if a MAC address is locally administered (virtual/randomized).
769
770 Locally administered addresses have bit 1 of the first octet set to 1.
771 These are often used by devices for virtual interfaces or protocol-specific
772 addresses (e.g., AirPlay, DLNA may use different virtual MACs than the real hardware MAC).
773
774 :param mac_address: MAC address in any common format (with :, -, or no separator).
775 :return: True if locally administered, False if globally unique (real hardware MAC).
776 """
777 # Normalize MAC address
778 mac_clean = mac_address.upper().replace(":", "").replace("-", "")
779 if len(mac_clean) < 2:
780 return False
781
782 # Get first octet and check bit 1 (second bit from right)
783 try:
784 first_octet = int(mac_clean[:2], 16)
785 return bool(first_octet & 0x02)
786 except ValueError:
787 return False
788
789
790def normalize_mac_for_matching(mac_address: str) -> str:
791 """
792 Normalize a MAC address for device matching by masking out the locally-administered bit.
793
794 Some protocols (like AirPlay) report a locally-administered MAC address variant where
795 bit 1 of the first octet is set. For example:
796 - Real hardware MAC: 54:78:C9:E6:0D:A0 (first byte 0x54 = 01010100)
797 - AirPlay reports: 56:78:C9:E6:0D:A0 (first byte 0x56 = 01010110)
798
799 These represent the same device but differ only in the locally-administered bit.
800 This function normalizes the MAC by clearing bit 1 of the first octet, allowing
801 both variants to match the same device.
802
803 :param mac_address: MAC address in any common format (with :, -, or no separator).
804 :return: Normalized MAC address in lowercase without separators, with the
805 locally-administered bit cleared.
806 """
807 # Normalize MAC address (remove separators, lowercase)
808 mac_clean = mac_address.lower().replace(":", "").replace("-", "")
809 if len(mac_clean) != 12:
810 # Invalid MAC length, return as-is
811 return mac_clean
812
813 try:
814 # Parse first octet and clear bit 1 (the locally-administered bit)
815 first_octet = int(mac_clean[:2], 16)
816 first_octet_normalized = first_octet & ~0x02 # Clear bit 1
817 # Reconstruct the MAC with the normalized first octet
818 return f"{first_octet_normalized:02x}{mac_clean[2:]}"
819 except ValueError:
820 # Invalid hex, return as-is
821 return mac_clean
822
823
824def is_valid_mac_address(mac_address: str | None) -> bool:
825 """
826 Check if a MAC address is valid and usable for device identification.
827
828 Invalid MAC addresses include:
829 - None or empty strings
830 - Null MAC: 00:00:00:00:00:00
831 - Broadcast MAC: ff:ff:ff:ff:ff:ff
832 - Any MAC that doesn't follow the expected pattern
833
834 :param mac_address: MAC address to validate.
835 :return: True if valid and usable, False otherwise.
836 """
837 if not mac_address:
838 return False
839
840 # Normalize MAC address (remove separators and convert to lowercase)
841 normalized = mac_address.lower().replace(":", "").replace("-", "")
842
843 # Check for invalid/reserved MAC addresses
844 if normalized in ("000000000000", "ffffffffffff"):
845 return False
846
847 # Check length and hex validity
848 if len(normalized) != 12:
849 return False
850
851 try:
852 int(normalized, 16)
853 return True
854 except ValueError:
855 return False
856
857
858def normalize_ip_address(ip_address: str | None) -> str | None:
859 """
860 Normalize IP address for comparison.
861
862 Handles IPv6-mapped IPv4 addresses (e.g., ::ffff:192.168.1.64 -> 192.168.1.64).
863
864 :param ip_address: IP address to normalize.
865 :return: Normalized IP address or None if invalid.
866 """
867 if not ip_address:
868 return None
869
870 # Handle IPv6-mapped IPv4 addresses
871 if ip_address.startswith("::ffff:"):
872 # Extract the IPv4 part
873 return ip_address[7:]
874
875 return ip_address
876
877
878async def resolve_real_mac_address(reported_mac: str | None, ip_address: str | None) -> str | None:
879 """
880 Resolve the real MAC address for a device.
881
882 Some devices report different virtual MAC addresses per protocol (AirPlay, DLNA,
883 Chromecast). This function tries to resolve the actual hardware MAC via ARP
884 when the reported MAC appears to be locally administered (virtual).
885
886 :param reported_mac: The MAC address reported by the protocol.
887 :param ip_address: The IP address of the device (for ARP lookup).
888 :return: The real MAC address if found, or None if it couldn't be resolved.
889 """
890 if not ip_address:
891 return None
892
893 # If no MAC reported or it's a locally administered one, try ARP lookup
894 if not reported_mac or is_locally_administered_mac(reported_mac):
895 real_mac = await get_mac_address(ip_address)
896 if real_mac and is_valid_mac_address(real_mac):
897 return real_mac.upper()
898
899 return None
900
901
902class TaskManager:
903 """
904 Helper class to run many tasks at once.
905
906 This is basically an alternative to asyncio.TaskGroup but this will not
907 cancel all operations when one of the tasks fails.
908 Logging of exceptions is done by the mass.create_task helper.
909 """
910
911 def __init__(self, mass: MusicAssistant, limit: int = 0):
912 """Initialize the TaskManager."""
913 self.mass = mass
914 self._tasks: list[asyncio.Task[None]] = []
915 self._semaphore = asyncio.Semaphore(limit) if limit else None
916
917 def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]:
918 """Create a new task and add it to the manager."""
919 task = self.mass.create_task(coro)
920 self._tasks.append(task)
921 return task
922
923 async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None:
924 """Create a new task with semaphore limit."""
925 assert self._semaphore is not None
926
927 def task_done_callback(_task: asyncio.Task[None]) -> None:
928 assert self._semaphore is not None # for type checking
929 self._tasks.remove(task)
930 self._semaphore.release()
931
932 await self._semaphore.acquire()
933 task: asyncio.Task[None] = self.create_task(coro)
934 task.add_done_callback(task_done_callback)
935
936 async def __aenter__(self) -> Self:
937 """Enter context manager."""
938 return self
939
940 async def __aexit__(
941 self,
942 exc_type: type[BaseException] | None,
943 exc_val: BaseException | None,
944 exc_tb: TracebackType | None,
945 ) -> bool | None:
946 """Exit context manager."""
947 if len(self._tasks) > 0:
948 await asyncio.wait(self._tasks)
949 self._tasks.clear()
950 return None
951
952
953_R = TypeVar("_R")
954_P = ParamSpec("_P")
955
956
957def lock[**P, R]( # type: ignore[valid-type]
958 func: Callable[_P, Awaitable[_R]],
959) -> Callable[_P, Coroutine[Any, Any, _R]]:
960 """Call async function using a Lock."""
961
962 @functools.wraps(func)
963 async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
964 """Call async function using the throttler with retries."""
965 if not (func_lock := getattr(func, "lock", None)):
966 func_lock = asyncio.Lock()
967 func.lock = func_lock # type: ignore[attr-defined]
968 async with func_lock:
969 return await func(*args, **kwargs)
970
971 return wrapper
972
973
974class TimedAsyncGenerator:
975 """
976 Async iterable that times out after a given time.
977
978 Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9
979 """
980
981 def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0):
982 """
983 Initialize the AsyncTimedIterable.
984
985 Args:
986 iterable: The async iterable to wrap.
987 timeout: The timeout in seconds for each iteration.
988 """
989
990 class AsyncTimedIterator:
991 def __init__(self) -> None:
992 self._iterator = iterable.__aiter__()
993
994 async def __anext__(self) -> Any:
995 result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout))
996 if not result:
997 raise StopAsyncIteration
998 return result
999
1000 self._factory = AsyncTimedIterator
1001
1002 def __aiter__(self): # type: ignore[no-untyped-def]
1003 """Return the async iterator."""
1004 return self._factory()
1005
1006
1007def guard_single_request[ProviderT: "Provider | CoreController", **P, R](
1008 func: Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
1009) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
1010 """Guard single request to a function."""
1011
1012 @functools.wraps(func)
1013 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
1014 mass = self.mass
1015 # create a task_id dynamically based on the function and args/kwargs
1016 cache_key_parts = [func.__class__.__name__, func.__name__, *args]
1017 for key in sorted(kwargs.keys()):
1018 cache_key_parts.append(f"{key}{kwargs[key]}")
1019 task_id = ".".join(map(str, cache_key_parts))
1020 task: asyncio.Task[R] = mass.create_task(
1021 func,
1022 self,
1023 *args,
1024 task_id=task_id,
1025 abort_existing=False,
1026 eager_start=True,
1027 **kwargs,
1028 )
1029 return await task
1030
1031 return wrapper
1032