/
/
/
1"""Provides a simple stateless caching system."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import logging
8import os
9import time
10from collections import OrderedDict
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterator, MutableMapping
12from contextlib import asynccontextmanager
13from contextvars import ContextVar
14from pathlib import Path
15from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, cast, get_type_hints
16
17from music_assistant_models.config_entries import ConfigEntry, ConfigValueType
18from music_assistant_models.enums import ConfigEntryType
19
20from music_assistant.constants import DB_TABLE_CACHE, DB_TABLE_SETTINGS, MASS_LOGGER_NAME
21from music_assistant.helpers.api import parse_value
22from music_assistant.helpers.database import DatabaseConnection
23from music_assistant.helpers.json import async_json_loads, json_dumps
24from music_assistant.models.core_controller import CoreController
25
26if TYPE_CHECKING:
27 from music_assistant_models.config_entries import CoreConfig
28
29 from music_assistant import MusicAssistant
30 from music_assistant.models.provider import Provider
31
32LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.cache")
33CONF_CLEAR_CACHE = "clear_cache"
34DEFAULT_CACHE_EXPIRATION = 86400 * 30 # 30 days
35DB_SCHEMA_VERSION = 7
36MAX_CACHE_DB_SIZE_MB = 2048
37
38BYPASS_CACHE: ContextVar[bool] = ContextVar("BYPASS_CACHE", default=False)
39
40
41class CacheController(CoreController):
42 """Basic cache controller using both memory and database."""
43
44 domain: str = "cache"
45
46 def __init__(self, mass: MusicAssistant) -> None:
47 """Initialize core controller."""
48 super().__init__(mass)
49 self.database: DatabaseConnection | None = None
50 self._mem_cache = MemoryCache(500)
51 self.manifest.name = "Cache controller"
52 self.manifest.description = (
53 "Music Assistant's core controller for caching data throughout the application."
54 )
55 self.manifest.icon = "memory"
56
57 async def get_config_entries(
58 self,
59 action: str | None = None,
60 values: dict[str, ConfigValueType] | None = None,
61 ) -> tuple[ConfigEntry, ...]:
62 """Return all Config Entries for this core module (if any)."""
63 if action == CONF_CLEAR_CACHE:
64 await self.clear()
65 return (
66 ConfigEntry(
67 key=CONF_CLEAR_CACHE,
68 type=ConfigEntryType.LABEL,
69 label="The cache has been cleared",
70 ),
71 )
72 return (
73 ConfigEntry(
74 key=CONF_CLEAR_CACHE,
75 type=ConfigEntryType.ACTION,
76 label="Clear cache",
77 description="Reset/clear all items in the cache. ",
78 ),
79 )
80
81 async def setup(self, config: CoreConfig) -> None:
82 """Async initialize of cache module."""
83 self.logger.info("Initializing cache controller...")
84 await self._setup_database()
85 self.__schedule_cleanup_task()
86
87 async def close(self) -> None:
88 """Cleanup on exit."""
89 if self.database:
90 await self.database.close()
91
92 async def get(
93 self,
94 key: str,
95 provider: str = "default",
96 category: int = 0,
97 checksum: str | int | None = None,
98 default: Any = None,
99 allow_bypass: bool = True,
100 ) -> Any:
101 """Get object from cache and return the results.
102
103 - key: the (unique) lookup key of the cache object as reference
104 - provider: optional provider id to group cache objects
105 - category: optional category to group cache objects
106 - checksum: optional argument to check if the checksum in the
107 cache object matches the checksum provided
108 - default: value to return if no cache object is found
109 """
110 assert self.database is not None
111 assert key, "No key provided"
112 if allow_bypass and BYPASS_CACHE.get():
113 return default
114 cur_time = int(time.time())
115 if checksum is not None and not isinstance(checksum, str):
116 checksum = str(checksum)
117 # try memory cache first
118 memory_key = f"{provider}/{category}/{key}"
119 cache_data = self._mem_cache.get(memory_key)
120 if cache_data and (not checksum or cache_data[1] == checksum) and cache_data[2] >= cur_time:
121 return cache_data[0]
122 # fall back to db cache
123 if (
124 (
125 db_row := await self.database.get_row(
126 DB_TABLE_CACHE, {"category": category, "provider": provider, "key": key}
127 )
128 )
129 and db_row["expires"] >= cur_time
130 and (not checksum or db_row["checksum"] == checksum)
131 ):
132 try:
133 data = await async_json_loads(db_row["data"])
134 except Exception as exc:
135 LOGGER.error(
136 "Error parsing cache data for %s: %s",
137 memory_key,
138 str(exc),
139 exc_info=exc if self.logger.isEnabledFor(10) else None,
140 )
141 else:
142 # also store in memory cache for faster access
143 self._mem_cache[memory_key] = (
144 data,
145 db_row["checksum"],
146 db_row["expires"],
147 )
148 return data
149 return default
150
151 async def set(
152 self,
153 key: str,
154 data: Any,
155 expiration: int = DEFAULT_CACHE_EXPIRATION,
156 provider: str = "default",
157 category: int = 0,
158 checksum: str | None = None,
159 persistent: bool = False,
160 ) -> None:
161 """
162 Set data in cache.
163
164 - key: the (unique) lookup key of the cache object as reference
165 - data: the actual data to store in the cache
166 - expiration: time in seconds the cache object should be valid
167 - provider: optional provider id to group cache objects
168 - category: optional category to group cache objects
169 - checksum: optional argument to store with the cache object
170 - persistent: if True the cache object will not be deleted when clearing the cache
171 """
172 assert self.database is not None
173 if not key:
174 return
175 if checksum is not None:
176 checksum = str(checksum)
177 expires = int(time.time() + expiration)
178 memory_key = f"{provider}/{category}/{key}"
179 self._mem_cache[memory_key] = (data, checksum, expires)
180 if (expires - time.time()) < 1800:
181 # do not cache items in db with short expiration
182 return
183 data = await asyncio.to_thread(json_dumps, data)
184 await self.database.insert_or_replace(
185 DB_TABLE_CACHE,
186 {
187 "category": category,
188 "provider": provider,
189 "key": key,
190 "expires": expires,
191 "checksum": checksum,
192 "data": data,
193 "persistent": persistent,
194 },
195 )
196
197 async def delete(
198 self, key: str | None, category: int | None = None, provider: str | None = None
199 ) -> None:
200 """Delete data from cache."""
201 assert self.database is not None
202 match: dict[str, str | int] = {}
203 if key is not None:
204 match["key"] = key
205 if category is not None:
206 match["category"] = category
207 if provider is not None:
208 match["provider"] = provider
209 if key is not None and category is not None and provider is not None:
210 self._mem_cache.pop(f"{provider}/{category}/{key}", None)
211 else:
212 self._mem_cache.clear()
213 await self.database.delete(DB_TABLE_CACHE, match)
214
215 async def clear(
216 self,
217 key_filter: str | None = None,
218 category_filter: int | None = None,
219 provider_filter: str | None = None,
220 include_persistent: bool = False,
221 ) -> None:
222 """Clear all/partial items from cache."""
223 assert self.database is not None
224 self._mem_cache.clear()
225 self.logger.info("Clearing database...")
226 query_parts: list[str] = []
227 if category_filter is not None:
228 query_parts.append(f"category = {category_filter}")
229 if provider_filter is not None:
230 query_parts.append(f"provider LIKE '%{provider_filter}%'")
231 if key_filter is not None:
232 query_parts.append(f"key LIKE '%{key_filter}%'")
233 if not include_persistent:
234 query_parts.append("persistent = 0")
235 query = "WHERE " + " AND ".join(query_parts) if query_parts else None
236 await self.database.delete(DB_TABLE_CACHE, query=query)
237 self.logger.info("Clearing database DONE")
238
239 async def auto_cleanup(self) -> None:
240 """Run scheduled auto cleanup task."""
241 assert self.database is not None
242 self.logger.debug("Running automatic cleanup...")
243 # simply reset the memory cache
244 self._mem_cache.clear()
245 cur_timestamp = int(time.time())
246 cleaned_records = 0
247 for db_row in await self.database.get_rows(DB_TABLE_CACHE):
248 # clean up db cache object only if expired
249 if db_row["expires"] < cur_timestamp:
250 await self.database.delete(DB_TABLE_CACHE, {"id": db_row["id"]})
251 cleaned_records += 1
252 await asyncio.sleep(0) # yield to eventloop
253 self.logger.debug("Automatic cleanup finished (cleaned up %s records)", cleaned_records)
254
255 @asynccontextmanager
256 async def handle_refresh(self, bypass: bool) -> AsyncGenerator[None, None]:
257 """Handle the cache bypass."""
258 try:
259 token = BYPASS_CACHE.set(bypass)
260 yield None
261 finally:
262 BYPASS_CACHE.reset(token)
263
264 async def _check_and_reset_oversized_cache(self) -> bool:
265 """Check cache database size and remove it if it exceeds the max size.
266
267 Returns True if the cache database was removed.
268 """
269 db_path = os.path.join(self.mass.cache_path, "cache.db")
270 # also include the write ahead log and shared memory db files
271 db_files = [db_path + suffix for suffix in ("", "-wal", "-shm")]
272
273 def _get_db_size() -> float:
274 total = 0
275 for path in db_files:
276 if os.path.exists(path):
277 total += Path(path).stat().st_size
278 return total / (1024 * 1024)
279
280 db_size_mb = await asyncio.to_thread(_get_db_size)
281 if db_size_mb <= MAX_CACHE_DB_SIZE_MB:
282 return False
283 self.logger.warning(
284 "Cache database size %.2f MB exceeds maximum of %d MB, removing cache database",
285 db_size_mb,
286 MAX_CACHE_DB_SIZE_MB,
287 )
288 for path in db_files:
289 if await asyncio.to_thread(os.path.exists, path):
290 await asyncio.to_thread(os.remove, path)
291 return True
292
293 async def _setup_database(self) -> None:
294 """Initialize database."""
295 cache_was_reset = await self._check_and_reset_oversized_cache()
296 db_path = os.path.join(self.mass.cache_path, "cache.db")
297 self.database = DatabaseConnection(db_path)
298 await self.database.setup()
299
300 # always create db tables if they don't exist to prevent errors trying to access them later
301 await self.__create_database_tables()
302
303 if not cache_was_reset:
304 try:
305 if db_row := await self.database.get_row(DB_TABLE_SETTINGS, {"key": "version"}):
306 prev_version = int(db_row["value"])
307 else:
308 prev_version = 0
309 except (KeyError, ValueError):
310 prev_version = 0
311
312 if prev_version not in (0, DB_SCHEMA_VERSION):
313 LOGGER.warning(
314 "Performing database migration from %s to %s",
315 prev_version,
316 DB_SCHEMA_VERSION,
317 )
318 try:
319 await self.__migrate_database(prev_version)
320 except Exception as err:
321 LOGGER.warning("Cache database migration failed: %s, resetting cache", err)
322 await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_CACHE}")
323 await self.__create_database_tables()
324
325 # store current schema version
326 await self.database.insert_or_replace(
327 DB_TABLE_SETTINGS,
328 {"key": "version", "value": str(DB_SCHEMA_VERSION), "type": "str"},
329 )
330 await self.__create_database_indexes()
331
332 if not cache_was_reset:
333 # compact db (vacuum) at startup
334 self.logger.debug("Compacting database...")
335 try:
336 await self.database.vacuum()
337 except Exception as err:
338 self.logger.warning("Database vacuum failed: %s", str(err))
339 else:
340 self.logger.debug("Compacting database done")
341
342 async def __create_database_tables(self) -> None:
343 """Create database table(s)."""
344 assert self.database is not None
345 await self.database.execute(
346 f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_SETTINGS}(
347 key TEXT PRIMARY KEY,
348 value TEXT,
349 type TEXT
350 );"""
351 )
352 await self.database.execute(
353 f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_CACHE}(
354 [id] INTEGER PRIMARY KEY AUTOINCREMENT,
355 [category] INTEGER NOT NULL DEFAULT 0,
356 [key] TEXT NOT NULL,
357 [provider] TEXT NOT NULL,
358 [expires] INTEGER NOT NULL,
359 [data] TEXT NULL,
360 [checksum] TEXT NULL,
361 [persistent] INTEGER NOT NULL DEFAULT 0,
362 UNIQUE(category, key, provider)
363 )"""
364 )
365
366 await self.database.commit()
367
368 async def __create_database_indexes(self) -> None:
369 """Create database indexes."""
370 assert self.database is not None
371 await self.database.execute(
372 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_idx "
373 f"ON {DB_TABLE_CACHE}(category);"
374 )
375 await self.database.execute(
376 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_idx ON {DB_TABLE_CACHE}(key);"
377 )
378 await self.database.execute(
379 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_provider_idx "
380 f"ON {DB_TABLE_CACHE}(provider);"
381 )
382 await self.database.execute(
383 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_key_idx "
384 f"ON {DB_TABLE_CACHE}(category,key);"
385 )
386 await self.database.execute(
387 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_provider_idx "
388 f"ON {DB_TABLE_CACHE}(category,provider);"
389 )
390 await self.database.execute(
391 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_key_provider_idx "
392 f"ON {DB_TABLE_CACHE}(category,key,provider);"
393 )
394 await self.database.execute(
395 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_provider_idx "
396 f"ON {DB_TABLE_CACHE}(key,provider);"
397 )
398 await self.database.commit()
399
400 async def __migrate_database(self, prev_version: int) -> None:
401 """Perform a database migration."""
402 assert self.database is not None
403 if prev_version <= 6:
404 # clear spotify cache entries to fix bloated cache from playlist pagination bug
405 await self.database.delete(DB_TABLE_CACHE, query="WHERE provider LIKE '%spotify%'")
406 await self.database.commit()
407
408 def __schedule_cleanup_task(self) -> None:
409 """Schedule the cleanup task."""
410 self.mass.create_task(self.auto_cleanup())
411 # reschedule self
412 self.mass.loop.call_later(3600, self.__schedule_cleanup_task)
413
414
415Param = ParamSpec("Param")
416RetType = TypeVar("RetType")
417
418
419ProviderT = TypeVar("ProviderT", bound="Provider | CoreController")
420P = ParamSpec("P")
421R = TypeVar("R")
422
423
424def use_cache(
425 expiration: int = DEFAULT_CACHE_EXPIRATION,
426 category: int = 0,
427 persistent: bool = False,
428 cache_checksum: str | None = None,
429 allow_bypass: bool = True,
430) -> Callable[
431 [Callable[Concatenate[ProviderT, P], Awaitable[R]]],
432 Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
433]:
434 """Return decorator that can be used to cache a method's result."""
435
436 def _decorator(
437 func: Callable[Concatenate[ProviderT, P], Awaitable[R]],
438 ) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
439 @functools.wraps(func)
440 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
441 cache = self.mass.cache
442 provider_id = getattr(self, "instance_id", self.domain)
443
444 # create a cache key dynamically based on the (remaining) args/kwargs
445 cache_key_parts = [func.__name__, *args]
446 for key in sorted(kwargs.keys()):
447 cache_key_parts.append(f"{key}{kwargs[key]}")
448 cache_key = ".".join(map(str, cache_key_parts))
449 # try to retrieve data from the cache
450 cachedata = await cache.get(
451 cache_key,
452 provider=provider_id,
453 checksum=cache_checksum,
454 category=category,
455 allow_bypass=allow_bypass,
456 )
457 if cachedata is not None:
458 type_hints = get_type_hints(func)
459 return cast("R", parse_value(func.__name__, cachedata, type_hints["return"]))
460 # get data from method/provider
461 result = await func(self, *args, **kwargs)
462 # store result in cache (but don't await)
463 self.mass.create_task(
464 cache.set(
465 key=cache_key,
466 data=result,
467 expiration=expiration,
468 provider=provider_id,
469 category=category,
470 checksum=cache_checksum,
471 persistent=persistent,
472 )
473 )
474 return result
475
476 return wrapper
477
478 return _decorator
479
480
481class MemoryCache(MutableMapping[str, Any]):
482 """Simple limited in-memory cache implementation."""
483
484 def __init__(self, maxlen: int) -> None:
485 """Initialize."""
486 self._maxlen = maxlen
487 self.d: OrderedDict[str, Any] = OrderedDict()
488
489 @property
490 def maxlen(self) -> int:
491 """Return max length."""
492 return self._maxlen
493
494 def get(self, key: str, default: Any = None) -> Any:
495 """Return item or default."""
496 return self.d.get(key, default)
497
498 def pop(self, key: str, default: Any = None) -> Any:
499 """Pop item from collection."""
500 return self.d.pop(key, default)
501
502 def __getitem__(self, key: str) -> Any:
503 """Get item."""
504 self.d.move_to_end(key)
505 return self.d[key]
506
507 def __setitem__(self, key: str, value: Any) -> None:
508 """Set item."""
509 if key in self.d:
510 self.d.move_to_end(key)
511 elif len(self.d) == self.maxlen:
512 self.d.popitem(last=False)
513 self.d[key] = value
514
515 def __delitem__(self, key: str) -> None:
516 """Delete item."""
517 del self.d[key]
518
519 def __iter__(self) -> Iterator[str]:
520 """Iterate items."""
521 return self.d.__iter__()
522
523 def __len__(self) -> int:
524 """Return length."""
525 return len(self.d)
526
527 def clear(self) -> None:
528 """Clear cache."""
529 self.d.clear()
530