/
/
/
1"""Provides a simple stateless caching system."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import logging
8import os
9import time
10from collections import OrderedDict
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterator, MutableMapping
12from contextlib import asynccontextmanager
13from contextvars import ContextVar
14from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, cast, get_type_hints
15
16from music_assistant_models.config_entries import ConfigEntry, ConfigValueType
17from music_assistant_models.enums import ConfigEntryType
18
19from music_assistant.constants import DB_TABLE_CACHE, DB_TABLE_SETTINGS, MASS_LOGGER_NAME
20from music_assistant.helpers.api import parse_value
21from music_assistant.helpers.database import DatabaseConnection
22from music_assistant.helpers.json import async_json_loads, json_dumps
23from music_assistant.models.core_controller import CoreController
24
25if TYPE_CHECKING:
26 from music_assistant_models.config_entries import CoreConfig
27
28 from music_assistant import MusicAssistant
29 from music_assistant.models.provider import Provider
30
31LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.cache")
32CONF_CLEAR_CACHE = "clear_cache"
33DEFAULT_CACHE_EXPIRATION = 86400 * 30 # 30 days
34DB_SCHEMA_VERSION = 6
35
36BYPASS_CACHE: ContextVar[bool] = ContextVar("BYPASS_CACHE", default=False)
37
38
39class CacheController(CoreController):
40 """Basic cache controller using both memory and database."""
41
42 domain: str = "cache"
43
44 def __init__(self, mass: MusicAssistant) -> None:
45 """Initialize core controller."""
46 super().__init__(mass)
47 self.database: DatabaseConnection | None = None
48 self._mem_cache = MemoryCache(500)
49 self.manifest.name = "Cache controller"
50 self.manifest.description = (
51 "Music Assistant's core controller for caching data throughout the application."
52 )
53 self.manifest.icon = "memory"
54
55 async def get_config_entries(
56 self,
57 action: str | None = None,
58 values: dict[str, ConfigValueType] | None = None,
59 ) -> tuple[ConfigEntry, ...]:
60 """Return all Config Entries for this core module (if any)."""
61 if action == CONF_CLEAR_CACHE:
62 await self.clear()
63 return (
64 ConfigEntry(
65 key=CONF_CLEAR_CACHE,
66 type=ConfigEntryType.LABEL,
67 label="The cache has been cleared",
68 ),
69 )
70 return (
71 ConfigEntry(
72 key=CONF_CLEAR_CACHE,
73 type=ConfigEntryType.ACTION,
74 label="Clear cache",
75 description="Reset/clear all items in the cache. ",
76 ),
77 )
78
79 async def setup(self, config: CoreConfig) -> None:
80 """Async initialize of cache module."""
81 self.logger.info("Initializing cache controller...")
82 await self._setup_database()
83 self.__schedule_cleanup_task()
84
85 async def close(self) -> None:
86 """Cleanup on exit."""
87 if self.database:
88 await self.database.close()
89
90 async def get(
91 self,
92 key: str,
93 provider: str = "default",
94 category: int = 0,
95 checksum: str | int | None = None,
96 default: Any = None,
97 allow_bypass: bool = True,
98 ) -> Any:
99 """Get object from cache and return the results.
100
101 - key: the (unique) lookup key of the cache object as reference
102 - provider: optional provider id to group cache objects
103 - category: optional category to group cache objects
104 - checksum: optional argument to check if the checksum in the
105 cache object matches the checksum provided
106 - default: value to return if no cache object is found
107 """
108 assert self.database is not None
109 assert key, "No key provided"
110 if allow_bypass and BYPASS_CACHE.get():
111 return default
112 cur_time = int(time.time())
113 if checksum is not None and not isinstance(checksum, str):
114 checksum = str(checksum)
115 # try memory cache first
116 memory_key = f"{provider}/{category}/{key}"
117 cache_data = self._mem_cache.get(memory_key)
118 if cache_data and (not checksum or cache_data[1] == checksum) and cache_data[2] >= cur_time:
119 return cache_data[0]
120 # fall back to db cache
121 if (
122 (
123 db_row := await self.database.get_row(
124 DB_TABLE_CACHE, {"category": category, "provider": provider, "key": key}
125 )
126 )
127 and db_row["expires"] >= cur_time
128 and (not checksum or db_row["checksum"] == checksum)
129 ):
130 try:
131 data = await async_json_loads(db_row["data"])
132 except Exception as exc:
133 LOGGER.error(
134 "Error parsing cache data for %s: %s",
135 memory_key,
136 str(exc),
137 exc_info=exc if self.logger.isEnabledFor(10) else None,
138 )
139 else:
140 # also store in memory cache for faster access
141 self._mem_cache[memory_key] = (
142 data,
143 db_row["checksum"],
144 db_row["expires"],
145 )
146 return data
147 return default
148
149 async def set(
150 self,
151 key: str,
152 data: Any,
153 expiration: int = DEFAULT_CACHE_EXPIRATION,
154 provider: str = "default",
155 category: int = 0,
156 checksum: str | None = None,
157 persistent: bool = False,
158 ) -> None:
159 """
160 Set data in cache.
161
162 - key: the (unique) lookup key of the cache object as reference
163 - data: the actual data to store in the cache
164 - expiration: time in seconds the cache object should be valid
165 - provider: optional provider id to group cache objects
166 - category: optional category to group cache objects
167 - checksum: optional argument to store with the cache object
168 - persistent: if True the cache object will not be deleted when clearing the cache
169 """
170 assert self.database is not None
171 if not key:
172 return
173 if checksum is not None:
174 checksum = str(checksum)
175 expires = int(time.time() + expiration)
176 memory_key = f"{provider}/{category}/{key}"
177 self._mem_cache[memory_key] = (data, checksum, expires)
178 if (expires - time.time()) < 1800:
179 # do not cache items in db with short expiration
180 return
181 data = await asyncio.to_thread(json_dumps, data)
182 await self.database.insert_or_replace(
183 DB_TABLE_CACHE,
184 {
185 "category": category,
186 "provider": provider,
187 "key": key,
188 "expires": expires,
189 "checksum": checksum,
190 "data": data,
191 "persistent": persistent,
192 },
193 )
194
195 async def delete(
196 self, key: str | None, category: int | None = None, provider: str | None = None
197 ) -> None:
198 """Delete data from cache."""
199 assert self.database is not None
200 match: dict[str, str | int] = {}
201 if key is not None:
202 match["key"] = key
203 if category is not None:
204 match["category"] = category
205 if provider is not None:
206 match["provider"] = provider
207 if key is not None and category is not None and provider is not None:
208 self._mem_cache.pop(f"{provider}/{category}/{key}", None)
209 else:
210 self._mem_cache.clear()
211 await self.database.delete(DB_TABLE_CACHE, match)
212
213 async def clear(
214 self,
215 key_filter: str | None = None,
216 category_filter: int | None = None,
217 provider_filter: str | None = None,
218 include_persistent: bool = False,
219 ) -> None:
220 """Clear all/partial items from cache."""
221 assert self.database is not None
222 self._mem_cache.clear()
223 self.logger.info("Clearing database...")
224 query_parts: list[str] = []
225 if category_filter is not None:
226 query_parts.append(f"category = {category_filter}")
227 if provider_filter is not None:
228 query_parts.append(f"provider LIKE '%{provider_filter}%'")
229 if key_filter is not None:
230 query_parts.append(f"key LIKE '%{key_filter}%'")
231 if not include_persistent:
232 query_parts.append("persistent = 0")
233 query = "WHERE " + " AND ".join(query_parts) if query_parts else None
234 await self.database.delete(DB_TABLE_CACHE, query=query)
235 self.logger.info("Clearing database DONE")
236
237 async def auto_cleanup(self) -> None:
238 """Run scheduled auto cleanup task."""
239 assert self.database is not None
240 self.logger.debug("Running automatic cleanup...")
241 # simply reset the memory cache
242 self._mem_cache.clear()
243 cur_timestamp = int(time.time())
244 cleaned_records = 0
245 for db_row in await self.database.get_rows(DB_TABLE_CACHE):
246 # clean up db cache object only if expired
247 if db_row["expires"] < cur_timestamp:
248 await self.database.delete(DB_TABLE_CACHE, {"id": db_row["id"]})
249 cleaned_records += 1
250 await asyncio.sleep(0) # yield to eventloop
251 self.logger.debug("Automatic cleanup finished (cleaned up %s records)", cleaned_records)
252
253 @asynccontextmanager
254 async def handle_refresh(self, bypass: bool) -> AsyncGenerator[None, None]:
255 """Handle the cache bypass."""
256 try:
257 token = BYPASS_CACHE.set(bypass)
258 yield None
259 finally:
260 BYPASS_CACHE.reset(token)
261
262 async def _setup_database(self) -> None:
263 """Initialize database."""
264 db_path = os.path.join(self.mass.cache_path, "cache.db")
265 self.database = DatabaseConnection(db_path)
266 await self.database.setup()
267
268 # always create db tables if they don't exist to prevent errors trying to access them later
269 await self.__create_database_tables()
270 try:
271 if db_row := await self.database.get_row(DB_TABLE_SETTINGS, {"key": "version"}):
272 prev_version = int(db_row["value"])
273 else:
274 prev_version = 0
275 except (KeyError, ValueError):
276 prev_version = 0
277
278 if prev_version not in (0, DB_SCHEMA_VERSION):
279 LOGGER.warning(
280 "Performing database migration from %s to %s",
281 prev_version,
282 DB_SCHEMA_VERSION,
283 )
284
285 if prev_version < DB_SCHEMA_VERSION:
286 # for now just keep it simple and just recreate the table(s)
287 await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_CACHE}")
288
289 # recreate missing table(s)
290 await self.__create_database_tables()
291
292 # store current schema version
293 await self.database.insert_or_replace(
294 DB_TABLE_SETTINGS,
295 {"key": "version", "value": str(DB_SCHEMA_VERSION), "type": "str"},
296 )
297 await self.__create_database_indexes()
298 # compact db (vacuum) at startup
299 self.logger.debug("Compacting database...")
300 try:
301 await self.database.vacuum()
302 except Exception as err:
303 self.logger.warning("Database vacuum failed: %s", str(err))
304 else:
305 self.logger.debug("Compacting database done")
306
307 async def __create_database_tables(self) -> None:
308 """Create database table(s)."""
309 assert self.database is not None
310 await self.database.execute(
311 f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_SETTINGS}(
312 key TEXT PRIMARY KEY,
313 value TEXT,
314 type TEXT
315 );"""
316 )
317 await self.database.execute(
318 f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_CACHE}(
319 [id] INTEGER PRIMARY KEY AUTOINCREMENT,
320 [category] INTEGER NOT NULL DEFAULT 0,
321 [key] TEXT NOT NULL,
322 [provider] TEXT NOT NULL,
323 [expires] INTEGER NOT NULL,
324 [data] TEXT NULL,
325 [checksum] TEXT NULL,
326 [persistent] INTEGER NOT NULL DEFAULT 0,
327 UNIQUE(category, key, provider)
328 )"""
329 )
330
331 await self.database.commit()
332
333 async def __create_database_indexes(self) -> None:
334 """Create database indexes."""
335 assert self.database is not None
336 await self.database.execute(
337 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_idx "
338 f"ON {DB_TABLE_CACHE}(category);"
339 )
340 await self.database.execute(
341 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_idx ON {DB_TABLE_CACHE}(key);"
342 )
343 await self.database.execute(
344 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_provider_idx "
345 f"ON {DB_TABLE_CACHE}(provider);"
346 )
347 await self.database.execute(
348 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_key_idx "
349 f"ON {DB_TABLE_CACHE}(category,key);"
350 )
351 await self.database.execute(
352 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_provider_idx "
353 f"ON {DB_TABLE_CACHE}(category,provider);"
354 )
355 await self.database.execute(
356 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_key_provider_idx "
357 f"ON {DB_TABLE_CACHE}(category,key,provider);"
358 )
359 await self.database.execute(
360 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_provider_idx "
361 f"ON {DB_TABLE_CACHE}(key,provider);"
362 )
363 await self.database.commit()
364
365 def __schedule_cleanup_task(self) -> None:
366 """Schedule the cleanup task."""
367 self.mass.create_task(self.auto_cleanup())
368 # reschedule self
369 self.mass.loop.call_later(3600, self.__schedule_cleanup_task)
370
371
372Param = ParamSpec("Param")
373RetType = TypeVar("RetType")
374
375
376ProviderT = TypeVar("ProviderT", bound="Provider | CoreController")
377P = ParamSpec("P")
378R = TypeVar("R")
379
380
381def use_cache(
382 expiration: int = DEFAULT_CACHE_EXPIRATION,
383 category: int = 0,
384 persistent: bool = False,
385 cache_checksum: str | None = None,
386 allow_bypass: bool = True,
387) -> Callable[
388 [Callable[Concatenate[ProviderT, P], Awaitable[R]]],
389 Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
390]:
391 """Return decorator that can be used to cache a method's result."""
392
393 def _decorator(
394 func: Callable[Concatenate[ProviderT, P], Awaitable[R]],
395 ) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
396 @functools.wraps(func)
397 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
398 cache = self.mass.cache
399 provider_id = getattr(self, "instance_id", self.domain)
400
401 # create a cache key dynamically based on the (remaining) args/kwargs
402 cache_key_parts = [func.__name__, *args]
403 for key in sorted(kwargs.keys()):
404 cache_key_parts.append(f"{key}{kwargs[key]}")
405 cache_key = ".".join(map(str, cache_key_parts))
406 # try to retrieve data from the cache
407 cachedata = await cache.get(
408 cache_key,
409 provider=provider_id,
410 checksum=cache_checksum,
411 category=category,
412 allow_bypass=allow_bypass,
413 )
414 if cachedata is not None:
415 type_hints = get_type_hints(func)
416 return cast("R", parse_value(func.__name__, cachedata, type_hints["return"]))
417 # get data from method/provider
418 result = await func(self, *args, **kwargs)
419 # store result in cache (but don't await)
420 self.mass.create_task(
421 cache.set(
422 key=cache_key,
423 data=result,
424 expiration=expiration,
425 provider=provider_id,
426 category=category,
427 checksum=cache_checksum,
428 persistent=persistent,
429 )
430 )
431 return result
432
433 return wrapper
434
435 return _decorator
436
437
438class MemoryCache(MutableMapping[str, Any]):
439 """Simple limited in-memory cache implementation."""
440
441 def __init__(self, maxlen: int) -> None:
442 """Initialize."""
443 self._maxlen = maxlen
444 self.d: OrderedDict[str, Any] = OrderedDict()
445
446 @property
447 def maxlen(self) -> int:
448 """Return max length."""
449 return self._maxlen
450
451 def get(self, key: str, default: Any = None) -> Any:
452 """Return item or default."""
453 return self.d.get(key, default)
454
455 def pop(self, key: str, default: Any = None) -> Any:
456 """Pop item from collection."""
457 return self.d.pop(key, default)
458
459 def __getitem__(self, key: str) -> Any:
460 """Get item."""
461 self.d.move_to_end(key)
462 return self.d[key]
463
464 def __setitem__(self, key: str, value: Any) -> None:
465 """Set item."""
466 if key in self.d:
467 self.d.move_to_end(key)
468 elif len(self.d) == self.maxlen:
469 self.d.popitem(last=False)
470 self.d[key] = value
471
472 def __delitem__(self, key: str) -> None:
473 """Delete item."""
474 del self.d[key]
475
476 def __iter__(self) -> Iterator[str]:
477 """Iterate items."""
478 return self.d.__iter__()
479
480 def __len__(self) -> int:
481 """Return length."""
482 return len(self.d)
483
484 def clear(self) -> None:
485 """Clear cache."""
486 self.d.clear()
487