music-assistant-server

11.4 KBPY
database.py
11.4 KB318 lines • python
1"""Database helpers and logic."""
2
3from __future__ import annotations
4
5import asyncio
6import logging
7import os
8import time
9from collections.abc import Mapping
10from contextlib import asynccontextmanager
11from sqlite3 import OperationalError
12from typing import TYPE_CHECKING, Any, cast
13
14import aiosqlite
15
16from music_assistant.constants import MASS_LOGGER_NAME
17
18if TYPE_CHECKING:
19    from collections.abc import AsyncGenerator
20
21LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.database")
22
23
24class _UnsetType:
25    """Sentinel value to indicate a field should use the database default."""
26
27    _instance: _UnsetType | None = None
28
29    def __new__(cls) -> _UnsetType:
30        """Create singleton instance."""
31        if cls._instance is None:
32            cls._instance = super().__new__(cls)
33        return cls._instance
34
35    def __repr__(self) -> str:
36        """Return string representation."""
37        return "UNSET"
38
39    def __bool__(self) -> bool:
40        """Return False for boolean context."""
41        return False
42
43
44UNSET: _UnsetType = _UnsetType()
45
46ENABLE_DEBUG = os.environ.get("PYTHONDEVMODE") == "1"
47
48
49@asynccontextmanager
50async def debug_query(
51    sql_query: str, query_params: dict[str, Any] | None = None
52) -> AsyncGenerator[None]:
53    """Time the processing time of an sql query."""
54    if not ENABLE_DEBUG:
55        yield
56        return
57    time_start = time.time()
58    try:
59        yield
60    except OperationalError as err:
61        LOGGER.error(f"{err}\n{sql_query}")
62        raise
63    finally:
64        process_time = time.time() - time_start
65        if process_time > 0.5:
66            # log slow queries
67            for key, value in (query_params or {}).items():
68                sql_query = sql_query.replace(f":{key}", repr(value))
69            LOGGER.warning("SQL Query took %s seconds! (\n%s\n", process_time, sql_query)
70
71
72def query_params(query: str, params: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
73    """Extend query parameters support."""
74    if params is None:
75        return (query, {})
76    count = 0
77    result_query = query
78    result_params = {}
79    for key, value in params.items():
80        # add support for a list within the query params
81        # recreates the params as (:_param_0, :_param_1) etc
82        if isinstance(value, list | tuple):
83            subparams = []
84            for subval in value:
85                subparam_name = f"_param_{count}"
86                result_params[subparam_name] = subval
87                subparams.append(subparam_name)
88                count += 1
89            params_str = ",".join(f":{x}" for x in subparams)
90            result_query = result_query.replace(f" :{key}", f" ({params_str})")
91        else:
92            result_params[key] = value
93    return (result_query, result_params)
94
95
96class DatabaseConnection:
97    """Class that holds the (connection to the) database with some convenience helper functions."""
98
99    _db: aiosqlite.Connection
100
101    def __init__(self, db_path: str) -> None:
102        """Initialize class."""
103        self.db_path = db_path
104
105    async def setup(self) -> None:
106        """Perform async initialization."""
107        self._db = await aiosqlite.connect(self.db_path)
108        self._db.row_factory = aiosqlite.Row
109        # setup some default settings for more performance
110        await self.execute("PRAGMA analysis_limit=10000;")
111        await self.execute("PRAGMA locking_mode=exclusive;")
112        await self.execute("PRAGMA journal_mode=WAL;")
113        await self.execute("PRAGMA journal_size_limit = 6144000;")
114        await self.execute("PRAGMA synchronous=normal;")
115        await self.execute("PRAGMA temp_store=memory;")
116        await self.execute("PRAGMA mmap_size = 30000000000;")
117        await self.execute("PRAGMA cache_size = -64000;")
118        await self.commit()
119
120    async def close(self) -> None:
121        """Close db connection on exit."""
122        await self.execute("PRAGMA optimize;")
123        await self.commit()
124        await self._db.close()
125
126    async def get_rows(
127        self,
128        table: str,
129        match: dict[str, Any] | None = None,
130        order_by: str | None = None,
131        limit: int = 500,
132        offset: int = 0,
133    ) -> list[Mapping[str, Any]]:
134        """Get all rows for given table."""
135        sql_query = f"SELECT * FROM {table}"
136        if match is not None:
137            sql_query += " WHERE " + " AND ".join(f"{x} = :{x}" for x in match)
138        if order_by is not None:
139            sql_query += f" ORDER BY {order_by}"
140        if limit:
141            sql_query += f" LIMIT {limit} OFFSET {offset}"
142        async with debug_query(sql_query):
143            return cast(
144                "list[Mapping[str, Any]]", await self._db.execute_fetchall(sql_query, match)
145            )
146
147    async def get_rows_from_query(
148        self,
149        query: str,
150        params: dict[str, Any] | None = None,
151        limit: int = 500,
152        offset: int = 0,
153    ) -> list[Mapping[str, Any]]:
154        """Get all rows for given custom query."""
155        if limit:
156            query += f" LIMIT {limit} OFFSET {offset}"
157        _query, _params = query_params(query, params)
158        async with debug_query(_query, _params):
159            return cast("list[Mapping[str, Any]]", await self._db.execute_fetchall(_query, _params))
160
161    async def get_count_from_query(
162        self,
163        query: str,
164        params: dict[str, Any] | None = None,
165    ) -> int:
166        """Get row count for given custom query."""
167        query = f"SELECT count() FROM ({query})"
168        _query, _params = query_params(query, params)
169        async with debug_query(_query):
170            async with self._db.execute(_query, _params) as cursor:
171                if result := await cursor.fetchone():
172                    assert isinstance(result[0], int)  # for type checking
173                    return result[0]
174            return 0
175
176    async def get_count(
177        self,
178        table: str,
179    ) -> int:
180        """Get row count for given table."""
181        query = f"SELECT count(*) FROM {table}"
182        async with debug_query(query):
183            async with self._db.execute(query) as cursor:
184                if result := await cursor.fetchone():
185                    assert isinstance(result[0], int)  # for type checking
186                    return result[0]
187            return 0
188
189    async def search(
190        self, table: str, search: str, column: str = "name"
191    ) -> list[Mapping[str, Any]]:
192        """Search table by column."""
193        sql_query = f"SELECT * FROM {table} WHERE {table}.{column} LIKE :search"
194        params = {"search": f"%{search}%"}
195        async with debug_query(sql_query, params):
196            return cast(
197                "list[Mapping[str, Any]]", await self._db.execute_fetchall(sql_query, params)
198            )
199
200    async def get_row(self, table: str, match: dict[str, Any]) -> Mapping[str, Any] | None:
201        """Get single row for given table where column matches keys/values."""
202        sql_query = f"SELECT * FROM {table} WHERE "
203        sql_query += " AND ".join(f"{table}.{x} = :{x}" for x in match)
204        async with debug_query(sql_query, match), self._db.execute(sql_query, match) as cursor:
205            return cast("Mapping[str, Any] | None", await cursor.fetchone())
206
207    async def insert(
208        self,
209        table: str,
210        values: dict[str, Any],
211        allow_replace: bool = False,
212    ) -> int:
213        """Insert data in given table."""
214        # Filter out UNSET values so database defaults are used
215        values = {k: v for k, v in values.items() if v is not UNSET}
216        keys = tuple(values.keys())
217        if allow_replace:
218            sql_query = f"INSERT OR REPLACE INTO {table}({','.join(keys)})"
219        else:
220            sql_query = f"INSERT INTO {table}({','.join(keys)})"
221        sql_query += f" VALUES ({','.join(f':{x}' for x in keys)})"
222        row_id = await self._db.execute_insert(sql_query, values)
223        await self._db.commit()
224        assert row_id is not None  # for type checking
225        assert isinstance(row_id[0], int)  # for type checking
226        return row_id[0]
227
228    async def insert_or_replace(self, table: str, values: dict[str, Any]) -> int:
229        """Insert or replace data in given table."""
230        return await self.insert(table=table, values=values, allow_replace=True)
231
232    async def upsert(self, table: str, values: dict[str, Any]) -> None:
233        """Upsert data in given table."""
234        # Filter out UNSET values so database defaults are used
235        values = {k: v for k, v in values.items() if v is not UNSET}
236        keys = tuple(values.keys())
237        sql_query = (
238            f"INSERT INTO {table}({','.join(keys)}) VALUES ({','.join(f':{x}' for x in keys)})"
239        )
240        sql_query += f" ON CONFLICT DO UPDATE SET {','.join(f'{x}=:{x}' for x in keys)}"
241        await self._db.execute(sql_query, values)
242        await self._db.commit()
243
244    async def update(
245        self,
246        table: str,
247        match: dict[str, Any],
248        values: dict[str, Any],
249    ) -> Mapping[str, Any]:
250        """Update record."""
251        # Filter out UNSET values so those fields are not updated
252        values = {k: v for k, v in values.items() if v is not UNSET}
253        keys = tuple(values.keys())
254        sql_query = f"UPDATE {table} SET {','.join(f'{x}=:{x}' for x in keys)} WHERE "
255        sql_query += " AND ".join(f"{x} = :{x}" for x in match)
256        await self.execute(sql_query, {**match, **values})
257        await self._db.commit()
258        # return updated item
259        updated_item = await self.get_row(table, match)
260        assert updated_item is not None  # for type checking
261        return updated_item
262
263    async def delete(
264        self, table: str, match: dict[str, Any] | None = None, query: str | None = None
265    ) -> None:
266        """Delete data in given table."""
267        assert not (match and query), "Cannot use both match and query"
268        sql_query = f"DELETE FROM {table} "
269        if match:
270            sql_query += " WHERE " + " AND ".join(f"{x} = :{x}" for x in match)
271        elif query and "where" not in query.lower():
272            sql_query += "WHERE " + query
273        elif query:
274            sql_query += query
275        await self.execute(sql_query, match)
276        await self._db.commit()
277
278    async def delete_where_query(self, table: str, query: str | None = None) -> None:
279        """Delete data in given table using given where clausule."""
280        sql_query = f"DELETE FROM {table} WHERE {query}"
281        await self.execute(sql_query)
282        await self._db.commit()
283
284    async def execute(self, query: str, values: dict[str, Any] | None = None) -> Any:
285        """Execute command on the database."""
286        return await self._db.execute(query, values)
287
288    async def commit(self) -> None:
289        """Commit the current transaction."""
290        return await self._db.commit()
291
292    async def iter_items(
293        self,
294        table: str,
295        match: dict[str, Any] | None = None,
296    ) -> AsyncGenerator[Mapping[str, Any], None]:
297        """Iterate all items within a table."""
298        limit: int = 500
299        offset: int = 0
300        while True:
301            next_items = await self.get_rows(
302                table=table,
303                match=match,
304                offset=offset,
305                limit=limit,
306            )
307            for item in next_items:
308                yield item
309            if len(next_items) < limit:
310                break
311            await asyncio.sleep(0)  # yield to eventloop
312            offset += limit
313
314    async def vacuum(self) -> None:
315        """Run vacuum command on database."""
316        await self._db.execute("VACUUM")
317        await self._db.commit()
318