music-assistant-server

28.1 KBPY
util.py
28.1 KB853 lines • python
1"""Various (server-only) tools and helpers."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import importlib
8import logging
9import os
10import re
11import shutil
12import socket
13import urllib.error
14import urllib.request
15from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine
16from contextlib import suppress
17from functools import lru_cache
18from importlib.metadata import PackageNotFoundError
19from importlib.metadata import version as pkg_version
20from pathlib import Path
21from types import TracebackType
22from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeVar, cast
23from urllib.parse import urlparse
24
25import chardet
26import ifaddr
27from music_assistant_models.enums import AlbumType
28from zeroconf import IPVersion
29
30from music_assistant.constants import (
31    ANNOUNCE_ALERT_FILE,
32    LIVE_INDICATORS,
33    SOUNDTRACK_INDICATORS,
34    VERBOSE_LOG_LEVEL,
35)
36from music_assistant.helpers.process import check_output
37
38if TYPE_CHECKING:
39    from collections.abc import Iterator
40
41    from chardet.resultdict import ResultDict
42    from zeroconf.asyncio import AsyncServiceInfo
43
44    from music_assistant.mass import MusicAssistant
45    from music_assistant.models import ProviderModuleType
46    from music_assistant.models.core_controller import CoreController
47    from music_assistant.models.provider import Provider
48
49from dataclasses import fields, is_dataclass
50
51LOGGER = logging.getLogger(__name__)
52
53HA_WHEELS = "https://wheels.home-assistant.io/musllinux/"
54
55T = TypeVar("T")
56CALLBACK_TYPE = Callable[[], None]
57
58
59def get_total_system_memory() -> float:
60    """Get total system memory in GB."""
61    try:
62        # Works on Linux and macOS
63        total_memory_bytes = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
64        return total_memory_bytes / (1024**3)  # Convert to GB
65    except (AttributeError, ValueError):
66        # Fallback if sysconf is not available (e.g., Windows)
67        # Return a conservative default to disable buffering by default
68        return 0.0
69
70
71keyword_pattern = re.compile("title=|artist=")
72title_pattern = re.compile(r"title=\"(?P<title>.*?)\"")
73artist_pattern = re.compile(r"artist=\"(?P<artist>.*?)\"")
74dot_com_pattern = re.compile(r"(?P<netloc>\(?\w+\.(?:\w+\.)?(\w{2,3})\)?)")
75ad_pattern = re.compile(r"((ad|advertisement)_)|^AD\s\d+$|ADBREAK", flags=re.IGNORECASE)
76title_artist_order_pattern = re.compile(r"(?P<title>.+)\sBy:\s(?P<artist>.+)", flags=re.IGNORECASE)
77multi_space_pattern = re.compile(r"\s{2,}")
78end_junk_pattern = re.compile(r"(.+?)(\s\W+)$")
79
80VERSION_PARTS = (
81    # list of common version strings
82    "version",
83    "live",
84    "edit",
85    "remix",
86    "mix",
87    "acoustic",
88    "instrumental",
89    "karaoke",
90    "remaster",
91    "versie",
92    "unplugged",
93    "disco",
94    "akoestisch",
95    "deluxe",
96)
97IGNORE_TITLE_PARTS = (
98    # strings that may be stripped off a title part
99    # (most important the featuring parts)
100    "feat.",
101    "featuring",
102    "ft.",
103    "with ",
104    "explicit",
105)
106WITH_TITLE_WORDS = (
107    # words that, when following "with", indicate this is part of the song title
108    # not a featuring credit.
109    "someone",
110    "the",
111    "u",
112    "you",
113    "no",
114)
115
116
117def filename_from_string(string: str) -> str:
118    """Create filename from unsafe string."""
119    keepcharacters = (" ", ".", "_")
120    return "".join(c for c in string if c.isalnum() or c in keepcharacters).rstrip()
121
122
123def try_parse_int(possible_int: Any, default: int | None = 0) -> int | None:
124    """Try to parse an int."""
125    try:
126        return int(float(possible_int))
127    except (TypeError, ValueError):
128        return default
129
130
131def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | None:
132    """Try to parse a float."""
133    try:
134        return float(possible_float)
135    except (TypeError, ValueError):
136        return default
137
138
139def try_parse_bool(possible_bool: Any) -> bool:
140    """Try to parse a bool."""
141    if isinstance(possible_bool, bool):
142        return possible_bool
143    return possible_bool in ["true", "True", "1", "on", "ON", 1]
144
145
146def try_parse_duration(duration_str: str) -> float:
147    """Try to parse a duration in seconds from a duration (HH:MM:SS) string."""
148    milliseconds = float("0." + duration_str.split(".")[-1]) if "." in duration_str else 0.0
149    duration_parts = duration_str.split(".")[0].split(",")[0].split(":")
150    if len(duration_parts) == 3:
151        seconds = sum(x * int(t) for x, t in zip([3600, 60, 1], duration_parts, strict=False))
152    elif len(duration_parts) == 2:
153        seconds = sum(x * int(t) for x, t in zip([60, 1], duration_parts, strict=False))
154    else:
155        seconds = int(duration_parts[0])
156    return seconds + milliseconds
157
158
159def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
160    """Try to parse version from the title."""
161    version = track_version or ""
162    for regex in (r"\(.*?\)", r"\[.*?\]", r" - .*"):
163        for title_part in re.findall(regex, title):
164            # Extract the content without brackets/dashes for checking
165            clean_part = title_part.translate(str.maketrans("", "", "()[]-")).strip().lower()
166
167            # Check if this should be ignored (featuring/explicit parts)
168            should_ignore = False
169            for ignore_str in IGNORE_TITLE_PARTS:
170                if clean_part.startswith(ignore_str):
171                    # Special handling for "with " - check if followed by title words
172                    if ignore_str == "with ":
173                        # Extract the word after "with "
174                        after_with = (
175                            clean_part[len("with ") :].split()[0]
176                            if len(clean_part) > len("with ")
177                            else ""
178                        )
179                        if after_with in WITH_TITLE_WORDS:
180                            # This is part of the title (e.g., "with you"), don't ignore
181                            break
182                    # Remove this part from the title
183                    title = title.replace(title_part, "").strip()
184                    should_ignore = True
185                    break
186
187            if should_ignore:
188                continue
189
190            # Check if this part is a version
191            for version_str in VERSION_PARTS:
192                if version_str in clean_part:
193                    # Preserve original casing for output
194                    version = title_part.strip("()[]- ").strip()
195                    title = title.replace(title_part, "").strip()
196                    return title, version
197    return title, version
198
199
200def infer_album_type(title: str, version: str) -> AlbumType:
201    """Infer album type by looking for live or soundtrack indicators."""
202    combined = f"{title} {version}".lower()
203    for pat in LIVE_INDICATORS:
204        if re.search(pat, combined):
205            return AlbumType.LIVE
206    for pat in SOUNDTRACK_INDICATORS:
207        if re.search(pat, combined):
208            return AlbumType.SOUNDTRACK
209    return AlbumType.UNKNOWN
210
211
212def strip_ads(line: str) -> str:
213    """Strip Ads from line."""
214    if ad_pattern.search(line):
215        return "Advert"
216    return line
217
218
219def strip_url(line: str) -> str:
220    """Strip URL from line."""
221    return (
222        " ".join([p for p in line.split() if (not urlparse(p).scheme or not urlparse(p).netloc)])
223    ).rstrip()
224
225
226def strip_dotcom(line: str) -> str:
227    """Strip scheme-less netloc from line."""
228    return dot_com_pattern.sub("", line)
229
230
231def strip_end_junk(line: str) -> str:
232    """Strip non-word info from end of line."""
233    return end_junk_pattern.sub(r"\1", line)
234
235
236def swap_title_artist_order(line: str) -> str:
237    """Swap title/artist order in line."""
238    return title_artist_order_pattern.sub(r"\g<artist> - \g<title>", line)
239
240
241def strip_multi_space(line: str) -> str:
242    """Strip multi-whitespace from line."""
243    return multi_space_pattern.sub(" ", line)
244
245
246def multi_strip(line: str) -> str:
247    """Strip assorted junk from line."""
248    return strip_multi_space(
249        swap_title_artist_order(strip_end_junk(strip_dotcom(strip_url(strip_ads(line)))))
250    ).rstrip()
251
252
253def clean_stream_title(line: str) -> str:
254    """Strip junk text from radio streamtitle."""
255    title: str = ""
256    artist: str = ""
257
258    if not keyword_pattern.search(line):
259        return multi_strip(line)
260
261    if match := title_pattern.search(line):
262        title = multi_strip(match.group("title"))
263
264    if match := artist_pattern.search(line):
265        possible_artist = multi_strip(match.group("artist"))
266        if possible_artist and possible_artist != title:
267            artist = possible_artist
268
269    if not title and not artist:
270        return ""
271
272    if title:
273        if re.search(" - ", title) or not artist:
274            return title
275        if artist:
276            return f"{artist} - {title}"
277
278    if artist:
279        return artist
280
281    return line
282
283
284async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]:
285    """Return all IP-adresses of all network interfaces."""
286
287    def call() -> tuple[str, ...]:
288        result: list[tuple[int, str]] = []
289        # try to get the primary IP address
290        # this is the IP address of the default route
291        _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
292        _sock.settimeout(0)
293        try:
294            # doesn't even have to be reachable
295            _sock.connect(("10.254.254.254", 1))
296            primary_ip = _sock.getsockname()[0]
297        except Exception:
298            primary_ip = ""
299        finally:
300            _sock.close()
301        # get all IP addresses of all network interfaces
302        adapters = ifaddr.get_adapters()
303        for adapter in adapters:
304            for ip in adapter.ips:
305                if ip.is_IPv6 and not include_ipv6:
306                    continue
307                ip_str = str(ip.ip)
308                if ip_str.startswith(("127", "169.254")):
309                    # filter out IPv4 loopback/APIPA address
310                    continue
311                if ip_str.startswith(("::1", "::ffff:", "fe80")):
312                    # filter out IPv6 loopback/link-local address
313                    continue
314                if ip_str == primary_ip:
315                    score = 10
316                elif ip_str.startswith(("192.168.",)):
317                    # we rank the 192.168 range a bit higher as its most
318                    # often used as the private network subnet
319                    score = 2
320                elif ip_str.startswith(("172.", "10.", "192.")):
321                    # we rank the 172 range a bit lower as its most
322                    # often used as the private docker network
323                    score = 1
324                else:
325                    score = 0
326                result.append((score, ip_str))
327        result.sort(key=lambda x: x[0], reverse=True)
328        return tuple(ip[1] for ip in result)
329
330    return await asyncio.to_thread(call)
331
332
333async def get_primary_ip_address() -> str | None:
334    """Return the primary IP address of the system."""
335
336
337async def is_port_in_use(port: int) -> bool:
338    """Check if port is in use."""
339
340    def _is_port_in_use() -> bool:
341        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _sock:
342            # Set SO_REUSEADDR to match asyncio.start_server behavior
343            # This allows binding to ports in TIME_WAIT state
344            _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
345            try:
346                _sock.bind(("0.0.0.0", port))
347            except OSError:
348                return True
349        return False
350
351    return await asyncio.to_thread(_is_port_in_use)
352
353
354async def select_free_port(range_start: int, range_end: int) -> int:
355    """Automatically find available port within range."""
356    for port in range(range_start, range_end):
357        if not await is_port_in_use(port):
358            return port
359    msg = "No free port available"
360    raise OSError(msg)
361
362
363async def get_ip_from_host(dns_name: str) -> str | None:
364    """Resolve (first) IP-address for given dns name."""
365
366    def _resolve() -> str | None:
367        try:
368            return socket.gethostbyname(dns_name)
369        except Exception:
370            # fail gracefully!
371            return None
372
373    return await asyncio.to_thread(_resolve)
374
375
376async def get_ip_pton(ip_string: str) -> bytes:
377    """Return socket pton for a local ip."""
378    try:
379        return await asyncio.to_thread(socket.inet_pton, socket.AF_INET, ip_string)
380    except OSError:
381        return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)
382
383
384async def get_folder_size(folderpath: str) -> float:
385    """Return folder size in gb."""
386
387    def _get_folder_size(folderpath: str) -> float:
388        total_size = 0
389        for dirpath, _dirnames, filenames in os.walk(folderpath):
390            for _file in filenames:
391                _fp = os.path.join(dirpath, _file)
392                total_size += Path(_fp).stat().st_size
393        return total_size / float(1 << 30)
394
395    return await asyncio.to_thread(_get_folder_size, folderpath)
396
397
398def get_changed_keys(
399    dict1: dict[str, Any],
400    dict2: dict[str, Any],
401    recursive: bool = False,
402) -> set[str]:
403    """Compare 2 dicts and return set of changed keys."""
404    # TODO: Check with Marcel whether we should calculate new dicts based on ignore_keys
405    return set(get_changed_dict_values(dict1, dict2, recursive).keys())
406    # return set(get_changed_dict_values(dict1, dict2, ignore_keys, recursive).keys())
407
408
409def get_changed_dict_values(
410    dict1: dict[str, Any],
411    dict2: dict[str, Any],
412    recursive: bool = False,
413) -> dict[str, tuple[Any, Any]]:
414    """
415    Compare 2 dicts and return dict of changed values.
416
417    dict key is the changed key, value is tuple of old and new values.
418    """
419    if not dict1 and not dict2:
420        return {}
421    if not dict1:
422        return {key: (None, value) for key, value in dict2.items()}
423    if not dict2:
424        return {key: (None, value) for key, value in dict1.items()}
425    changed_values = {}
426    for key, value in dict2.items():
427        if isinstance(value, dict) and isinstance(dict1[key], dict) and recursive:
428            changed_subvalues = get_changed_dict_values(dict1[key], value, recursive)
429            for subkey, subvalue in changed_subvalues.items():
430                changed_values[f"{key}.{subkey}"] = subvalue
431            continue
432        if key not in dict1:
433            changed_values[key] = (None, value)
434            continue
435        if dict1[key] != value:
436            changed_values[key] = (dict1[key], value)
437    return changed_values
438
439
440def get_changed_dataclass_values(
441    obj1: T,
442    obj2: T,
443    recursive: bool = False,
444) -> dict[str, tuple[Any, Any]]:
445    """
446    Compare 2 dataclass instances of the same type and return dict of changed field values.
447
448    dict key is the changed field name, value is tuple of old and new values.
449    """
450    if not (is_dataclass(obj1) and is_dataclass(obj2)):
451        raise ValueError("Both objects must be dataclass instances")
452
453    changed_values: dict[str, tuple[Any, Any]] = {}
454    for field in fields(obj1):
455        val1 = getattr(obj1, field.name, None)
456        val2 = getattr(obj2, field.name, None)
457        if recursive and is_dataclass(val1) and is_dataclass(val2):
458            sub_changes = get_changed_dataclass_values(val1, val2, recursive)
459            for sub_field, sub_value in sub_changes.items():
460                changed_values[f"{field.name}.{sub_field}"] = sub_value
461            continue
462        if recursive and isinstance(val1, dict) and isinstance(val2, dict):
463            sub_changes = get_changed_dict_values(val1, val2, recursive=recursive)
464            for sub_field, sub_value in sub_changes.items():
465                changed_values[f"{field.name}.{sub_field}"] = sub_value
466            continue
467        if val1 != val2:
468            changed_values[field.name] = (val1, val2)
469    return changed_values
470
471
472def empty_queue[T](q: asyncio.Queue[T]) -> None:
473    """Empty an asyncio Queue."""
474    for _ in range(q.qsize()):
475        try:
476            q.get_nowait()
477            q.task_done()
478        except (asyncio.QueueEmpty, ValueError):
479            pass
480
481
482async def install_package(package: str) -> None:
483    """Install package with pip, raise when install failed."""
484    LOGGER.debug("Installing python package %s", package)
485    args = ["uv", "pip", "install", "--no-cache", "--find-links", HA_WHEELS, package]
486    return_code, output = await check_output(*args)
487    if return_code != 0:
488        msg = f"Failed to install package {package}\n{output.decode()}"
489        raise RuntimeError(msg)
490
491
492async def get_package_version(pkg_name: str) -> str | None:
493    """
494    Return the version of an installed (python) package.
495
496    Will return None if the package is not found.
497    """
498    try:
499        return await asyncio.to_thread(pkg_version, pkg_name)
500    except PackageNotFoundError:
501        return None
502
503
504async def is_hass_supervisor() -> bool:
505    """Return if we're running inside the HA Supervisor (e.g. HAOS)."""
506
507    def _check() -> bool:
508        try:
509            urllib.request.urlopen("http://supervisor/core", timeout=1)
510        except urllib.error.URLError as err:
511            # this should return a 401 unauthorized if it exists
512            return getattr(err, "code", 999) == 401
513        except Exception:
514            return False
515        return False
516
517    return await asyncio.to_thread(_check)
518
519
520async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType:
521    """Return module for given provider domain and make sure the requirements are met."""
522
523    @lru_cache
524    def _get_provider_module(domain: str) -> ProviderModuleType:
525        return cast(
526            "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers")
527        )
528
529    # ensure module requirements are met
530    for requirement in requirements:
531        if "==" not in requirement:
532            # we should really get rid of unpinned requirements
533            continue
534        package_name, version = requirement.split("==", 1)
535        installed_version = await get_package_version(package_name)
536        if installed_version == "0.0.0":
537            # ignore editable installs
538            continue
539        if installed_version != version:
540            await install_package(requirement)
541
542    # try to load the module
543    try:
544        return await asyncio.to_thread(_get_provider_module, domain)
545    except ImportError:
546        # (re)install ALL requirements
547        for requirement in requirements:
548            await install_package(requirement)
549    # try loading the provider again to be safe
550    # this will fail if something else is wrong (as it should)
551    return await asyncio.to_thread(_get_provider_module, domain)
552
553
554async def has_tmpfs_mount() -> bool:
555    """Check if we have a tmpfs mount."""
556
557    def _has_tmpfs_mount() -> bool:
558        """Check if we have a tmpfs mount."""
559        try:
560            with open("/proc/mounts") as file:
561                for line in file:
562                    if "tmpfs /tmp tmpfs rw" in line:
563                        return True
564        except (FileNotFoundError, OSError, PermissionError):
565            pass
566        return False
567
568    return await asyncio.to_thread(_has_tmpfs_mount)
569
570
571async def get_free_space(folder: str) -> float:
572    """Return free space on given folderpath in GB."""
573
574    def _get_free_space(folder: str) -> float:
575        """Return free space on given folderpath in GB."""
576        try:
577            res = shutil.disk_usage(folder)
578            return res.free / float(1 << 30)
579        except (FileNotFoundError, OSError, PermissionError):
580            return 0.0
581
582    return await asyncio.to_thread(_get_free_space, folder)
583
584
585async def get_free_space_percentage(folder: str) -> float:
586    """Return free space on given folderpath in percentage."""
587
588    def _get_free_space(folder: str) -> float:
589        """Return free space on given folderpath in GB."""
590        try:
591            res = shutil.disk_usage(folder)
592            return res.free / res.total * 100
593        except (FileNotFoundError, OSError, PermissionError):
594            return 0.0
595
596    return await asyncio.to_thread(_get_free_space, folder)
597
598
599async def has_enough_space(folder: str, size: int) -> bool:
600    """Check if folder has enough free space."""
601    return await get_free_space(folder) > size
602
603
604def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
605    """Chunk bytes data into smaller chunks."""
606    for i in range(0, len(data), chunk_size):
607        yield data[i : i + chunk_size]
608
609
610async def remove_file(file_path: str) -> None:
611    """Remove file path (if it exists)."""
612    if not await asyncio.to_thread(os.path.exists, file_path):
613        return
614    await asyncio.to_thread(os.remove, file_path)
615    LOGGER.log(VERBOSE_LOG_LEVEL, "Removed file: %s", file_path)
616
617
618def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None:
619    """Get primary IP address from zeroconf discovery info."""
620    for address in discovery_info.parsed_addresses(IPVersion.V4Only):
621        if address.startswith("127"):
622            # filter out loopback address
623            continue
624        if address.startswith("169.254"):
625            # filter out APIPA address
626            continue
627        return address
628    return None
629
630
631def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None:
632    """Get port from zeroconf discovery info."""
633    return discovery_info.port
634
635
636async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None:
637    """Force close an async generator."""
638    task = asyncio.create_task(agen.__anext__())
639    task.cancel()
640    with suppress(asyncio.CancelledError, StopAsyncIteration):
641        await task
642    await agen.aclose()
643
644
645async def detect_charset(data: bytes, fallback: str = "utf-8") -> str:
646    """Detect charset of raw data."""
647    try:
648        detected: ResultDict = await asyncio.to_thread(chardet.detect, data)
649        if detected and detected["encoding"] and detected["confidence"] > 0.75:
650            assert isinstance(detected["encoding"], str)  # for type checking
651            return detected["encoding"]
652    except Exception as err:
653        LOGGER.debug("Failed to detect charset: %s", err)
654    return fallback
655
656
657def merge_dict(
658    base_dict: dict[Any, Any],
659    new_dict: dict[Any, Any],
660    allow_overwite: bool = False,
661) -> dict[Any, Any]:
662    """Merge dict without overwriting existing values."""
663    final_dict = base_dict.copy()
664    for key, value in new_dict.items():
665        if final_dict.get(key) and isinstance(value, dict):
666            final_dict[key] = merge_dict(final_dict[key], value)
667        if final_dict.get(key) and isinstance(value, tuple):
668            final_dict[key] = merge_tuples(final_dict[key], value)
669        if final_dict.get(key) and isinstance(value, list):
670            final_dict[key] = merge_lists(final_dict[key], value)
671        elif not final_dict.get(key) or allow_overwite:
672            final_dict[key] = value
673    return final_dict
674
675
676def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
677    """Merge 2 tuples."""
678    return tuple(x for x in base if x not in new) + tuple(new)
679
680
681def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
682    """Merge 2 lists."""
683    return [x for x in base if x not in new] + list(new)
684
685
686def percentage(part: float, whole: float) -> int:
687    """Calculate percentage."""
688    return int(100 * float(part) / float(whole))
689
690
691def validate_announcement_chime_url(url: str) -> bool:
692    """Validate announcement chime URL format."""
693    if not url or not url.strip():
694        return True  # Empty URL is valid
695
696    if url == ANNOUNCE_ALERT_FILE:
697        return True  # Built-in chime file is valid
698
699    try:
700        parsed = urlparse(url.strip())
701
702        if parsed.scheme not in ("http", "https"):
703            return False
704
705        if not parsed.netloc:
706            return False
707
708        path_lower = parsed.path.lower()
709        audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac")
710
711        return any(path_lower.endswith(ext) for ext in audio_extensions)
712
713    except Exception:
714        return False
715
716
717async def get_mac_address(ip_address: str) -> str | None:
718    """Get MAC address for given IP address."""
719    from getmac import get_mac_address  # noqa: PLC0415
720
721    return await asyncio.to_thread(get_mac_address, ip=ip_address)
722
723
724class TaskManager:
725    """
726    Helper class to run many tasks at once.
727
728    This is basically an alternative to asyncio.TaskGroup but this will not
729    cancel all operations when one of the tasks fails.
730    Logging of exceptions is done by the mass.create_task helper.
731    """
732
733    def __init__(self, mass: MusicAssistant, limit: int = 0):
734        """Initialize the TaskManager."""
735        self.mass = mass
736        self._tasks: list[asyncio.Task[None]] = []
737        self._semaphore = asyncio.Semaphore(limit) if limit else None
738
739    def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]:
740        """Create a new task and add it to the manager."""
741        task = self.mass.create_task(coro)
742        self._tasks.append(task)
743        return task
744
745    async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None:
746        """Create a new task with semaphore limit."""
747        assert self._semaphore is not None
748
749        def task_done_callback(_task: asyncio.Task[None]) -> None:
750            assert self._semaphore is not None  # for type checking
751            self._tasks.remove(task)
752            self._semaphore.release()
753
754        await self._semaphore.acquire()
755        task: asyncio.Task[None] = self.create_task(coro)
756        task.add_done_callback(task_done_callback)
757
758    async def __aenter__(self) -> Self:
759        """Enter context manager."""
760        return self
761
762    async def __aexit__(
763        self,
764        exc_type: type[BaseException] | None,
765        exc_val: BaseException | None,
766        exc_tb: TracebackType | None,
767    ) -> bool | None:
768        """Exit context manager."""
769        if len(self._tasks) > 0:
770            await asyncio.wait(self._tasks)
771            self._tasks.clear()
772        return None
773
774
775_R = TypeVar("_R")
776_P = ParamSpec("_P")
777
778
779def lock[**P, R](  # type: ignore[valid-type]
780    func: Callable[_P, Awaitable[_R]],
781) -> Callable[_P, Coroutine[Any, Any, _R]]:
782    """Call async function using a Lock."""
783
784    @functools.wraps(func)
785    async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
786        """Call async function using the throttler with retries."""
787        if not (func_lock := getattr(func, "lock", None)):
788            func_lock = asyncio.Lock()
789            func.lock = func_lock  # type: ignore[attr-defined]
790        async with func_lock:
791            return await func(*args, **kwargs)
792
793    return wrapper
794
795
796class TimedAsyncGenerator:
797    """
798    Async iterable that times out after a given time.
799
800    Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9
801    """
802
803    def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0):
804        """
805        Initialize the AsyncTimedIterable.
806
807        Args:
808            iterable: The async iterable to wrap.
809            timeout: The timeout in seconds for each iteration.
810        """
811
812        class AsyncTimedIterator:
813            def __init__(self) -> None:
814                self._iterator = iterable.__aiter__()
815
816            async def __anext__(self) -> Any:
817                result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout))
818                if not result:
819                    raise StopAsyncIteration
820                return result
821
822        self._factory = AsyncTimedIterator
823
824    def __aiter__(self):  # type: ignore[no-untyped-def]
825        """Return the async iterator."""
826        return self._factory()
827
828
829def guard_single_request[ProviderT: "Provider | CoreController", **P, R](
830    func: Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
831) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
832    """Guard single request to a function."""
833
834    @functools.wraps(func)
835    async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
836        mass = self.mass
837        # create a task_id dynamically based on the function and args/kwargs
838        cache_key_parts = [func.__class__.__name__, func.__name__, *args]
839        for key in sorted(kwargs.keys()):
840            cache_key_parts.append(f"{key}{kwargs[key]}")
841        task_id = ".".join(map(str, cache_key_parts))
842        task: asyncio.Task[R] = mass.create_task(
843            func,
844            self,
845            *args,
846            task_id=task_id,
847            abort_existing=False,
848            **kwargs,
849        )
850        return await task
851
852    return wrapper
853