music-assistant-server

5.6 KBPY
images.py
5.6 KB158 lines • python
1"""Utilities for image manipulation and retrieval."""
2
3from __future__ import annotations
4
5import asyncio
6import itertools
7import os
8import random
9from base64 import b64decode
10from collections.abc import Iterable
11from io import BytesIO
12from typing import TYPE_CHECKING, cast
13
14import aiofiles
15from aiohttp.client_exceptions import ClientError
16from PIL import Image, UnidentifiedImageError
17
18from music_assistant.helpers.security import is_safe_path
19from music_assistant.helpers.tags import get_embedded_image
20from music_assistant.models.metadata_provider import MetadataProvider
21from music_assistant.models.music_provider import MusicProvider
22from music_assistant.models.plugin import PluginProvider
23
24if TYPE_CHECKING:
25    from music_assistant_models.media_items import MediaItemImage
26    from PIL.Image import Image as ImageClass
27
28    from music_assistant.mass import MusicAssistant
29
30
31async def get_image_data(mass: MusicAssistant, path_or_url: str, provider: str) -> bytes:
32    """Create thumbnail from image url."""
33    # TODO: add local cache here !
34    if prov := mass.get_provider(provider):
35        assert isinstance(prov, MusicProvider | MetadataProvider | PluginProvider)
36        if resolved_image := await prov.resolve_image(path_or_url):
37            if isinstance(resolved_image, bytes):
38                return resolved_image
39            if isinstance(resolved_image, str):
40                path_or_url = resolved_image
41    # handle HTTP location
42    if path_or_url.startswith("http"):
43        try:
44            async with mass.http_session_no_ssl.get(path_or_url, raise_for_status=True) as resp:
45                return await resp.read()
46        except ClientError as err:
47            raise FileNotFoundError from err
48    # handle base64 embedded images
49    if path_or_url.startswith("data:image"):
50        return b64decode(path_or_url.split(",")[-1])
51    # handle FILE location (of type image)
52    if path_or_url.endswith(("jpg", "JPG", "png", "PNG", "jpeg")) and is_safe_path(path_or_url):
53        if await asyncio.to_thread(os.path.isfile, path_or_url):
54            async with aiofiles.open(path_or_url, "rb") as _file:
55                return cast("bytes", await _file.read())
56    # use ffmpeg for embedded images
57    if is_safe_path(path_or_url) and (img_data := await get_embedded_image(path_or_url)):
58        return img_data
59    msg = f"Image not found: {path_or_url}"
60    raise FileNotFoundError(msg)
61
62
63async def get_image_thumb(
64    mass: MusicAssistant,
65    path_or_url: str,
66    size: int | None,
67    provider: str,
68    image_format: str = "PNG",
69) -> bytes:
70    """Get (optimized) PNG thumbnail from image url."""
71    img_data = await get_image_data(mass, path_or_url, provider)
72    if not img_data or not isinstance(img_data, bytes):
73        raise FileNotFoundError(f"Image not found: {path_or_url}")
74
75    if not size and image_format.encode() in img_data:
76        return img_data
77
78    image_format = image_format.upper()
79    if image_format == "JPG":
80        image_format = "JPEG"
81
82    def _create_image() -> bytes:
83        data = BytesIO()
84        try:
85            img = Image.open(BytesIO(img_data))
86        except UnidentifiedImageError:
87            raise FileNotFoundError(f"Invalid image: {path_or_url}")
88        if size:
89            # Use LANCZOS for high quality downsampling
90            img.thumbnail((size, size), Image.Resampling.LANCZOS)
91
92        mode = "RGBA" if image_format == "PNG" else "RGB"
93
94        # Save with high quality settings
95        if image_format == "JPEG":
96            # For JPEG, use quality=95 for better quality
97            img.convert(mode).save(data, image_format, quality=95, optimize=False)
98        else:
99            # For PNG, disable optimize to preserve quality
100            img.convert(mode).save(data, image_format, optimize=False)
101        return data.getvalue()
102
103    image_format = image_format.upper()
104    return await asyncio.to_thread(_create_image)
105
106
107async def create_collage(
108    mass: MusicAssistant,
109    images: Iterable[MediaItemImage],
110    dimensions: tuple[int, int] = (1500, 1500),
111) -> bytes:
112    """Create a basic collage image from multiple image urls."""
113    image_size = 250
114
115    def _new_collage() -> ImageClass:
116        return Image.new("RGB", (dimensions[0], dimensions[1]), color=(255, 255, 255, 255))
117
118    collage = await asyncio.to_thread(_new_collage)
119
120    def _add_to_collage(img_data: bytes, coord_x: int, coord_y: int) -> None:
121        data = BytesIO(img_data)
122        photo = Image.open(data).convert("RGB")
123        photo = photo.resize((image_size, image_size))
124        collage.paste(photo, (coord_x, coord_y))
125        del data
126
127    # prevent duplicates with a set
128    images = list(set(images))
129    random.shuffle(images)
130    iter_images = itertools.cycle(images)
131
132    for x_co in range(0, dimensions[0], image_size):
133        for y_co in range(0, dimensions[1], image_size):
134            for _ in range(5):
135                img = next(iter_images)
136                img_data = await get_image_data(mass, img.path, img.provider)
137                if img_data:
138                    await asyncio.to_thread(_add_to_collage, img_data, x_co, y_co)
139                    del img_data
140                    break
141
142    def _save_collage() -> bytes:
143        final_data = BytesIO()
144        collage.convert("RGB").save(final_data, "JPEG", optimize=True)
145        return final_data.getvalue()
146
147    return await asyncio.to_thread(_save_collage)
148
149
150async def get_icon_string(icon_path: str) -> str:
151    """Get svg icon as string."""
152    ext = icon_path.rsplit(".")[-1]
153    assert ext == "svg"
154    async with aiofiles.open(icon_path) as _file:
155        xml_data = await _file.read()
156        assert isinstance(xml_data, str)  # for type checking
157        return xml_data.replace("\n", "").strip()
158