/
/
/
1"""Various (server-only) tools and helpers."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import importlib
8import logging
9import os
10import re
11import shutil
12import socket
13import urllib.error
14import urllib.request
15from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine
16from contextlib import suppress
17from functools import lru_cache
18from importlib.metadata import PackageNotFoundError
19from importlib.metadata import version as pkg_version
20from pathlib import Path
21from types import TracebackType
22from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeVar, cast
23from urllib.parse import urlparse
24
25import chardet
26import ifaddr
27from music_assistant_models.enums import AlbumType
28from zeroconf import IPVersion
29
30from music_assistant.constants import (
31 ANNOUNCE_ALERT_FILE,
32 LIVE_INDICATORS,
33 SOUNDTRACK_INDICATORS,
34 VERBOSE_LOG_LEVEL,
35)
36from music_assistant.helpers.process import check_output
37
38if TYPE_CHECKING:
39 from collections.abc import Iterator
40
41 from chardet.resultdict import ResultDict
42 from zeroconf.asyncio import AsyncServiceInfo
43
44 from music_assistant.mass import MusicAssistant
45 from music_assistant.models import ProviderModuleType
46 from music_assistant.models.core_controller import CoreController
47 from music_assistant.models.provider import Provider
48
49from dataclasses import fields, is_dataclass
50
51LOGGER = logging.getLogger(__name__)
52
53HA_WHEELS = "https://wheels.home-assistant.io/musllinux/"
54
55T = TypeVar("T")
56CALLBACK_TYPE = Callable[[], None]
57
58
59def get_total_system_memory() -> float:
60 """Get total system memory in GB."""
61 try:
62 # Works on Linux and macOS
63 total_memory_bytes = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
64 return total_memory_bytes / (1024**3) # Convert to GB
65 except (AttributeError, ValueError):
66 # Fallback if sysconf is not available (e.g., Windows)
67 # Return a conservative default to disable buffering by default
68 return 0.0
69
70
71keyword_pattern = re.compile("title=|artist=")
72title_pattern = re.compile(r"title=\"(?P<title>.*?)\"")
73artist_pattern = re.compile(r"artist=\"(?P<artist>.*?)\"")
74dot_com_pattern = re.compile(r"(?P<netloc>\(?\w+\.(?:\w+\.)?(\w{2,3})\)?)")
75ad_pattern = re.compile(r"((ad|advertisement)_)|^AD\s\d+$|ADBREAK", flags=re.IGNORECASE)
76title_artist_order_pattern = re.compile(r"(?P<title>.+)\sBy:\s(?P<artist>.+)", flags=re.IGNORECASE)
77multi_space_pattern = re.compile(r"\s{2,}")
78end_junk_pattern = re.compile(r"(.+?)(\s\W+)$")
79
80VERSION_PARTS = (
81 # list of common version strings
82 "version",
83 "live",
84 "edit",
85 "remix",
86 "mix",
87 "acoustic",
88 "instrumental",
89 "karaoke",
90 "remaster",
91 "versie",
92 "unplugged",
93 "disco",
94 "akoestisch",
95 "deluxe",
96)
97IGNORE_TITLE_PARTS = (
98 # strings that may be stripped off a title part
99 # (most important the featuring parts)
100 "feat.",
101 "featuring",
102 "ft.",
103 "with ",
104 "explicit",
105)
106WITH_TITLE_WORDS = (
107 # words that, when following "with", indicate this is part of the song title
108 # not a featuring credit.
109 "someone",
110 "the",
111 "u",
112 "you",
113 "no",
114)
115
116
117def filename_from_string(string: str) -> str:
118 """Create filename from unsafe string."""
119 keepcharacters = (" ", ".", "_")
120 return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
121
122
123def try_parse_int(possible_int: Any, default: int | None = 0) -> int | None:
124 """Try to parse an int."""
125 try:
126 return int(float(possible_int))
127 except (TypeError, ValueError):
128 return default
129
130
131def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | None:
132 """Try to parse a float."""
133 try:
134 return float(possible_float)
135 except (TypeError, ValueError):
136 return default
137
138
139def try_parse_bool(possible_bool: Any) -> bool:
140 """Try to parse a bool."""
141 if isinstance(possible_bool, bool):
142 return possible_bool
143 return possible_bool in ["true", "True", "1", "on", "ON", 1]
144
145
146def try_parse_duration(duration_str: str) -> float:
147 """Try to parse a duration in seconds from a duration (HH:MM:SS) string."""
148 milliseconds = float("0." + duration_str.split(".")[-1]) if "." in duration_str else 0.0
149 duration_parts = duration_str.split(".")[0].split(",")[0].split(":")
150 if len(duration_parts) == 3:
151 seconds = sum(x * int(t) for x, t in zip([3600, 60, 1], duration_parts, strict=False))
152 elif len(duration_parts) == 2:
153 seconds = sum(x * int(t) for x, t in zip([60, 1], duration_parts, strict=False))
154 else:
155 seconds = int(duration_parts[0])
156 return seconds + milliseconds
157
158
159def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
160 """Try to parse version from the title."""
161 version = track_version or ""
162 for regex in (r"\(.*?\)", r"\[.*?\]", r" - .*"):
163 for title_part in re.findall(regex, title):
164 # Extract the content without brackets/dashes for checking
165 clean_part = title_part.translate(str.maketrans("", "", "()[]-")).strip().lower()
166
167 # Check if this should be ignored (featuring/explicit parts)
168 should_ignore = False
169 for ignore_str in IGNORE_TITLE_PARTS:
170 if clean_part.startswith(ignore_str):
171 # Special handling for "with " - check if followed by title words
172 if ignore_str == "with ":
173 # Extract the word after "with "
174 after_with = (
175 clean_part[len("with ") :].split()[0]
176 if len(clean_part) > len("with ")
177 else ""
178 )
179 if after_with in WITH_TITLE_WORDS:
180 # This is part of the title (e.g., "with you"), don't ignore
181 break
182 # Remove this part from the title
183 title = title.replace(title_part, "").strip()
184 should_ignore = True
185 break
186
187 if should_ignore:
188 continue
189
190 # Check if this part is a version
191 for version_str in VERSION_PARTS:
192 if version_str in clean_part:
193 # Preserve original casing for output
194 version = title_part.strip("()[]- ").strip()
195 title = title.replace(title_part, "").strip()
196 return title, version
197 return title, version
198
199
200def infer_album_type(title: str, version: str) -> AlbumType:
201 """Infer album type by looking for live or soundtrack indicators."""
202 combined = f"{title} {version}".lower()
203 for pat in LIVE_INDICATORS:
204 if re.search(pat, combined):
205 return AlbumType.LIVE
206 for pat in SOUNDTRACK_INDICATORS:
207 if re.search(pat, combined):
208 return AlbumType.SOUNDTRACK
209 return AlbumType.UNKNOWN
210
211
212def strip_ads(line: str) -> str:
213 """Strip Ads from line."""
214 if ad_pattern.search(line):
215 return "Advert"
216 return line
217
218
219def strip_url(line: str) -> str:
220 """Strip URL from line."""
221 return (
222 " ".join([p for p in line.split() if (not urlparse(p).scheme or not urlparse(p).netloc)])
223 ).rstrip()
224
225
226def strip_dotcom(line: str) -> str:
227 """Strip scheme-less netloc from line."""
228 return dot_com_pattern.sub("", line)
229
230
231def strip_end_junk(line: str) -> str:
232 """Strip non-word info from end of line."""
233 return end_junk_pattern.sub(r"\1", line)
234
235
236def swap_title_artist_order(line: str) -> str:
237 """Swap title/artist order in line."""
238 return title_artist_order_pattern.sub(r"\g<artist> - \g<title>", line)
239
240
241def strip_multi_space(line: str) -> str:
242 """Strip multi-whitespace from line."""
243 return multi_space_pattern.sub(" ", line)
244
245
246def multi_strip(line: str) -> str:
247 """Strip assorted junk from line."""
248 return strip_multi_space(
249 swap_title_artist_order(strip_end_junk(strip_dotcom(strip_url(strip_ads(line)))))
250 ).rstrip()
251
252
253def clean_stream_title(line: str) -> str:
254 """Strip junk text from radio streamtitle."""
255 title: str = ""
256 artist: str = ""
257
258 if not keyword_pattern.search(line):
259 return multi_strip(line)
260
261 if match := title_pattern.search(line):
262 title = multi_strip(match.group("title"))
263
264 if match := artist_pattern.search(line):
265 possible_artist = multi_strip(match.group("artist"))
266 if possible_artist and possible_artist != title:
267 artist = possible_artist
268
269 if not title and not artist:
270 return ""
271
272 if title:
273 if re.search(" - ", title) or not artist:
274 return title
275 if artist:
276 return f"{artist} - {title}"
277
278 if artist:
279 return artist
280
281 return line
282
283
284async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]:
285 """Return all IP-adresses of all network interfaces."""
286
287 def call() -> tuple[str, ...]:
288 result: list[tuple[int, str]] = []
289 # try to get the primary IP address
290 # this is the IP address of the default route
291 _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
292 _sock.settimeout(0)
293 try:
294 # doesn't even have to be reachable
295 _sock.connect(("10.254.254.254", 1))
296 primary_ip = _sock.getsockname()[0]
297 except Exception:
298 primary_ip = ""
299 finally:
300 _sock.close()
301 # get all IP addresses of all network interfaces
302 adapters = ifaddr.get_adapters()
303 for adapter in adapters:
304 for ip in adapter.ips:
305 if ip.is_IPv6 and not include_ipv6:
306 continue
307 ip_str = str(ip.ip)
308 if ip_str.startswith(("127", "169.254")):
309 # filter out IPv4 loopback/APIPA address
310 continue
311 if ip_str.startswith(("::1", "::ffff:", "fe80")):
312 # filter out IPv6 loopback/link-local address
313 continue
314 if ip_str == primary_ip:
315 score = 10
316 elif ip_str.startswith(("192.168.",)):
317 # we rank the 192.168 range a bit higher as its most
318 # often used as the private network subnet
319 score = 2
320 elif ip_str.startswith(("172.", "10.", "192.")):
321 # we rank the 172 range a bit lower as its most
322 # often used as the private docker network
323 score = 1
324 else:
325 score = 0
326 result.append((score, ip_str))
327 result.sort(key=lambda x: x[0], reverse=True)
328 return tuple(ip[1] for ip in result)
329
330 return await asyncio.to_thread(call)
331
332
333async def get_primary_ip_address() -> str | None:
334 """Return the primary IP address of the system."""
335
336
337async def is_port_in_use(port: int) -> bool:
338 """Check if port is in use."""
339
340 def _is_port_in_use() -> bool:
341 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _sock:
342 # Set SO_REUSEADDR to match asyncio.start_server behavior
343 # This allows binding to ports in TIME_WAIT state
344 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
345 try:
346 _sock.bind(("0.0.0.0", port))
347 except OSError:
348 return True
349 return False
350
351 return await asyncio.to_thread(_is_port_in_use)
352
353
354async def select_free_port(range_start: int, range_end: int) -> int:
355 """Automatically find available port within range."""
356 for port in range(range_start, range_end):
357 if not await is_port_in_use(port):
358 return port
359 msg = "No free port available"
360 raise OSError(msg)
361
362
363async def get_ip_from_host(dns_name: str) -> str | None:
364 """Resolve (first) IP-address for given dns name."""
365
366 def _resolve() -> str | None:
367 try:
368 return socket.gethostbyname(dns_name)
369 except Exception:
370 # fail gracefully!
371 return None
372
373 return await asyncio.to_thread(_resolve)
374
375
376async def get_ip_pton(ip_string: str) -> bytes:
377 """Return socket pton for a local ip."""
378 try:
379 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET, ip_string)
380 except OSError:
381 return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)
382
383
384async def get_folder_size(folderpath: str) -> float:
385 """Return folder size in gb."""
386
387 def _get_folder_size(folderpath: str) -> float:
388 total_size = 0
389 for dirpath, _dirnames, filenames in os.walk(folderpath):
390 for _file in filenames:
391 _fp = os.path.join(dirpath, _file)
392 total_size += Path(_fp).stat().st_size
393 return total_size / float(1 << 30)
394
395 return await asyncio.to_thread(_get_folder_size, folderpath)
396
397
398def get_changed_keys(
399 dict1: dict[str, Any],
400 dict2: dict[str, Any],
401 recursive: bool = False,
402) -> set[str]:
403 """Compare 2 dicts and return set of changed keys."""
404 # TODO: Check with Marcel whether we should calculate new dicts based on ignore_keys
405 return set(get_changed_dict_values(dict1, dict2, recursive).keys())
406 # return set(get_changed_dict_values(dict1, dict2, ignore_keys, recursive).keys())
407
408
409def get_changed_dict_values(
410 dict1: dict[str, Any],
411 dict2: dict[str, Any],
412 recursive: bool = False,
413) -> dict[str, tuple[Any, Any]]:
414 """
415 Compare 2 dicts and return dict of changed values.
416
417 dict key is the changed key, value is tuple of old and new values.
418 """
419 if not dict1 and not dict2:
420 return {}
421 if not dict1:
422 return {key: (None, value) for key, value in dict2.items()}
423 if not dict2:
424 return {key: (None, value) for key, value in dict1.items()}
425 changed_values = {}
426 for key, value in dict2.items():
427 if isinstance(value, dict) and isinstance(dict1[key], dict) and recursive:
428 changed_subvalues = get_changed_dict_values(dict1[key], value, recursive)
429 for subkey, subvalue in changed_subvalues.items():
430 changed_values[f"{key}.{subkey}"] = subvalue
431 continue
432 if key not in dict1:
433 changed_values[key] = (None, value)
434 continue
435 if dict1[key] != value:
436 changed_values[key] = (dict1[key], value)
437 return changed_values
438
439
440def get_changed_dataclass_values(
441 obj1: T,
442 obj2: T,
443 recursive: bool = False,
444) -> dict[str, tuple[Any, Any]]:
445 """
446 Compare 2 dataclass instances of the same type and return dict of changed field values.
447
448 dict key is the changed field name, value is tuple of old and new values.
449 """
450 if not (is_dataclass(obj1) and is_dataclass(obj2)):
451 raise ValueError("Both objects must be dataclass instances")
452
453 changed_values: dict[str, tuple[Any, Any]] = {}
454 for field in fields(obj1):
455 val1 = getattr(obj1, field.name, None)
456 val2 = getattr(obj2, field.name, None)
457 if recursive and is_dataclass(val1) and is_dataclass(val2):
458 sub_changes = get_changed_dataclass_values(val1, val2, recursive)
459 for sub_field, sub_value in sub_changes.items():
460 changed_values[f"{field.name}.{sub_field}"] = sub_value
461 continue
462 if recursive and isinstance(val1, dict) and isinstance(val2, dict):
463 sub_changes = get_changed_dict_values(val1, val2, recursive=recursive)
464 for sub_field, sub_value in sub_changes.items():
465 changed_values[f"{field.name}.{sub_field}"] = sub_value
466 continue
467 if val1 != val2:
468 changed_values[field.name] = (val1, val2)
469 return changed_values
470
471
472def empty_queue[T](q: asyncio.Queue[T]) -> None:
473 """Empty an asyncio Queue."""
474 for _ in range(q.qsize()):
475 try:
476 q.get_nowait()
477 q.task_done()
478 except (asyncio.QueueEmpty, ValueError):
479 pass
480
481
482async def install_package(package: str) -> None:
483 """Install package with pip, raise when install failed."""
484 LOGGER.debug("Installing python package %s", package)
485 args = ["uv", "pip", "install", "--no-cache", "--find-links", HA_WHEELS, package]
486 return_code, output = await check_output(*args)
487 if return_code != 0:
488 msg = f"Failed to install package {package}\n{output.decode()}"
489 raise RuntimeError(msg)
490
491
492async def get_package_version(pkg_name: str) -> str | None:
493 """
494 Return the version of an installed (python) package.
495
496 Will return None if the package is not found.
497 """
498 try:
499 return await asyncio.to_thread(pkg_version, pkg_name)
500 except PackageNotFoundError:
501 return None
502
503
504async def is_hass_supervisor() -> bool:
505 """Return if we're running inside the HA Supervisor (e.g. HAOS)."""
506
507 def _check() -> bool:
508 try:
509 urllib.request.urlopen("http://supervisor/core", timeout=1)
510 except urllib.error.URLError as err:
511 # this should return a 401 unauthorized if it exists
512 return getattr(err, "code", 999) == 401
513 except Exception:
514 return False
515 return False
516
517 return await asyncio.to_thread(_check)
518
519
520async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType:
521 """Return module for given provider domain and make sure the requirements are met."""
522
523 @lru_cache
524 def _get_provider_module(domain: str) -> ProviderModuleType:
525 return cast(
526 "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers")
527 )
528
529 # ensure module requirements are met
530 for requirement in requirements:
531 if "==" not in requirement:
532 # we should really get rid of unpinned requirements
533 continue
534 package_name, version = requirement.split("==", 1)
535 installed_version = await get_package_version(package_name)
536 if installed_version == "0.0.0":
537 # ignore editable installs
538 continue
539 if installed_version != version:
540 await install_package(requirement)
541
542 # try to load the module
543 try:
544 return await asyncio.to_thread(_get_provider_module, domain)
545 except ImportError:
546 # (re)install ALL requirements
547 for requirement in requirements:
548 await install_package(requirement)
549 # try loading the provider again to be safe
550 # this will fail if something else is wrong (as it should)
551 return await asyncio.to_thread(_get_provider_module, domain)
552
553
554async def has_tmpfs_mount() -> bool:
555 """Check if we have a tmpfs mount."""
556
557 def _has_tmpfs_mount() -> bool:
558 """Check if we have a tmpfs mount."""
559 try:
560 with open("/proc/mounts") as file:
561 for line in file:
562 if "tmpfs /tmp tmpfs rw" in line:
563 return True
564 except (FileNotFoundError, OSError, PermissionError):
565 pass
566 return False
567
568 return await asyncio.to_thread(_has_tmpfs_mount)
569
570
571async def get_free_space(folder: str) -> float:
572 """Return free space on given folderpath in GB."""
573
574 def _get_free_space(folder: str) -> float:
575 """Return free space on given folderpath in GB."""
576 try:
577 res = shutil.disk_usage(folder)
578 return res.free / float(1 << 30)
579 except (FileNotFoundError, OSError, PermissionError):
580 return 0.0
581
582 return await asyncio.to_thread(_get_free_space, folder)
583
584
585async def get_free_space_percentage(folder: str) -> float:
586 """Return free space on given folderpath in percentage."""
587
588 def _get_free_space(folder: str) -> float:
589 """Return free space on given folderpath in GB."""
590 try:
591 res = shutil.disk_usage(folder)
592 return res.free / res.total * 100
593 except (FileNotFoundError, OSError, PermissionError):
594 return 0.0
595
596 return await asyncio.to_thread(_get_free_space, folder)
597
598
599async def has_enough_space(folder: str, size: int) -> bool:
600 """Check if folder has enough free space."""
601 return await get_free_space(folder) > size
602
603
604def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
605 """Chunk bytes data into smaller chunks."""
606 for i in range(0, len(data), chunk_size):
607 yield data[i : i + chunk_size]
608
609
610async def remove_file(file_path: str) -> None:
611 """Remove file path (if it exists)."""
612 if not await asyncio.to_thread(os.path.exists, file_path):
613 return
614 await asyncio.to_thread(os.remove, file_path)
615 LOGGER.log(VERBOSE_LOG_LEVEL, "Removed file: %s", file_path)
616
617
618def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None:
619 """Get primary IP address from zeroconf discovery info."""
620 for address in discovery_info.parsed_addresses(IPVersion.V4Only):
621 if address.startswith("127"):
622 # filter out loopback address
623 continue
624 if address.startswith("169.254"):
625 # filter out APIPA address
626 continue
627 return address
628 return None
629
630
631def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None:
632 """Get port from zeroconf discovery info."""
633 return discovery_info.port
634
635
636async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None:
637 """Force close an async generator."""
638 task = asyncio.create_task(agen.__anext__())
639 task.cancel()
640 with suppress(asyncio.CancelledError, StopAsyncIteration):
641 await task
642 await agen.aclose()
643
644
645async def detect_charset(data: bytes, fallback: str = "utf-8") -> str:
646 """Detect charset of raw data."""
647 try:
648 detected: ResultDict = await asyncio.to_thread(chardet.detect, data)
649 if detected and detected["encoding"] and detected["confidence"] > 0.75:
650 assert isinstance(detected["encoding"], str) # for type checking
651 return detected["encoding"]
652 except Exception as err:
653 LOGGER.debug("Failed to detect charset: %s", err)
654 return fallback
655
656
657def merge_dict(
658 base_dict: dict[Any, Any],
659 new_dict: dict[Any, Any],
660 allow_overwite: bool = False,
661) -> dict[Any, Any]:
662 """Merge dict without overwriting existing values."""
663 final_dict = base_dict.copy()
664 for key, value in new_dict.items():
665 if final_dict.get(key) and isinstance(value, dict):
666 final_dict[key] = merge_dict(final_dict[key], value)
667 if final_dict.get(key) and isinstance(value, tuple):
668 final_dict[key] = merge_tuples(final_dict[key], value)
669 if final_dict.get(key) and isinstance(value, list):
670 final_dict[key] = merge_lists(final_dict[key], value)
671 elif not final_dict.get(key) or allow_overwite:
672 final_dict[key] = value
673 return final_dict
674
675
676def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
677 """Merge 2 tuples."""
678 return tuple(x for x in base if x not in new) + tuple(new)
679
680
681def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
682 """Merge 2 lists."""
683 return [x for x in base if x not in new] + list(new)
684
685
686def percentage(part: float, whole: float) -> int:
687 """Calculate percentage."""
688 return int(100 * float(part) / float(whole))
689
690
691def validate_announcement_chime_url(url: str) -> bool:
692 """Validate announcement chime URL format."""
693 if not url or not url.strip():
694 return True # Empty URL is valid
695
696 if url == ANNOUNCE_ALERT_FILE:
697 return True # Built-in chime file is valid
698
699 try:
700 parsed = urlparse(url.strip())
701
702 if parsed.scheme not in ("http", "https"):
703 return False
704
705 if not parsed.netloc:
706 return False
707
708 path_lower = parsed.path.lower()
709 audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac")
710
711 return any(path_lower.endswith(ext) for ext in audio_extensions)
712
713 except Exception:
714 return False
715
716
717async def get_mac_address(ip_address: str) -> str | None:
718 """Get MAC address for given IP address."""
719 from getmac import get_mac_address # noqa: PLC0415
720
721 return await asyncio.to_thread(get_mac_address, ip=ip_address)
722
723
724class TaskManager:
725 """
726 Helper class to run many tasks at once.
727
728 This is basically an alternative to asyncio.TaskGroup but this will not
729 cancel all operations when one of the tasks fails.
730 Logging of exceptions is done by the mass.create_task helper.
731 """
732
733 def __init__(self, mass: MusicAssistant, limit: int = 0):
734 """Initialize the TaskManager."""
735 self.mass = mass
736 self._tasks: list[asyncio.Task[None]] = []
737 self._semaphore = asyncio.Semaphore(limit) if limit else None
738
739 def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]:
740 """Create a new task and add it to the manager."""
741 task = self.mass.create_task(coro)
742 self._tasks.append(task)
743 return task
744
745 async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None:
746 """Create a new task with semaphore limit."""
747 assert self._semaphore is not None
748
749 def task_done_callback(_task: asyncio.Task[None]) -> None:
750 assert self._semaphore is not None # for type checking
751 self._tasks.remove(task)
752 self._semaphore.release()
753
754 await self._semaphore.acquire()
755 task: asyncio.Task[None] = self.create_task(coro)
756 task.add_done_callback(task_done_callback)
757
758 async def __aenter__(self) -> Self:
759 """Enter context manager."""
760 return self
761
762 async def __aexit__(
763 self,
764 exc_type: type[BaseException] | None,
765 exc_val: BaseException | None,
766 exc_tb: TracebackType | None,
767 ) -> bool | None:
768 """Exit context manager."""
769 if len(self._tasks) > 0:
770 await asyncio.wait(self._tasks)
771 self._tasks.clear()
772 return None
773
774
775_R = TypeVar("_R")
776_P = ParamSpec("_P")
777
778
779def lock[**P, R]( # type: ignore[valid-type]
780 func: Callable[_P, Awaitable[_R]],
781) -> Callable[_P, Coroutine[Any, Any, _R]]:
782 """Call async function using a Lock."""
783
784 @functools.wraps(func)
785 async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
786 """Call async function using the throttler with retries."""
787 if not (func_lock := getattr(func, "lock", None)):
788 func_lock = asyncio.Lock()
789 func.lock = func_lock # type: ignore[attr-defined]
790 async with func_lock:
791 return await func(*args, **kwargs)
792
793 return wrapper
794
795
796class TimedAsyncGenerator:
797 """
798 Async iterable that times out after a given time.
799
800 Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9
801 """
802
803 def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0):
804 """
805 Initialize the AsyncTimedIterable.
806
807 Args:
808 iterable: The async iterable to wrap.
809 timeout: The timeout in seconds for each iteration.
810 """
811
812 class AsyncTimedIterator:
813 def __init__(self) -> None:
814 self._iterator = iterable.__aiter__()
815
816 async def __anext__(self) -> Any:
817 result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout))
818 if not result:
819 raise StopAsyncIteration
820 return result
821
822 self._factory = AsyncTimedIterator
823
824 def __aiter__(self): # type: ignore[no-untyped-def]
825 """Return the async iterator."""
826 return self._factory()
827
828
829def guard_single_request[ProviderT: "Provider | CoreController", **P, R](
830 func: Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
831) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
832 """Guard single request to a function."""
833
834 @functools.wraps(func)
835 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
836 mass = self.mass
837 # create a task_id dynamically based on the function and args/kwargs
838 cache_key_parts = [func.__class__.__name__, func.__name__, *args]
839 for key in sorted(kwargs.keys()):
840 cache_key_parts.append(f"{key}{kwargs[key]}")
841 task_id = ".".join(map(str, cache_key_parts))
842 task: asyncio.Task[R] = mass.create_task(
843 func,
844 self,
845 *args,
846 task_id=task_id,
847 abort_existing=False,
848 **kwargs,
849 )
850 return await task
851
852 return wrapper
853