music-assistant-server

5 KBPY
throttle_retry.py
5 KB139 lines • python
1"""Context manager using asyncio_throttle that catches and re-raises RetriesExhausted."""
2
3import asyncio
4import functools
5import logging
6import time
7from collections import deque
8from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
9from contextlib import asynccontextmanager
10from contextvars import ContextVar
11from types import TracebackType
12from typing import TYPE_CHECKING, Any, Concatenate
13
14from music_assistant_models.errors import ResourceTemporarilyUnavailable, RetriesExhausted
15
16from music_assistant.constants import MASS_LOGGER_NAME
17
18if TYPE_CHECKING:
19    from music_assistant.models.provider import Provider
20
21LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.throttle_retry")
22
23BYPASS_THROTTLER: ContextVar[bool] = ContextVar("BYPASS_THROTTLER", default=False)
24
25
26class Throttler:
27    """asyncio_throttle (https://github.com/hallazzang/asyncio-throttle).
28
29    With improvements:
30    - Accurate sleep without "busy waiting" (PR #4)
31    - Return the delay caused by acquire()
32    """
33
34    def __init__(self, rate_limit: int, period: float = 1.0) -> None:
35        """Initialize the Throttler."""
36        self.rate_limit = rate_limit
37        self.period = period
38        self._task_logs: deque[float] = deque()
39
40    def _flush(self) -> None:
41        now = time.monotonic()
42        while self._task_logs:
43            if now - self._task_logs[0] > self.period:
44                self._task_logs.popleft()
45            else:
46                break
47
48    async def acquire(self) -> float:
49        """Acquire a free slot from the Throttler, returns the throttled time."""
50        cur_time = time.monotonic()
51        start_time = cur_time
52        while True:
53            self._flush()
54            if len(self._task_logs) < self.rate_limit:
55                break
56            # sleep the exact amount of time until the oldest task can be flushed
57            time_to_release = self._task_logs[0] + self.period - cur_time
58            await asyncio.sleep(time_to_release)
59            cur_time = time.monotonic()
60
61        self._task_logs.append(cur_time)
62        return cur_time - start_time  # exactly 0 if not throttled
63
64    async def __aenter__(self) -> float:
65        """Wait until the lock is acquired, return the time delay."""
66        return await self.acquire()
67
68    async def __aexit__(
69        self,
70        exc_type: type[BaseException] | None,
71        exc_val: BaseException | None,
72        exc_tb: TracebackType | None,
73    ) -> bool | None:
74        """Nothing to do on exit."""
75
76
77class ThrottlerManager:
78    """Throttler manager that extends asyncio Throttle by retrying."""
79
80    def __init__(
81        self, rate_limit: int, period: float = 1, retry_attempts: int = 5, initial_backoff: int = 5
82    ):
83        """Initialize the AsyncThrottledContextManager."""
84        self.retry_attempts = retry_attempts
85        self.initial_backoff = initial_backoff
86        self.throttler = Throttler(rate_limit, period)
87
88    @asynccontextmanager
89    async def acquire(self) -> AsyncGenerator[float, None]:
90        """Acquire a free slot from the Throttler, returns the throttled time."""
91        if BYPASS_THROTTLER.get():
92            yield 0
93        else:
94            yield await self.throttler.acquire()
95
96    @asynccontextmanager
97    async def bypass(self) -> AsyncGenerator[None, None]:
98        """Bypass the throttler."""
99        try:
100            token = BYPASS_THROTTLER.set(True)
101            yield None
102        finally:
103            BYPASS_THROTTLER.reset(token)
104
105
106def throttle_with_retries[ProviderT: "Provider", **P, R](
107    func: Callable[Concatenate[ProviderT, P], Awaitable[R]],
108) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
109    """Call async function using the throttler with retries."""
110
111    @functools.wraps(func)
112    async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
113        """Call async function using the throttler with retries."""
114        # the trottler attribute must be present on the class
115        throttler: ThrottlerManager = self.throttler  # type: ignore[attr-defined]
116        backoff_time = throttler.initial_backoff
117        async with throttler.acquire() as delay:
118            if delay != 0:
119                self.logger.debug(
120                    "%s was delayed for %.3f secs due to throttling", func.__name__, delay
121                )
122            for attempt in range(throttler.retry_attempts):
123                try:
124                    return await func(self, *args, **kwargs)
125                except ResourceTemporarilyUnavailable as e:
126                    backoff_time = e.backoff_time or backoff_time
127                    self.logger.info(
128                        f"Attempt {attempt + 1}/{throttler.retry_attempts} failed: {e}"
129                    )
130                    if attempt < throttler.retry_attempts - 1:
131                        self.logger.info(f"Retrying in {backoff_time} seconds...")
132                        await asyncio.sleep(backoff_time)
133                        backoff_time *= 2
134            else:  # noqa: PLW0120
135                msg = f"Retries exhausted, failed after {throttler.retry_attempts} attempts"
136                raise RetriesExhausted(msg)
137
138    return wrapper
139