/
/
/
1"""Provides a simple stateless caching system."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import logging
8import os
9import time
10from collections import OrderedDict
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterator, MutableMapping
12from contextlib import asynccontextmanager
13from contextvars import ContextVar
14from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, cast, get_type_hints
15
16from music_assistant_models.config_entries import ConfigEntry, ConfigValueType
17from music_assistant_models.enums import ConfigEntryType
18
19from music_assistant.constants import DB_TABLE_CACHE, DB_TABLE_SETTINGS, MASS_LOGGER_NAME
20from music_assistant.helpers.api import parse_value
21from music_assistant.helpers.database import DatabaseConnection
22from music_assistant.helpers.json import async_json_loads, json_dumps
23from music_assistant.models.core_controller import CoreController
24
25if TYPE_CHECKING:
26 from music_assistant_models.config_entries import CoreConfig
27
28 from music_assistant import MusicAssistant
29 from music_assistant.models.provider import Provider
30
31LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.cache")
32CONF_CLEAR_CACHE = "clear_cache"
33DEFAULT_CACHE_EXPIRATION = 86400 * 30 # 30 days
34DB_SCHEMA_VERSION = 7
35
36BYPASS_CACHE: ContextVar[bool] = ContextVar("BYPASS_CACHE", default=False)
37
38
39class CacheController(CoreController):
40 """Basic cache controller using both memory and database."""
41
42 domain: str = "cache"
43
44 def __init__(self, mass: MusicAssistant) -> None:
45 """Initialize core controller."""
46 super().__init__(mass)
47 self.database: DatabaseConnection | None = None
48 self._mem_cache = MemoryCache(500)
49 self.manifest.name = "Cache controller"
50 self.manifest.description = (
51 "Music Assistant's core controller for caching data throughout the application."
52 )
53 self.manifest.icon = "memory"
54
55 async def get_config_entries(
56 self,
57 action: str | None = None,
58 values: dict[str, ConfigValueType] | None = None,
59 ) -> tuple[ConfigEntry, ...]:
60 """Return all Config Entries for this core module (if any)."""
61 if action == CONF_CLEAR_CACHE:
62 await self.clear()
63 return (
64 ConfigEntry(
65 key=CONF_CLEAR_CACHE,
66 type=ConfigEntryType.LABEL,
67 label="The cache has been cleared",
68 ),
69 )
70 return (
71 ConfigEntry(
72 key=CONF_CLEAR_CACHE,
73 type=ConfigEntryType.ACTION,
74 label="Clear cache",
75 description="Reset/clear all items in the cache. ",
76 ),
77 )
78
79 async def setup(self, config: CoreConfig) -> None:
80 """Async initialize of cache module."""
81 self.logger.info("Initializing cache controller...")
82 await self._setup_database()
83 self.__schedule_cleanup_task()
84
85 async def close(self) -> None:
86 """Cleanup on exit."""
87 if self.database:
88 await self.database.close()
89
90 async def get(
91 self,
92 key: str,
93 provider: str = "default",
94 category: int = 0,
95 checksum: str | int | None = None,
96 default: Any = None,
97 allow_bypass: bool = True,
98 ) -> Any:
99 """Get object from cache and return the results.
100
101 - key: the (unique) lookup key of the cache object as reference
102 - provider: optional provider id to group cache objects
103 - category: optional category to group cache objects
104 - checksum: optional argument to check if the checksum in the
105 cache object matches the checksum provided
106 - default: value to return if no cache object is found
107 """
108 assert self.database is not None
109 assert key, "No key provided"
110 if allow_bypass and BYPASS_CACHE.get():
111 return default
112 cur_time = int(time.time())
113 if checksum is not None and not isinstance(checksum, str):
114 checksum = str(checksum)
115 # try memory cache first
116 memory_key = f"{provider}/{category}/{key}"
117 cache_data = self._mem_cache.get(memory_key)
118 if cache_data and (not checksum or cache_data[1] == checksum) and cache_data[2] >= cur_time:
119 return cache_data[0]
120 # fall back to db cache
121 if (
122 (
123 db_row := await self.database.get_row(
124 DB_TABLE_CACHE, {"category": category, "provider": provider, "key": key}
125 )
126 )
127 and db_row["expires"] >= cur_time
128 and (not checksum or db_row["checksum"] == checksum)
129 ):
130 try:
131 data = await async_json_loads(db_row["data"])
132 except Exception as exc:
133 LOGGER.error(
134 "Error parsing cache data for %s: %s",
135 memory_key,
136 str(exc),
137 exc_info=exc if self.logger.isEnabledFor(10) else None,
138 )
139 else:
140 # also store in memory cache for faster access
141 self._mem_cache[memory_key] = (
142 data,
143 db_row["checksum"],
144 db_row["expires"],
145 )
146 return data
147 return default
148
149 async def set(
150 self,
151 key: str,
152 data: Any,
153 expiration: int = DEFAULT_CACHE_EXPIRATION,
154 provider: str = "default",
155 category: int = 0,
156 checksum: str | None = None,
157 persistent: bool = False,
158 ) -> None:
159 """
160 Set data in cache.
161
162 - key: the (unique) lookup key of the cache object as reference
163 - data: the actual data to store in the cache
164 - expiration: time in seconds the cache object should be valid
165 - provider: optional provider id to group cache objects
166 - category: optional category to group cache objects
167 - checksum: optional argument to store with the cache object
168 - persistent: if True the cache object will not be deleted when clearing the cache
169 """
170 assert self.database is not None
171 if not key:
172 return
173 if checksum is not None:
174 checksum = str(checksum)
175 expires = int(time.time() + expiration)
176 memory_key = f"{provider}/{category}/{key}"
177 self._mem_cache[memory_key] = (data, checksum, expires)
178 if (expires - time.time()) < 1800:
179 # do not cache items in db with short expiration
180 return
181 data = await asyncio.to_thread(json_dumps, data)
182 await self.database.insert_or_replace(
183 DB_TABLE_CACHE,
184 {
185 "category": category,
186 "provider": provider,
187 "key": key,
188 "expires": expires,
189 "checksum": checksum,
190 "data": data,
191 "persistent": persistent,
192 },
193 )
194
195 async def delete(
196 self, key: str | None, category: int | None = None, provider: str | None = None
197 ) -> None:
198 """Delete data from cache."""
199 assert self.database is not None
200 match: dict[str, str | int] = {}
201 if key is not None:
202 match["key"] = key
203 if category is not None:
204 match["category"] = category
205 if provider is not None:
206 match["provider"] = provider
207 if key is not None and category is not None and provider is not None:
208 self._mem_cache.pop(f"{provider}/{category}/{key}", None)
209 else:
210 self._mem_cache.clear()
211 await self.database.delete(DB_TABLE_CACHE, match)
212
213 async def clear(
214 self,
215 key_filter: str | None = None,
216 category_filter: int | None = None,
217 provider_filter: str | None = None,
218 include_persistent: bool = False,
219 ) -> None:
220 """Clear all/partial items from cache."""
221 assert self.database is not None
222 self._mem_cache.clear()
223 self.logger.info("Clearing database...")
224 query_parts: list[str] = []
225 if category_filter is not None:
226 query_parts.append(f"category = {category_filter}")
227 if provider_filter is not None:
228 query_parts.append(f"provider LIKE '%{provider_filter}%'")
229 if key_filter is not None:
230 query_parts.append(f"key LIKE '%{key_filter}%'")
231 if not include_persistent:
232 query_parts.append("persistent = 0")
233 query = "WHERE " + " AND ".join(query_parts) if query_parts else None
234 await self.database.delete(DB_TABLE_CACHE, query=query)
235 self.logger.info("Clearing database DONE")
236
237 async def auto_cleanup(self) -> None:
238 """Run scheduled auto cleanup task."""
239 assert self.database is not None
240 self.logger.debug("Running automatic cleanup...")
241 # simply reset the memory cache
242 self._mem_cache.clear()
243 cur_timestamp = int(time.time())
244 cleaned_records = 0
245 for db_row in await self.database.get_rows(DB_TABLE_CACHE):
246 # clean up db cache object only if expired
247 if db_row["expires"] < cur_timestamp:
248 await self.database.delete(DB_TABLE_CACHE, {"id": db_row["id"]})
249 cleaned_records += 1
250 await asyncio.sleep(0) # yield to eventloop
251 self.logger.debug("Automatic cleanup finished (cleaned up %s records)", cleaned_records)
252
253 @asynccontextmanager
254 async def handle_refresh(self, bypass: bool) -> AsyncGenerator[None, None]:
255 """Handle the cache bypass."""
256 try:
257 token = BYPASS_CACHE.set(bypass)
258 yield None
259 finally:
260 BYPASS_CACHE.reset(token)
261
262 async def _setup_database(self) -> None:
263 """Initialize database."""
264 db_path = os.path.join(self.mass.cache_path, "cache.db")
265 self.database = DatabaseConnection(db_path)
266 await self.database.setup()
267
268 # always create db tables if they don't exist to prevent errors trying to access them later
269 await self.__create_database_tables()
270 try:
271 if db_row := await self.database.get_row(DB_TABLE_SETTINGS, {"key": "version"}):
272 prev_version = int(db_row["value"])
273 else:
274 prev_version = 0
275 except (KeyError, ValueError):
276 prev_version = 0
277
278 if prev_version not in (0, DB_SCHEMA_VERSION):
279 LOGGER.warning(
280 "Performing database migration from %s to %s",
281 prev_version,
282 DB_SCHEMA_VERSION,
283 )
284 try:
285 await self.__migrate_database(prev_version)
286 except Exception as err:
287 LOGGER.warning("Cache database migration failed: %s, resetting cache", err)
288 await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_CACHE}")
289 await self.__create_database_tables()
290
291 # store current schema version
292 await self.database.insert_or_replace(
293 DB_TABLE_SETTINGS,
294 {"key": "version", "value": str(DB_SCHEMA_VERSION), "type": "str"},
295 )
296 await self.__create_database_indexes()
297 # compact db (vacuum) at startup
298 self.logger.debug("Compacting database...")
299 try:
300 await self.database.vacuum()
301 except Exception as err:
302 self.logger.warning("Database vacuum failed: %s", str(err))
303 else:
304 self.logger.debug("Compacting database done")
305
306 async def __create_database_tables(self) -> None:
307 """Create database table(s)."""
308 assert self.database is not None
309 await self.database.execute(
310 f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_SETTINGS}(
311 key TEXT PRIMARY KEY,
312 value TEXT,
313 type TEXT
314 );"""
315 )
316 await self.database.execute(
317 f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_CACHE}(
318 [id] INTEGER PRIMARY KEY AUTOINCREMENT,
319 [category] INTEGER NOT NULL DEFAULT 0,
320 [key] TEXT NOT NULL,
321 [provider] TEXT NOT NULL,
322 [expires] INTEGER NOT NULL,
323 [data] TEXT NULL,
324 [checksum] TEXT NULL,
325 [persistent] INTEGER NOT NULL DEFAULT 0,
326 UNIQUE(category, key, provider)
327 )"""
328 )
329
330 await self.database.commit()
331
332 async def __create_database_indexes(self) -> None:
333 """Create database indexes."""
334 assert self.database is not None
335 await self.database.execute(
336 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_idx "
337 f"ON {DB_TABLE_CACHE}(category);"
338 )
339 await self.database.execute(
340 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_idx ON {DB_TABLE_CACHE}(key);"
341 )
342 await self.database.execute(
343 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_provider_idx "
344 f"ON {DB_TABLE_CACHE}(provider);"
345 )
346 await self.database.execute(
347 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_key_idx "
348 f"ON {DB_TABLE_CACHE}(category,key);"
349 )
350 await self.database.execute(
351 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_provider_idx "
352 f"ON {DB_TABLE_CACHE}(category,provider);"
353 )
354 await self.database.execute(
355 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_key_provider_idx "
356 f"ON {DB_TABLE_CACHE}(category,key,provider);"
357 )
358 await self.database.execute(
359 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_provider_idx "
360 f"ON {DB_TABLE_CACHE}(key,provider);"
361 )
362 await self.database.commit()
363
364 async def __migrate_database(self, prev_version: int) -> None:
365 """Perform a database migration."""
366 assert self.database is not None
367 if prev_version <= 6:
368 # clear spotify cache entries to fix bloated cache from playlist pagination bug
369 await self.database.delete(DB_TABLE_CACHE, query="WHERE provider LIKE '%spotify%'")
370 await self.database.commit()
371
372 def __schedule_cleanup_task(self) -> None:
373 """Schedule the cleanup task."""
374 self.mass.create_task(self.auto_cleanup())
375 # reschedule self
376 self.mass.loop.call_later(3600, self.__schedule_cleanup_task)
377
378
379Param = ParamSpec("Param")
380RetType = TypeVar("RetType")
381
382
383ProviderT = TypeVar("ProviderT", bound="Provider | CoreController")
384P = ParamSpec("P")
385R = TypeVar("R")
386
387
388def use_cache(
389 expiration: int = DEFAULT_CACHE_EXPIRATION,
390 category: int = 0,
391 persistent: bool = False,
392 cache_checksum: str | None = None,
393 allow_bypass: bool = True,
394) -> Callable[
395 [Callable[Concatenate[ProviderT, P], Awaitable[R]]],
396 Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
397]:
398 """Return decorator that can be used to cache a method's result."""
399
400 def _decorator(
401 func: Callable[Concatenate[ProviderT, P], Awaitable[R]],
402 ) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
403 @functools.wraps(func)
404 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
405 cache = self.mass.cache
406 provider_id = getattr(self, "instance_id", self.domain)
407
408 # create a cache key dynamically based on the (remaining) args/kwargs
409 cache_key_parts = [func.__name__, *args]
410 for key in sorted(kwargs.keys()):
411 cache_key_parts.append(f"{key}{kwargs[key]}")
412 cache_key = ".".join(map(str, cache_key_parts))
413 # try to retrieve data from the cache
414 cachedata = await cache.get(
415 cache_key,
416 provider=provider_id,
417 checksum=cache_checksum,
418 category=category,
419 allow_bypass=allow_bypass,
420 )
421 if cachedata is not None:
422 type_hints = get_type_hints(func)
423 return cast("R", parse_value(func.__name__, cachedata, type_hints["return"]))
424 # get data from method/provider
425 result = await func(self, *args, **kwargs)
426 # store result in cache (but don't await)
427 self.mass.create_task(
428 cache.set(
429 key=cache_key,
430 data=result,
431 expiration=expiration,
432 provider=provider_id,
433 category=category,
434 checksum=cache_checksum,
435 persistent=persistent,
436 )
437 )
438 return result
439
440 return wrapper
441
442 return _decorator
443
444
445class MemoryCache(MutableMapping[str, Any]):
446 """Simple limited in-memory cache implementation."""
447
448 def __init__(self, maxlen: int) -> None:
449 """Initialize."""
450 self._maxlen = maxlen
451 self.d: OrderedDict[str, Any] = OrderedDict()
452
453 @property
454 def maxlen(self) -> int:
455 """Return max length."""
456 return self._maxlen
457
458 def get(self, key: str, default: Any = None) -> Any:
459 """Return item or default."""
460 return self.d.get(key, default)
461
462 def pop(self, key: str, default: Any = None) -> Any:
463 """Pop item from collection."""
464 return self.d.pop(key, default)
465
466 def __getitem__(self, key: str) -> Any:
467 """Get item."""
468 self.d.move_to_end(key)
469 return self.d[key]
470
471 def __setitem__(self, key: str, value: Any) -> None:
472 """Set item."""
473 if key in self.d:
474 self.d.move_to_end(key)
475 elif len(self.d) == self.maxlen:
476 self.d.popitem(last=False)
477 self.d[key] = value
478
479 def __delitem__(self, key: str) -> None:
480 """Delete item."""
481 del self.d[key]
482
483 def __iter__(self) -> Iterator[str]:
484 """Iterate items."""
485 return self.d.__iter__()
486
487 def __len__(self) -> int:
488 """Return length."""
489 return len(self.d)
490
491 def clear(self) -> None:
492 """Clear cache."""
493 self.d.clear()
494