/
/
/
1"""Context manager using asyncio_throttle that catches and re-raises RetriesExhausted."""
2
3import asyncio
4import functools
5import logging
6import time
7from collections import deque
8from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
9from contextlib import asynccontextmanager
10from contextvars import ContextVar
11from types import TracebackType
12from typing import TYPE_CHECKING, Any, Concatenate
13
14from music_assistant_models.errors import ResourceTemporarilyUnavailable, RetriesExhausted
15
16from music_assistant.constants import MASS_LOGGER_NAME
17
18if TYPE_CHECKING:
19 from music_assistant.models.provider import Provider
20
21LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.throttle_retry")
22
23BYPASS_THROTTLER: ContextVar[bool] = ContextVar("BYPASS_THROTTLER", default=False)
24
25
26class Throttler:
27 """asyncio_throttle (https://github.com/hallazzang/asyncio-throttle).
28
29 With improvements:
30 - Accurate sleep without "busy waiting" (PR #4)
31 - Return the delay caused by acquire()
32 """
33
34 def __init__(self, rate_limit: int, period: float = 1.0) -> None:
35 """Initialize the Throttler."""
36 self.rate_limit = rate_limit
37 self.period = period
38 self._task_logs: deque[float] = deque()
39
40 def _flush(self) -> None:
41 now = time.monotonic()
42 while self._task_logs:
43 if now - self._task_logs[0] > self.period:
44 self._task_logs.popleft()
45 else:
46 break
47
48 async def acquire(self) -> float:
49 """Acquire a free slot from the Throttler, returns the throttled time."""
50 cur_time = time.monotonic()
51 start_time = cur_time
52 while True:
53 self._flush()
54 if len(self._task_logs) < self.rate_limit:
55 break
56 # sleep the exact amount of time until the oldest task can be flushed
57 time_to_release = self._task_logs[0] + self.period - cur_time
58 await asyncio.sleep(time_to_release)
59 cur_time = time.monotonic()
60
61 self._task_logs.append(cur_time)
62 return cur_time - start_time # exactly 0 if not throttled
63
64 async def __aenter__(self) -> float:
65 """Wait until the lock is acquired, return the time delay."""
66 return await self.acquire()
67
68 async def __aexit__(
69 self,
70 exc_type: type[BaseException] | None,
71 exc_val: BaseException | None,
72 exc_tb: TracebackType | None,
73 ) -> bool | None:
74 """Nothing to do on exit."""
75
76
77class ThrottlerManager:
78 """Throttler manager that extends asyncio Throttle by retrying."""
79
80 def __init__(
81 self, rate_limit: int, period: float = 1, retry_attempts: int = 5, initial_backoff: int = 5
82 ):
83 """Initialize the AsyncThrottledContextManager."""
84 self.retry_attempts = retry_attempts
85 self.initial_backoff = initial_backoff
86 self.throttler = Throttler(rate_limit, period)
87
88 @asynccontextmanager
89 async def acquire(self) -> AsyncGenerator[float, None]:
90 """Acquire a free slot from the Throttler, returns the throttled time."""
91 if BYPASS_THROTTLER.get():
92 yield 0
93 else:
94 yield await self.throttler.acquire()
95
96 @asynccontextmanager
97 async def bypass(self) -> AsyncGenerator[None, None]:
98 """Bypass the throttler."""
99 try:
100 token = BYPASS_THROTTLER.set(True)
101 yield None
102 finally:
103 BYPASS_THROTTLER.reset(token)
104
105
106def throttle_with_retries[ProviderT: "Provider", **P, R](
107 func: Callable[Concatenate[ProviderT, P], Awaitable[R]],
108) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
109 """Call async function using the throttler with retries."""
110
111 @functools.wraps(func)
112 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
113 """Call async function using the throttler with retries."""
114 # the trottler attribute must be present on the class
115 throttler: ThrottlerManager = self.throttler # type: ignore[attr-defined]
116 backoff_time = throttler.initial_backoff
117 async with throttler.acquire() as delay:
118 if delay != 0:
119 self.logger.debug(
120 "%s was delayed for %.3f secs due to throttling", func.__name__, delay
121 )
122 for attempt in range(throttler.retry_attempts):
123 try:
124 return await func(self, *args, **kwargs)
125 except ResourceTemporarilyUnavailable as e:
126 backoff_time = e.backoff_time or backoff_time
127 self.logger.info(
128 f"Attempt {attempt + 1}/{throttler.retry_attempts} failed: {e}"
129 )
130 if attempt < throttler.retry_attempts - 1:
131 self.logger.info(f"Retrying in {backoff_time} seconds...")
132 await asyncio.sleep(backoff_time)
133 backoff_time *= 2
134 else: # noqa: PLW0120
135 msg = f"Retries exhausted, failed after {throttler.retry_attempts} attempts"
136 raise RetriesExhausted(msg)
137
138 return wrapper
139