/
/
/
1"""Various (server-only) tools and helpers."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import importlib
8import logging
9import os
10import re
11import shutil
12import socket
13import urllib.error
14import urllib.request
15from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine
16from contextlib import suppress
17from functools import lru_cache
18from importlib.metadata import PackageNotFoundError
19from importlib.metadata import version as pkg_version
20from pathlib import Path
21from types import TracebackType
22from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeVar, cast
23from urllib.parse import urlparse
24
25import chardet
26import ifaddr
27from music_assistant_models.enums import AlbumType
28from zeroconf import IPVersion
29
30from music_assistant.constants import (
31 ANNOUNCE_ALERT_FILE,
32 LIVE_INDICATORS,
33 SOUNDTRACK_INDICATORS,
34 VERBOSE_LOG_LEVEL,
35)
36from music_assistant.helpers.process import check_output
37
38if TYPE_CHECKING:
39 from collections.abc import Iterator
40
41 from chardet.resultdict import ResultDict
42 from zeroconf.asyncio import AsyncServiceInfo
43
44 from music_assistant.mass import MusicAssistant
45 from music_assistant.models import ProviderModuleType
46 from music_assistant.models.core_controller import CoreController
47 from music_assistant.models.provider import Provider
48
49from dataclasses import fields, is_dataclass
50
51LOGGER = logging.getLogger(__name__)
52
53HA_WHEELS = "https://wheels.home-assistant.io/musllinux/"
54
55T = TypeVar("T")
56CALLBACK_TYPE = Callable[[], None]
57
58
59def get_total_system_memory() -> float:
60 """Get total system memory in GB."""
61 try:
62 # Works on Linux and macOS
63 total_memory_bytes = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
64 return total_memory_bytes / (1024**3) # Convert to GB
65 except (AttributeError, ValueError):
66 # Fallback if sysconf is not available (e.g., Windows)
67 # Return a conservative default to disable buffering by default
68 return 0.0
69
70
71keyword_pattern = re.compile("title=|artist=")
72title_pattern = re.compile(r"title=\"(?P<title>.*?)\"")
73artist_pattern = re.compile(r"artist=\"(?P<artist>.*?)\"")
74dot_com_pattern = re.compile(r"(?P<netloc>\(?\w+\.(?:\w+\.)?(\w{2,3})\)?)")
75ad_pattern = re.compile(r"((ad|advertisement)_)|^AD\s\d+$|ADBREAK", flags=re.IGNORECASE)
76title_artist_order_pattern = re.compile(r"(?P<title>.+)\sBy:\s(?P<artist>.+)", flags=re.IGNORECASE)
77multi_space_pattern = re.compile(r"\s{2,}")
78end_junk_pattern = re.compile(r"(.+?)(\s\W+)$")
79
80VERSION_PARTS = (
81 # list of common version strings
82 "version",
83 "live",
84 "edit",
85 "remix",
86 "mix",
87 "acoustic",
88 "instrumental",
89 "karaoke",
90 "remaster",
91 "versie",
92 "unplugged",
93 "disco",
94 "akoestisch",
95 "deluxe",
96)
97IGNORE_TITLE_PARTS = (
98 # strings that may be stripped off a title part
99 # (most important the featuring parts)
100 "feat.",
101 "featuring",
102 "ft.",
103 "with ",
104 "explicit",
105)
106WITH_TITLE_WORDS = (
107 # words that, when following "with", indicate this is part of the song title
108 # not a featuring credit.
109 "someone",
110 "the",
111 "u",
112 "you",
113 "no",
114)
115
116
117def filename_from_string(string: str) -> str:
118 """Create filename from unsafe string."""
119 keepcharacters = (" ", ".", "_")
120 return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
121
122
123def try_parse_int(possible_int: Any, default: int | None = 0) -> int | None:
124 """Try to parse an int."""
125 try:
126 return int(float(possible_int))
127 except (TypeError, ValueError):
128 return default
129
130
131def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | None:
132 """Try to parse a float."""
133 try:
134 return float(possible_float)
135 except (TypeError, ValueError):
136 return default
137
138
139def try_parse_bool(possible_bool: Any) -> bool:
140 """Try to parse a bool."""
141 if isinstance(possible_bool, bool):
142 return possible_bool
143 return possible_bool in ["true", "True", "1", "on", "ON", 1]
144
145
146def try_parse_duration(duration_str: str) -> float:
147 """Try to parse a duration in seconds from a duration (HH:MM:SS) string."""
148 milliseconds = float("0." + duration_str.split(".")[-1]) if "." in duration_str else 0.0
149 duration_parts = duration_str.split(".")[0].split(",")[0].split(":")
150 if len(duration_parts) == 3:
151 seconds = sum(x * int(t) for x, t in zip([3600, 60, 1], duration_parts, strict=False))
152 elif len(duration_parts) == 2:
153 seconds = sum(x * int(t) for x, t in zip([60, 1], duration_parts, strict=False))
154 else:
155 seconds = int(duration_parts[0])
156 return seconds + milliseconds
157
158
159def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
160 """Try to parse version from the title."""
161 version = track_version or ""
162 for regex in (r"\(.*?\)", r"\[.*?\]", r" - .*"):
163 for title_part in re.findall(regex, title):
164 # Extract the content without brackets/dashes for checking
165 clean_part = title_part.translate(str.maketrans("", "", "()[]-")).strip().lower()
166
167 # Check if this should be ignored (featuring/explicit parts)
168 should_ignore = False
169 for ignore_str in IGNORE_TITLE_PARTS:
170 if clean_part.startswith(ignore_str):
171 # Special handling for "with " - check if followed by title words
172 if ignore_str == "with ":
173 # Extract the word after "with "
174 after_with = (
175 clean_part[len("with ") :].split()[0]
176 if len(clean_part) > len("with ")
177 else ""
178 )
179 if after_with in WITH_TITLE_WORDS:
180 # This is part of the title (e.g., "with you"), don't ignore
181 break
182 # Remove this part from the title
183 title = title.replace(title_part, "").strip()
184 should_ignore = True
185 break
186
187 if should_ignore:
188 continue
189
190 # Check if this part is a version
191 for version_str in VERSION_PARTS:
192 if version_str in clean_part:
193 # Preserve original casing for output
194 version = title_part.strip("()[]- ").strip()
195 title = title.replace(title_part, "").strip()
196 return title, version
197 return title, version
198
199
200def infer_album_type(title: str, version: str) -> AlbumType:
201 """Infer album type by looking for live or soundtrack indicators."""
202 combined = f"{title} {version}".lower()
203 for pat in LIVE_INDICATORS:
204 if re.search(pat, combined):
205 return AlbumType.LIVE
206 for pat in SOUNDTRACK_INDICATORS:
207 if re.search(pat, combined):
208 return AlbumType.SOUNDTRACK
209 return AlbumType.UNKNOWN
210
211
212def strip_ads(line: str) -> str:
213 """Strip Ads from line."""
214 if ad_pattern.search(line):
215 return "Advert"
216 return line
217
218
219def strip_url(line: str) -> str:
220 """Strip URL from line."""
221 return (
222 " ".join([p for p in line.split() if (not urlparse(p).scheme or not urlparse(p).netloc)])
223 ).rstrip()
224
225
226def strip_dotcom(line: str) -> str:
227 """Strip scheme-less netloc from line."""
228 return dot_com_pattern.sub("", line)
229
230
231def strip_end_junk(line: str) -> str:
232 """Strip non-word info from end of line."""
233 return end_junk_pattern.sub(r"\1", line)
234
235
236def swap_title_artist_order(line: str) -> str:
237 """Swap title/artist order in line."""
238 return title_artist_order_pattern.sub(r"\g<artist> - \g<title>", line)
239
240
241def strip_multi_space(line: str) -> str:
242 """Strip multi-whitespace from line."""
243 return multi_space_pattern.sub(" ", line)
244
245
246def multi_strip(line: str) -> str:
247 """Strip assorted junk from line."""
248 return strip_multi_space(
249 swap_title_artist_order(strip_end_junk(strip_dotcom(strip_url(strip_ads(line)))))
250 ).rstrip()
251
252
253def clean_stream_title(line: str) -> str:
254 """Strip junk text from radio streamtitle."""
255 title: str = ""
256 artist: str = ""
257
258 if not keyword_pattern.search(line):
259 return multi_strip(line)
260
261 if match := title_pattern.search(line):
262 title = multi_strip(match.group("title"))
263
264 if match := artist_pattern.search(line):
265 possible_artist = multi_strip(match.group("artist"))
266 if possible_artist and possible_artist != title:
267 artist = possible_artist
268
269 if not title and not artist:
270 return ""
271
272 if title:
273 if re.search(" - ", title) or not artist:
274 return title
275 if artist:
276 return f"{artist} - {title}"
277
278 if artist:
279 return artist
280
281 return line
282
283
284async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]:
285 """Return all IP-adresses of all network interfaces."""
286
287 def call() -> tuple[str, ...]:
288 result: list[tuple[int, str]] = []
289 # try to get the primary IP address
290 # this is the IP address of the default route
291 _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
292 _sock.settimeout(0)
293 try:
294 # doesn't even have to be reachable
295 _sock.connect(("10.254.254.254", 1))
296 primary_ip = _sock.getsockname()[0]
297 except Exception:
298 primary_ip = ""
299 finally:
300 _sock.close()
301 # get all IP addresses of all network interfaces
302 adapters = ifaddr.get_adapters()
303 for adapter in adapters:
304 for ip in adapter.ips:
305 if ip.is_IPv6 and not include_ipv6:
306 continue
307 ip_str = str(ip.ip)
308 if ip_str.startswith(("127", "169.254")):
309 # filter out IPv4 loopback/APIPA address
310 continue
311 if ip_str.startswith(("::1", "::ffff:", "fe80")):
312 # filter out IPv6 loopback/link-local address
313 continue
314 if ip_str == primary_ip:
315 score = 10
316 elif ip_str.startswith(("192.168.",)):
317 # we rank the 192.168 range a bit higher as its most
318 # often used as the private network subnet
319 score = 2
320 elif ip_str.startswith(("172.", "10.", "192.")):
321 # we rank the 172 range a bit lower as its most
322 # often used as the private docker network
323 score = 1
324 else:
325 score = 0
326 result.append((score, ip_str))
327 result.sort(key=lambda x: x[0], reverse=True)
328 return tuple(ip[1] for ip in result)
329
330 return await asyncio.to_thread(call)
331
332
333async def get_primary_ip_address() -> str | None:
334 """Return the primary IP address of the system."""
335
336
337async def is_port_in_use(port: int) -> bool:
338 """Check if port is in use."""
339
340 def _is_port_in_use() -> bool:
341 # Try both IPv4 and IPv6 to support single-stack and dual-stack systems.
342 # A port is considered free if it can be bound on at least one address family.
343 for family, addr in ((socket.AF_INET, "0.0.0.0"), (socket.AF_INET6, "::")):
344 try:
345 with socket.socket(family, socket.SOCK_STREAM) as _sock:
346 # Set SO_REUSEADDR to match asyncio.start_server behavior
347 # This allows binding to ports in TIME_WAIT state
348 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
349 _sock.bind((addr, port))
350 return False
351 except OSError:
352 continue
353 return True
354
355 return await asyncio.to_thread(_is_port_in_use)
356
357
358async def select_free_port(range_start: int, range_end: int) -> int:
359 """Automatically find available port within range."""
360 for port in range(range_start, range_end):
361 if not await is_port_in_use(port):
362 return port
363 msg = "No free port available"
364 raise OSError(msg)
365
366
367async def get_ip_from_host(dns_name: str) -> str | None:
368 """Resolve (first) IP-address for given dns name."""
369
370 def _resolve() -> str | None:
371 try:
372 return socket.gethostbyname(dns_name)
373 except Exception:
374 # fail gracefully!
375 return None
376
377 return await asyncio.to_thread(_resolve)
378
379
380async def get_ip_pton(ip_string: str) -> bytes:
381 """Return socket pton for a local ip."""
382 try:
383 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET, ip_string)
384 except OSError:
385 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)
386
387
388def format_ip_for_url(ip_address: str) -> str:
389 """Wrap IPv6 addresses in brackets for use in URLs (RFC 2732)."""
390 if ":" in ip_address:
391 return f"[{ip_address}]"
392 return ip_address
393
394
395async def get_folder_size(folderpath: str) -> float:
396 """Return folder size in gb."""
397
398 def _get_folder_size(folderpath: str) -> float:
399 total_size = 0
400 for dirpath, _dirnames, filenames in os.walk(folderpath):
401 for _file in filenames:
402 _fp = os.path.join(dirpath, _file)
403 total_size += Path(_fp).stat().st_size
404 return total_size / float(1 << 30)
405
406 return await asyncio.to_thread(_get_folder_size, folderpath)
407
408
409def get_changed_keys(
410 dict1: dict[str, Any],
411 dict2: dict[str, Any],
412 recursive: bool = False,
413) -> set[str]:
414 """Compare 2 dicts and return set of changed keys."""
415 # TODO: Check with Marcel whether we should calculate new dicts based on ignore_keys
416 return set(get_changed_dict_values(dict1, dict2, recursive).keys())
417 # return set(get_changed_dict_values(dict1, dict2, ignore_keys, recursive).keys())
418
419
420def get_changed_dict_values(
421 dict1: dict[str, Any],
422 dict2: dict[str, Any],
423 recursive: bool = False,
424) -> dict[str, tuple[Any, Any]]:
425 """
426 Compare 2 dicts and return dict of changed values.
427
428 dict key is the changed key, value is tuple of old and new values.
429 """
430 if not dict1 and not dict2:
431 return {}
432 if not dict1:
433 return {key: (None, value) for key, value in dict2.items()}
434 if not dict2:
435 return {key: (None, value) for key, value in dict1.items()}
436 changed_values = {}
437 for key, value in dict2.items():
438 if isinstance(value, dict) and isinstance(dict1[key], dict) and recursive:
439 changed_subvalues = get_changed_dict_values(dict1[key], value, recursive)
440 for subkey, subvalue in changed_subvalues.items():
441 changed_values[f"{key}.{subkey}"] = subvalue
442 continue
443 if key not in dict1:
444 changed_values[key] = (None, value)
445 continue
446 if dict1[key] != value:
447 changed_values[key] = (dict1[key], value)
448 return changed_values
449
450
451def get_changed_dataclass_values(
452 obj1: T,
453 obj2: T,
454 recursive: bool = False,
455) -> dict[str, tuple[Any, Any]]:
456 """
457 Compare 2 dataclass instances of the same type and return dict of changed field values.
458
459 dict key is the changed field name, value is tuple of old and new values.
460 """
461 if not (is_dataclass(obj1) and is_dataclass(obj2)):
462 raise ValueError("Both objects must be dataclass instances")
463
464 changed_values: dict[str, tuple[Any, Any]] = {}
465 for field in fields(obj1):
466 val1 = getattr(obj1, field.name, None)
467 val2 = getattr(obj2, field.name, None)
468 if recursive and is_dataclass(val1) and is_dataclass(val2):
469 sub_changes = get_changed_dataclass_values(val1, val2, recursive)
470 for sub_field, sub_value in sub_changes.items():
471 changed_values[f"{field.name}.{sub_field}"] = sub_value
472 continue
473 if recursive and isinstance(val1, dict) and isinstance(val2, dict):
474 sub_changes = get_changed_dict_values(val1, val2, recursive=recursive)
475 for sub_field, sub_value in sub_changes.items():
476 changed_values[f"{field.name}.{sub_field}"] = sub_value
477 continue
478 if val1 != val2:
479 changed_values[field.name] = (val1, val2)
480 return changed_values
481
482
483def empty_queue[T](q: asyncio.Queue[T]) -> None:
484 """Empty an asyncio Queue."""
485 for _ in range(q.qsize()):
486 try:
487 q.get_nowait()
488 q.task_done()
489 except (asyncio.QueueEmpty, ValueError):
490 pass
491
492
493async def install_package(package: str) -> None:
494 """Install package with pip, raise when install failed."""
495 LOGGER.debug("Installing python package %s", package)
496 args = ["uv", "pip", "install", "--no-cache", "--find-links", HA_WHEELS, package]
497 return_code, output = await check_output(*args)
498 if return_code != 0:
499 msg = f"Failed to install package {package}\n{output.decode()}"
500 raise RuntimeError(msg)
501
502
503async def get_package_version(pkg_name: str) -> str | None:
504 """
505 Return the version of an installed (python) package.
506
507 Will return None if the package is not found.
508 """
509 try:
510 return await asyncio.to_thread(pkg_version, pkg_name)
511 except PackageNotFoundError:
512 return None
513
514
515async def is_hass_supervisor() -> bool:
516 """Return if we're running inside the HA Supervisor (e.g. HAOS)."""
517
518 def _check() -> bool:
519 try:
520 urllib.request.urlopen("http://supervisor/core", timeout=1)
521 except urllib.error.URLError as err:
522 # this should return a 401 unauthorized if it exists
523 return getattr(err, "code", 999) == 401
524 except Exception:
525 return False
526 return False
527
528 return await asyncio.to_thread(_check)
529
530
531async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType:
532 """Return module for given provider domain and make sure the requirements are met."""
533
534 @lru_cache
535 def _get_provider_module(domain: str) -> ProviderModuleType:
536 return cast(
537 "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers")
538 )
539
540 # ensure module requirements are met
541 for requirement in requirements:
542 if "==" not in requirement:
543 # we should really get rid of unpinned requirements
544 continue
545 package_name, version = requirement.split("==", 1)
546 installed_version = await get_package_version(package_name)
547 if installed_version == "0.0.0":
548 # ignore editable installs
549 continue
550 if installed_version != version:
551 await install_package(requirement)
552
553 # try to load the module
554 try:
555 return await asyncio.to_thread(_get_provider_module, domain)
556 except ImportError:
557 # (re)install ALL requirements
558 for requirement in requirements:
559 await install_package(requirement)
560 # try loading the provider again to be safe
561 # this will fail if something else is wrong (as it should)
562 return await asyncio.to_thread(_get_provider_module, domain)
563
564
565async def has_tmpfs_mount() -> bool:
566 """Check if we have a tmpfs mount."""
567
568 def _has_tmpfs_mount() -> bool:
569 """Check if we have a tmpfs mount."""
570 try:
571 with open("/proc/mounts") as file:
572 for line in file:
573 if "tmpfs /tmp tmpfs rw" in line:
574 return True
575 except (FileNotFoundError, OSError, PermissionError):
576 pass
577 return False
578
579 return await asyncio.to_thread(_has_tmpfs_mount)
580
581
582async def get_free_space(folder: str) -> float:
583 """Return free space on given folderpath in GB."""
584
585 def _get_free_space(folder: str) -> float:
586 """Return free space on given folderpath in GB."""
587 try:
588 res = shutil.disk_usage(folder)
589 return res.free / float(1 << 30)
590 except (FileNotFoundError, OSError, PermissionError):
591 return 0.0
592
593 return await asyncio.to_thread(_get_free_space, folder)
594
595
596async def get_free_space_percentage(folder: str) -> float:
597 """Return free space on given folderpath in percentage."""
598
599 def _get_free_space(folder: str) -> float:
600 """Return free space on given folderpath in GB."""
601 try:
602 res = shutil.disk_usage(folder)
603 return res.free / res.total * 100
604 except (FileNotFoundError, OSError, PermissionError):
605 return 0.0
606
607 return await asyncio.to_thread(_get_free_space, folder)
608
609
610async def has_enough_space(folder: str, size: int) -> bool:
611 """Check if folder has enough free space."""
612 return await get_free_space(folder) > size
613
614
615def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
616 """Chunk bytes data into smaller chunks."""
617 for i in range(0, len(data), chunk_size):
618 yield data[i : i + chunk_size]
619
620
621async def remove_file(file_path: str) -> None:
622 """Remove file path (if it exists)."""
623 if not await asyncio.to_thread(os.path.exists, file_path):
624 return
625 await asyncio.to_thread(os.remove, file_path)
626 LOGGER.log(VERBOSE_LOG_LEVEL, "Removed file: %s", file_path)
627
628
629def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None:
630 """Get primary IP address from zeroconf discovery info."""
631 for address in discovery_info.parsed_addresses(IPVersion.V4Only):
632 if address.startswith("127"):
633 # filter out loopback address
634 continue
635 if address.startswith("169.254"):
636 # filter out APIPA address
637 continue
638 return address
639 # fall back to IPv6 addresses if no usable IPv4 address found
640 for address in discovery_info.parsed_addresses(IPVersion.V6Only):
641 if address.startswith(("::1", "fe80")):
642 # filter out loopback and link-local addresses
643 continue
644 return address
645 return None
646
647
648def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None:
649 """Get port from zeroconf discovery info."""
650 return discovery_info.port
651
652
653async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None:
654 """Force close an async generator."""
655 task = asyncio.create_task(agen.__anext__())
656 task.cancel()
657 with suppress(asyncio.CancelledError, StopAsyncIteration):
658 await task
659 await agen.aclose()
660
661
662async def detect_charset(data: bytes, fallback: str = "utf-8") -> str:
663 """Detect charset of raw data."""
664 try:
665 detected: ResultDict = await asyncio.to_thread(chardet.detect, data)
666 if detected and detected["encoding"] and detected["confidence"] > 0.75:
667 assert isinstance(detected["encoding"], str) # for type checking
668 return detected["encoding"]
669 except Exception as err:
670 LOGGER.debug("Failed to detect charset: %s", err)
671 return fallback
672
673
674def merge_dict(
675 base_dict: dict[Any, Any],
676 new_dict: dict[Any, Any],
677 allow_overwite: bool = False,
678) -> dict[Any, Any]:
679 """Merge dict without overwriting existing values."""
680 final_dict = base_dict.copy()
681 for key, value in new_dict.items():
682 if final_dict.get(key) and isinstance(value, dict):
683 final_dict[key] = merge_dict(final_dict[key], value)
684 if final_dict.get(key) and isinstance(value, tuple):
685 final_dict[key] = merge_tuples(final_dict[key], value)
686 if final_dict.get(key) and isinstance(value, list):
687 final_dict[key] = merge_lists(final_dict[key], value)
688 elif not final_dict.get(key) or allow_overwite:
689 final_dict[key] = value
690 return final_dict
691
692
693def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
694 """Merge 2 tuples."""
695 return tuple(x for x in base if x not in new) + tuple(new)
696
697
698def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
699 """Merge 2 lists."""
700 return [x for x in base if x not in new] + list(new)
701
702
703def percentage(part: float, whole: float) -> int:
704 """Calculate percentage."""
705 return int(100 * float(part) / float(whole))
706
707
708def validate_announcement_chime_url(url: str) -> bool:
709 """Validate announcement chime URL format."""
710 if not url or not url.strip():
711 return True # Empty URL is valid
712
713 if url == ANNOUNCE_ALERT_FILE:
714 return True # Built-in chime file is valid
715
716 try:
717 parsed = urlparse(url.strip())
718
719 if parsed.scheme not in ("http", "https"):
720 return False
721
722 if not parsed.netloc:
723 return False
724
725 path_lower = parsed.path.lower()
726 audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac")
727
728 return any(path_lower.endswith(ext) for ext in audio_extensions)
729
730 except Exception:
731 return False
732
733
734async def get_mac_address(ip_address: str) -> str | None:
735 """Get MAC address for given IP address."""
736 from getmac import get_mac_address # noqa: PLC0415
737
738 return await asyncio.to_thread(get_mac_address, ip=ip_address)
739
740
741class TaskManager:
742 """
743 Helper class to run many tasks at once.
744
745 This is basically an alternative to asyncio.TaskGroup but this will not
746 cancel all operations when one of the tasks fails.
747 Logging of exceptions is done by the mass.create_task helper.
748 """
749
750 def __init__(self, mass: MusicAssistant, limit: int = 0):
751 """Initialize the TaskManager."""
752 self.mass = mass
753 self._tasks: list[asyncio.Task[None]] = []
754 self._semaphore = asyncio.Semaphore(limit) if limit else None
755
756 def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]:
757 """Create a new task and add it to the manager."""
758 task = self.mass.create_task(coro)
759 self._tasks.append(task)
760 return task
761
762 async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None:
763 """Create a new task with semaphore limit."""
764 assert self._semaphore is not None
765
766 def task_done_callback(_task: asyncio.Task[None]) -> None:
767 assert self._semaphore is not None # for type checking
768 self._tasks.remove(task)
769 self._semaphore.release()
770
771 await self._semaphore.acquire()
772 task: asyncio.Task[None] = self.create_task(coro)
773 task.add_done_callback(task_done_callback)
774
775 async def __aenter__(self) -> Self:
776 """Enter context manager."""
777 return self
778
779 async def __aexit__(
780 self,
781 exc_type: type[BaseException] | None,
782 exc_val: BaseException | None,
783 exc_tb: TracebackType | None,
784 ) -> bool | None:
785 """Exit context manager."""
786 if len(self._tasks) > 0:
787 await asyncio.wait(self._tasks)
788 self._tasks.clear()
789 return None
790
791
792_R = TypeVar("_R")
793_P = ParamSpec("_P")
794
795
796def lock[**P, R]( # type: ignore[valid-type]
797 func: Callable[_P, Awaitable[_R]],
798) -> Callable[_P, Coroutine[Any, Any, _R]]:
799 """Call async function using a Lock."""
800
801 @functools.wraps(func)
802 async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
803 """Call async function using the throttler with retries."""
804 if not (func_lock := getattr(func, "lock", None)):
805 func_lock = asyncio.Lock()
806 func.lock = func_lock # type: ignore[attr-defined]
807 async with func_lock:
808 return await func(*args, **kwargs)
809
810 return wrapper
811
812
813class TimedAsyncGenerator:
814 """
815 Async iterable that times out after a given time.
816
817 Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9
818 """
819
820 def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0):
821 """
822 Initialize the AsyncTimedIterable.
823
824 Args:
825 iterable: The async iterable to wrap.
826 timeout: The timeout in seconds for each iteration.
827 """
828
829 class AsyncTimedIterator:
830 def __init__(self) -> None:
831 self._iterator = iterable.__aiter__()
832
833 async def __anext__(self) -> Any:
834 result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout))
835 if not result:
836 raise StopAsyncIteration
837 return result
838
839 self._factory = AsyncTimedIterator
840
841 def __aiter__(self): # type: ignore[no-untyped-def]
842 """Return the async iterator."""
843 return self._factory()
844
845
846def guard_single_request[ProviderT: "Provider | CoreController", **P, R](
847 func: Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
848) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
849 """Guard single request to a function."""
850
851 @functools.wraps(func)
852 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
853 mass = self.mass
854 # create a task_id dynamically based on the function and args/kwargs
855 cache_key_parts = [func.__class__.__name__, func.__name__, *args]
856 for key in sorted(kwargs.keys()):
857 cache_key_parts.append(f"{key}{kwargs[key]}")
858 task_id = ".".join(map(str, cache_key_parts))
859 task: asyncio.Task[R] = mass.create_task(
860 func,
861 self,
862 *args,
863 task_id=task_id,
864 abort_existing=False,
865 **kwargs,
866 )
867 return await task
868
869 return wrapper
870