/
/
/
1"""Provides a simple stateless caching system."""
2
3from __future__ import annotations
4
5import asyncio
6import functools
7import logging
8import os
9import time
10from collections import OrderedDict
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterator, MutableMapping
12from contextlib import asynccontextmanager
13from contextvars import ContextVar
14from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, cast, get_type_hints
15
16from music_assistant_models.config_entries import ConfigEntry, ConfigValueType
17from music_assistant_models.enums import ConfigEntryType
18
19from music_assistant.constants import DB_TABLE_CACHE, DB_TABLE_SETTINGS, MASS_LOGGER_NAME
20from music_assistant.helpers.api import parse_value
21from music_assistant.helpers.database import DatabaseConnection
22from music_assistant.helpers.json import async_json_loads, json_dumps
23from music_assistant.models.core_controller import CoreController
24
25if TYPE_CHECKING:
26 from music_assistant_models.config_entries import CoreConfig
27
28 from music_assistant import MusicAssistant
29 from music_assistant.models.provider import Provider
30
31LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.cache")
32CONF_CLEAR_CACHE = "clear_cache"
33DEFAULT_CACHE_EXPIRATION = 86400 * 30 # 30 days
34DB_SCHEMA_VERSION = 6
35
36BYPASS_CACHE: ContextVar[bool] = ContextVar("BYPASS_CACHE", default=False)
37
38
39class CacheController(CoreController):
40 """Basic cache controller using both memory and database."""
41
42 domain: str = "cache"
43
44 def __init__(self, mass: MusicAssistant) -> None:
45 """Initialize core controller."""
46 super().__init__(mass)
47 self.database: DatabaseConnection | None = None
48 self._mem_cache = MemoryCache(500)
49 self.manifest.name = "Cache controller"
50 self.manifest.description = (
51 "Music Assistant's core controller for caching data throughout the application."
52 )
53 self.manifest.icon = "memory"
54
55 async def get_config_entries(
56 self,
57 action: str | None = None,
58 values: dict[str, ConfigValueType] | None = None,
59 ) -> tuple[ConfigEntry, ...]:
60 """Return all Config Entries for this core module (if any)."""
61 if action == CONF_CLEAR_CACHE:
62 await self.clear()
63 return (
64 ConfigEntry(
65 key=CONF_CLEAR_CACHE,
66 type=ConfigEntryType.LABEL,
67 label="The cache has been cleared",
68 ),
69 )
70 return (
71 ConfigEntry(
72 key=CONF_CLEAR_CACHE,
73 type=ConfigEntryType.ACTION,
74 label="Clear cache",
75 description="Reset/clear all items in the cache. ",
76 ),
77 )
78
79 async def setup(self, config: CoreConfig) -> None:
80 """Async initialize of cache module."""
81 self.logger.info("Initializing cache controller...")
82 await self._setup_database()
83 self.__schedule_cleanup_task()
84
85 async def close(self) -> None:
86 """Cleanup on exit."""
87 if self.database:
88 await self.database.close()
89
90 async def get(
91 self,
92 key: str,
93 provider: str = "default",
94 category: int = 0,
95 checksum: str | int | None = None,
96 default: Any = None,
97 allow_bypass: bool = True,
98 ) -> Any:
99 """Get object from cache and return the results.
100
101 - key: the (unique) lookup key of the cache object as reference
102 - provider: optional provider id to group cache objects
103 - category: optional category to group cache objects
104 - checksum: optional argument to check if the checksum in the
105 cache object matches the checksum provided
106 - default: value to return if no cache object is found
107 """
108 assert self.database is not None
109 assert key, "No key provided"
110 if allow_bypass and BYPASS_CACHE.get():
111 return default
112 cur_time = int(time.time())
113 if checksum is not None and not isinstance(checksum, str):
114 checksum = str(checksum)
115 # try memory cache first
116 memory_key = f"{provider}/{category}/{key}"
117 cache_data = self._mem_cache.get(memory_key)
118 if cache_data and (not checksum or cache_data[1] == checksum) and cache_data[2] >= cur_time:
119 return cache_data[0]
120 # fall back to db cache
121 if (
122 db_row := await self.database.get_row(
123 DB_TABLE_CACHE, {"category": category, "provider": provider, "key": key}
124 )
125 ) and (not checksum or (db_row["checksum"] == checksum and db_row["expires"] >= cur_time)):
126 try:
127 data = await async_json_loads(db_row["data"])
128 except Exception as exc:
129 LOGGER.error(
130 "Error parsing cache data for %s: %s",
131 memory_key,
132 str(exc),
133 exc_info=exc if self.logger.isEnabledFor(10) else None,
134 )
135 else:
136 # also store in memory cache for faster access
137 self._mem_cache[memory_key] = (
138 data,
139 db_row["checksum"],
140 db_row["expires"],
141 )
142 return data
143 return default
144
145 async def set(
146 self,
147 key: str,
148 data: Any,
149 expiration: int = DEFAULT_CACHE_EXPIRATION,
150 provider: str = "default",
151 category: int = 0,
152 checksum: str | None = None,
153 persistent: bool = False,
154 ) -> None:
155 """
156 Set data in cache.
157
158 - key: the (unique) lookup key of the cache object as reference
159 - data: the actual data to store in the cache
160 - expiration: time in seconds the cache object should be valid
161 - provider: optional provider id to group cache objects
162 - category: optional category to group cache objects
163 - checksum: optional argument to store with the cache object
164 - persistent: if True the cache object will not be deleted when clearing the cache
165 """
166 assert self.database is not None
167 if not key:
168 return
169 if checksum is not None:
170 checksum = str(checksum)
171 expires = int(time.time() + expiration)
172 memory_key = f"{provider}/{category}/{key}"
173 self._mem_cache[memory_key] = (data, checksum, expires)
174 if (expires - time.time()) < 1800:
175 # do not cache items in db with short expiration
176 return
177 data = await asyncio.to_thread(json_dumps, data)
178 await self.database.insert_or_replace(
179 DB_TABLE_CACHE,
180 {
181 "category": category,
182 "provider": provider,
183 "key": key,
184 "expires": expires,
185 "checksum": checksum,
186 "data": data,
187 "persistent": persistent,
188 },
189 )
190
191 async def delete(
192 self, key: str | None, category: int | None = None, provider: str | None = None
193 ) -> None:
194 """Delete data from cache."""
195 assert self.database is not None
196 match: dict[str, str | int] = {}
197 if key is not None:
198 match["key"] = key
199 if category is not None:
200 match["category"] = category
201 if provider is not None:
202 match["provider"] = provider
203 if key is not None and category is not None and provider is not None:
204 self._mem_cache.pop(f"{provider}/{category}/{key}", None)
205 else:
206 self._mem_cache.clear()
207 await self.database.delete(DB_TABLE_CACHE, match)
208
209 async def clear(
210 self,
211 key_filter: str | None = None,
212 category_filter: int | None = None,
213 provider_filter: str | None = None,
214 include_persistent: bool = False,
215 ) -> None:
216 """Clear all/partial items from cache."""
217 assert self.database is not None
218 self._mem_cache.clear()
219 self.logger.info("Clearing database...")
220 query_parts: list[str] = []
221 if category_filter is not None:
222 query_parts.append(f"category = {category_filter}")
223 if provider_filter is not None:
224 query_parts.append(f"provider LIKE '%{provider_filter}%'")
225 if key_filter is not None:
226 query_parts.append(f"key LIKE '%{key_filter}%'")
227 if not include_persistent:
228 query_parts.append("persistent = 0")
229 query = "WHERE " + " AND ".join(query_parts) if query_parts else None
230 await self.database.delete(DB_TABLE_CACHE, query=query)
231 self.logger.info("Clearing database DONE")
232
233 async def auto_cleanup(self) -> None:
234 """Run scheduled auto cleanup task."""
235 assert self.database is not None
236 self.logger.debug("Running automatic cleanup...")
237 # simply reset the memory cache
238 self._mem_cache.clear()
239 cur_timestamp = int(time.time())
240 cleaned_records = 0
241 for db_row in await self.database.get_rows(DB_TABLE_CACHE):
242 # clean up db cache object only if expired
243 if db_row["expires"] < cur_timestamp:
244 await self.database.delete(DB_TABLE_CACHE, {"id": db_row["id"]})
245 cleaned_records += 1
246 await asyncio.sleep(0) # yield to eventloop
247 self.logger.debug("Automatic cleanup finished (cleaned up %s records)", cleaned_records)
248
249 @asynccontextmanager
250 async def handle_refresh(self, bypass: bool) -> AsyncGenerator[None, None]:
251 """Handle the cache bypass."""
252 try:
253 token = BYPASS_CACHE.set(bypass)
254 yield None
255 finally:
256 BYPASS_CACHE.reset(token)
257
258 async def _setup_database(self) -> None:
259 """Initialize database."""
260 db_path = os.path.join(self.mass.cache_path, "cache.db")
261 self.database = DatabaseConnection(db_path)
262 await self.database.setup()
263
264 # always create db tables if they don't exist to prevent errors trying to access them later
265 await self.__create_database_tables()
266 try:
267 if db_row := await self.database.get_row(DB_TABLE_SETTINGS, {"key": "version"}):
268 prev_version = int(db_row["value"])
269 else:
270 prev_version = 0
271 except (KeyError, ValueError):
272 prev_version = 0
273
274 if prev_version not in (0, DB_SCHEMA_VERSION):
275 LOGGER.warning(
276 "Performing database migration from %s to %s",
277 prev_version,
278 DB_SCHEMA_VERSION,
279 )
280
281 if prev_version < DB_SCHEMA_VERSION:
282 # for now just keep it simple and just recreate the table(s)
283 await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_CACHE}")
284
285 # recreate missing table(s)
286 await self.__create_database_tables()
287
288 # store current schema version
289 await self.database.insert_or_replace(
290 DB_TABLE_SETTINGS,
291 {"key": "version", "value": str(DB_SCHEMA_VERSION), "type": "str"},
292 )
293 await self.__create_database_indexes()
294 # compact db (vacuum) at startup
295 self.logger.debug("Compacting database...")
296 try:
297 await self.database.vacuum()
298 except Exception as err:
299 self.logger.warning("Database vacuum failed: %s", str(err))
300 else:
301 self.logger.debug("Compacting database done")
302
303 async def __create_database_tables(self) -> None:
304 """Create database table(s)."""
305 assert self.database is not None
306 await self.database.execute(
307 f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_SETTINGS}(
308 key TEXT PRIMARY KEY,
309 value TEXT,
310 type TEXT
311 );"""
312 )
313 await self.database.execute(
314 f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_CACHE}(
315 [id] INTEGER PRIMARY KEY AUTOINCREMENT,
316 [category] INTEGER NOT NULL DEFAULT 0,
317 [key] TEXT NOT NULL,
318 [provider] TEXT NOT NULL,
319 [expires] INTEGER NOT NULL,
320 [data] TEXT NULL,
321 [checksum] TEXT NULL,
322 [persistent] INTEGER NOT NULL DEFAULT 0,
323 UNIQUE(category, key, provider)
324 )"""
325 )
326
327 await self.database.commit()
328
329 async def __create_database_indexes(self) -> None:
330 """Create database indexes."""
331 assert self.database is not None
332 await self.database.execute(
333 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_idx "
334 f"ON {DB_TABLE_CACHE}(category);"
335 )
336 await self.database.execute(
337 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_idx ON {DB_TABLE_CACHE}(key);"
338 )
339 await self.database.execute(
340 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_provider_idx "
341 f"ON {DB_TABLE_CACHE}(provider);"
342 )
343 await self.database.execute(
344 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_key_idx "
345 f"ON {DB_TABLE_CACHE}(category,key);"
346 )
347 await self.database.execute(
348 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_provider_idx "
349 f"ON {DB_TABLE_CACHE}(category,provider);"
350 )
351 await self.database.execute(
352 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_key_provider_idx "
353 f"ON {DB_TABLE_CACHE}(category,key,provider);"
354 )
355 await self.database.execute(
356 f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_key_provider_idx "
357 f"ON {DB_TABLE_CACHE}(key,provider);"
358 )
359 await self.database.commit()
360
361 def __schedule_cleanup_task(self) -> None:
362 """Schedule the cleanup task."""
363 self.mass.create_task(self.auto_cleanup())
364 # reschedule self
365 self.mass.loop.call_later(3600, self.__schedule_cleanup_task)
366
367
368Param = ParamSpec("Param")
369RetType = TypeVar("RetType")
370
371
372ProviderT = TypeVar("ProviderT", bound="Provider | CoreController")
373P = ParamSpec("P")
374R = TypeVar("R")
375
376
377def use_cache(
378 expiration: int = DEFAULT_CACHE_EXPIRATION,
379 category: int = 0,
380 persistent: bool = False,
381 cache_checksum: str | None = None,
382 allow_bypass: bool = True,
383) -> Callable[
384 [Callable[Concatenate[ProviderT, P], Awaitable[R]]],
385 Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]],
386]:
387 """Return decorator that can be used to cache a method's result."""
388
389 def _decorator(
390 func: Callable[Concatenate[ProviderT, P], Awaitable[R]],
391 ) -> Callable[Concatenate[ProviderT, P], Coroutine[Any, Any, R]]:
392 @functools.wraps(func)
393 async def wrapper(self: ProviderT, *args: P.args, **kwargs: P.kwargs) -> R:
394 cache = self.mass.cache
395 provider_id = getattr(self, "instance_id", self.domain)
396
397 # create a cache key dynamically based on the (remaining) args/kwargs
398 cache_key_parts = [func.__name__, *args]
399 for key in sorted(kwargs.keys()):
400 cache_key_parts.append(f"{key}{kwargs[key]}")
401 cache_key = ".".join(map(str, cache_key_parts))
402 # try to retrieve data from the cache
403 cachedata = await cache.get(
404 cache_key,
405 provider=provider_id,
406 checksum=cache_checksum,
407 category=category,
408 allow_bypass=allow_bypass,
409 )
410 if cachedata is not None:
411 type_hints = get_type_hints(func)
412 return cast("R", parse_value(func.__name__, cachedata, type_hints["return"]))
413 # get data from method/provider
414 result = await func(self, *args, **kwargs)
415 # store result in cache (but don't await)
416 self.mass.create_task(
417 cache.set(
418 key=cache_key,
419 data=result,
420 expiration=expiration,
421 provider=provider_id,
422 category=category,
423 checksum=cache_checksum,
424 persistent=persistent,
425 )
426 )
427 return result
428
429 return wrapper
430
431 return _decorator
432
433
434class MemoryCache(MutableMapping[str, Any]):
435 """Simple limited in-memory cache implementation."""
436
437 def __init__(self, maxlen: int) -> None:
438 """Initialize."""
439 self._maxlen = maxlen
440 self.d: OrderedDict[str, Any] = OrderedDict()
441
442 @property
443 def maxlen(self) -> int:
444 """Return max length."""
445 return self._maxlen
446
447 def get(self, key: str, default: Any = None) -> Any:
448 """Return item or default."""
449 return self.d.get(key, default)
450
451 def pop(self, key: str, default: Any = None) -> Any:
452 """Pop item from collection."""
453 return self.d.pop(key, default)
454
455 def __getitem__(self, key: str) -> Any:
456 """Get item."""
457 self.d.move_to_end(key)
458 return self.d[key]
459
460 def __setitem__(self, key: str, value: Any) -> None:
461 """Set item."""
462 if key in self.d:
463 self.d.move_to_end(key)
464 elif len(self.d) == self.maxlen:
465 self.d.popitem(last=False)
466 self.d[key] = value
467
468 def __delitem__(self, key: str) -> None:
469 """Delete item."""
470 del self.d[key]
471
472 def __iter__(self) -> Iterator[str]:
473 """Iterate items."""
474 return self.d.__iter__()
475
476 def __len__(self) -> int:
477 """Return length."""
478 return len(self.d)
479
480 def clear(self) -> None:
481 """Clear cache."""
482 self.d.clear()
483