/
/
/
1"""Database helpers and logic."""
2
3from __future__ import annotations
4
5import asyncio
6import logging
7import os
8import time
9from collections.abc import Mapping
10from contextlib import asynccontextmanager
11from sqlite3 import OperationalError
12from typing import TYPE_CHECKING, Any, cast
13
14import aiosqlite
15
16from music_assistant.constants import MASS_LOGGER_NAME
17
18if TYPE_CHECKING:
19 from collections.abc import AsyncGenerator
20
21LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.database")
22
23
24class _UnsetType:
25 """Sentinel value to indicate a field should use the database default."""
26
27 _instance: _UnsetType | None = None
28
29 def __new__(cls) -> _UnsetType:
30 """Create singleton instance."""
31 if cls._instance is None:
32 cls._instance = super().__new__(cls)
33 return cls._instance
34
35 def __repr__(self) -> str:
36 """Return string representation."""
37 return "UNSET"
38
39 def __bool__(self) -> bool:
40 """Return False for boolean context."""
41 return False
42
43
44UNSET: _UnsetType = _UnsetType()
45
46ENABLE_DEBUG = os.environ.get("PYTHONDEVMODE") == "1"
47
48
49@asynccontextmanager
50async def debug_query(
51 sql_query: str, query_params: dict[str, Any] | None = None
52) -> AsyncGenerator[None]:
53 """Time the processing time of an sql query."""
54 if not ENABLE_DEBUG:
55 yield
56 return
57 time_start = time.time()
58 try:
59 yield
60 except OperationalError as err:
61 LOGGER.error(f"{err}\n{sql_query}")
62 raise
63 finally:
64 process_time = time.time() - time_start
65 if process_time > 0.5:
66 # log slow queries
67 for key, value in (query_params or {}).items():
68 sql_query = sql_query.replace(f":{key}", repr(value))
69 LOGGER.warning("SQL Query took %s seconds! (\n%s\n", process_time, sql_query)
70
71
72def query_params(query: str, params: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
73 """Extend query parameters support."""
74 if params is None:
75 return (query, {})
76 count = 0
77 result_query = query
78 result_params = {}
79 for key, value in params.items():
80 # add support for a list within the query params
81 # recreates the params as (:_param_0, :_param_1) etc
82 if isinstance(value, list | tuple):
83 subparams = []
84 for subval in value:
85 subparam_name = f"_param_{count}"
86 result_params[subparam_name] = subval
87 subparams.append(subparam_name)
88 count += 1
89 params_str = ",".join(f":{x}" for x in subparams)
90 result_query = result_query.replace(f" :{key}", f" ({params_str})")
91 else:
92 result_params[key] = value
93 return (result_query, result_params)
94
95
96class DatabaseConnection:
97 """Class that holds the (connection to the) database with some convenience helper functions."""
98
99 _db: aiosqlite.Connection
100
101 def __init__(self, db_path: str) -> None:
102 """Initialize class."""
103 self.db_path = db_path
104
105 async def setup(self) -> None:
106 """Perform async initialization."""
107 self._db = await aiosqlite.connect(self.db_path)
108 self._db.row_factory = aiosqlite.Row
109 # setup some default settings for more performance
110 await self.execute("PRAGMA analysis_limit=10000;")
111 await self.execute("PRAGMA locking_mode=exclusive;")
112 await self.execute("PRAGMA journal_mode=WAL;")
113 await self.execute("PRAGMA journal_size_limit = 6144000;")
114 await self.execute("PRAGMA synchronous=normal;")
115 await self.execute("PRAGMA temp_store=memory;")
116 await self.execute("PRAGMA mmap_size = 30000000000;")
117 await self.execute("PRAGMA cache_size = -64000;")
118 await self.commit()
119
120 async def close(self) -> None:
121 """Close db connection on exit."""
122 await self.execute("PRAGMA optimize;")
123 await self.commit()
124 await self._db.close()
125
126 async def get_rows(
127 self,
128 table: str,
129 match: dict[str, Any] | None = None,
130 order_by: str | None = None,
131 limit: int = 500,
132 offset: int = 0,
133 ) -> list[Mapping[str, Any]]:
134 """Get all rows for given table."""
135 sql_query = f"SELECT * FROM {table}"
136 if match is not None:
137 sql_query += " WHERE " + " AND ".join(f"{x} = :{x}" for x in match)
138 if order_by is not None:
139 sql_query += f" ORDER BY {order_by}"
140 if limit:
141 sql_query += f" LIMIT {limit} OFFSET {offset}"
142 async with debug_query(sql_query):
143 return cast(
144 "list[Mapping[str, Any]]", await self._db.execute_fetchall(sql_query, match)
145 )
146
147 async def get_rows_from_query(
148 self,
149 query: str,
150 params: dict[str, Any] | None = None,
151 limit: int = 500,
152 offset: int = 0,
153 ) -> list[Mapping[str, Any]]:
154 """Get all rows for given custom query."""
155 if limit:
156 query += f" LIMIT {limit} OFFSET {offset}"
157 _query, _params = query_params(query, params)
158 async with debug_query(_query, _params):
159 return cast("list[Mapping[str, Any]]", await self._db.execute_fetchall(_query, _params))
160
161 async def get_count_from_query(
162 self,
163 query: str,
164 params: dict[str, Any] | None = None,
165 ) -> int:
166 """Get row count for given custom query."""
167 query = f"SELECT count() FROM ({query})"
168 _query, _params = query_params(query, params)
169 async with debug_query(_query):
170 async with self._db.execute(_query, _params) as cursor:
171 if result := await cursor.fetchone():
172 assert isinstance(result[0], int) # for type checking
173 return result[0]
174 return 0
175
176 async def get_count(
177 self,
178 table: str,
179 ) -> int:
180 """Get row count for given table."""
181 query = f"SELECT count(*) FROM {table}"
182 async with debug_query(query):
183 async with self._db.execute(query) as cursor:
184 if result := await cursor.fetchone():
185 assert isinstance(result[0], int) # for type checking
186 return result[0]
187 return 0
188
189 async def search(
190 self, table: str, search: str, column: str = "name"
191 ) -> list[Mapping[str, Any]]:
192 """Search table by column."""
193 sql_query = f"SELECT * FROM {table} WHERE {table}.{column} LIKE :search"
194 params = {"search": f"%{search}%"}
195 async with debug_query(sql_query, params):
196 return cast(
197 "list[Mapping[str, Any]]", await self._db.execute_fetchall(sql_query, params)
198 )
199
200 async def get_row(self, table: str, match: dict[str, Any]) -> Mapping[str, Any] | None:
201 """Get single row for given table where column matches keys/values."""
202 sql_query = f"SELECT * FROM {table} WHERE "
203 sql_query += " AND ".join(f"{table}.{x} = :{x}" for x in match)
204 async with debug_query(sql_query, match), self._db.execute(sql_query, match) as cursor:
205 return cast("Mapping[str, Any] | None", await cursor.fetchone())
206
207 async def insert(
208 self,
209 table: str,
210 values: dict[str, Any],
211 allow_replace: bool = False,
212 ) -> int:
213 """Insert data in given table."""
214 # Filter out UNSET values so database defaults are used
215 values = {k: v for k, v in values.items() if v is not UNSET}
216 keys = tuple(values.keys())
217 if allow_replace:
218 sql_query = f"INSERT OR REPLACE INTO {table}({','.join(keys)})"
219 else:
220 sql_query = f"INSERT INTO {table}({','.join(keys)})"
221 sql_query += f" VALUES ({','.join(f':{x}' for x in keys)})"
222 row_id = await self._db.execute_insert(sql_query, values)
223 await self._db.commit()
224 assert row_id is not None # for type checking
225 assert isinstance(row_id[0], int) # for type checking
226 return row_id[0]
227
228 async def insert_or_replace(self, table: str, values: dict[str, Any]) -> int:
229 """Insert or replace data in given table."""
230 return await self.insert(table=table, values=values, allow_replace=True)
231
232 async def upsert(self, table: str, values: dict[str, Any]) -> None:
233 """Upsert data in given table."""
234 # Filter out UNSET values so database defaults are used
235 values = {k: v for k, v in values.items() if v is not UNSET}
236 keys = tuple(values.keys())
237 sql_query = (
238 f"INSERT INTO {table}({','.join(keys)}) VALUES ({','.join(f':{x}' for x in keys)})"
239 )
240 sql_query += f" ON CONFLICT DO UPDATE SET {','.join(f'{x}=:{x}' for x in keys)}"
241 await self._db.execute(sql_query, values)
242 await self._db.commit()
243
244 async def update(
245 self,
246 table: str,
247 match: dict[str, Any],
248 values: dict[str, Any],
249 ) -> Mapping[str, Any]:
250 """Update record."""
251 # Filter out UNSET values so those fields are not updated
252 values = {k: v for k, v in values.items() if v is not UNSET}
253 keys = tuple(values.keys())
254 sql_query = f"UPDATE {table} SET {','.join(f'{x}=:{x}' for x in keys)} WHERE "
255 sql_query += " AND ".join(f"{x} = :{x}" for x in match)
256 await self.execute(sql_query, {**match, **values})
257 await self._db.commit()
258 # return updated item
259 updated_item = await self.get_row(table, match)
260 assert updated_item is not None # for type checking
261 return updated_item
262
263 async def delete(
264 self, table: str, match: dict[str, Any] | None = None, query: str | None = None
265 ) -> None:
266 """Delete data in given table."""
267 assert not (match and query), "Cannot use both match and query"
268 sql_query = f"DELETE FROM {table} "
269 if match:
270 sql_query += " WHERE " + " AND ".join(f"{x} = :{x}" for x in match)
271 elif query and "where" not in query.lower():
272 sql_query += "WHERE " + query
273 elif query:
274 sql_query += query
275 await self.execute(sql_query, match)
276 await self._db.commit()
277
278 async def delete_where_query(self, table: str, query: str | None = None) -> None:
279 """Delete data in given table using given where clausule."""
280 sql_query = f"DELETE FROM {table} WHERE {query}"
281 await self.execute(sql_query)
282 await self._db.commit()
283
284 async def execute(self, query: str, values: dict[str, Any] | None = None) -> Any:
285 """Execute command on the database."""
286 return await self._db.execute(query, values)
287
288 async def commit(self) -> None:
289 """Commit the current transaction."""
290 return await self._db.commit()
291
292 async def iter_items(
293 self,
294 table: str,
295 match: dict[str, Any] | None = None,
296 ) -> AsyncGenerator[Mapping[str, Any], None]:
297 """Iterate all items within a table."""
298 limit: int = 500
299 offset: int = 0
300 while True:
301 next_items = await self.get_rows(
302 table=table,
303 match=match,
304 offset=offset,
305 limit=limit,
306 )
307 for item in next_items:
308 yield item
309 if len(next_items) < limit:
310 break
311 await asyncio.sleep(0) # yield to eventloop
312 offset += limit
313
314 async def vacuum(self) -> None:
315 """Run vacuum command on database."""
316 await self._db.execute("VACUUM")
317 await self._db.commit()
318