music-assistant-server

11.4 KBPY
ssl.py
11.4 KB364 lines • python
1"""SSL helpers for the webserver controller."""
2
3from __future__ import annotations
4
5import asyncio
6import contextlib
7import logging
8import ssl
9import subprocess
10import tempfile
11from dataclasses import dataclass
12from pathlib import Path
13
14import aiofiles
15
16LOGGER = logging.getLogger(__name__)
17
18
19@dataclass
20class SSLCertificateInfo:
21    """Information about an SSL certificate."""
22
23    is_valid: bool
24    key_type: str  # "RSA", "ECDSA", or "Unknown"
25    subject: str
26    expiry: str
27    is_expired: bool
28    is_expiring_soon: bool  # Within 30 days
29    error_message: str | None = None
30
31
32async def get_ssl_content(value: str) -> str:
33    """Get SSL content from either a file path or the raw PEM content.
34
35    :param value: Either an absolute file path or the raw PEM content.
36    :return: The PEM content as a string.
37    :raises FileNotFoundError: If the file path doesn't exist.
38    :raises ValueError: If the path is not a file.
39    """
40    value = value.strip()
41    # Check if this looks like a file path (absolute path starting with /)
42    # PEM content always starts with "-----BEGIN"
43    if value.startswith("/") and not value.startswith("-----BEGIN"):
44        # This looks like a file path
45        path = Path(value)
46        if not path.exists():
47            raise FileNotFoundError(f"SSL file not found: {value}")
48        if not path.is_file():
49            raise ValueError(f"SSL path is not a file: {value}")
50        async with aiofiles.open(path) as f:
51            content: str = await f.read()
52            return content
53    # Otherwise, treat as raw PEM content
54    return value
55
56
57def _run_openssl_command(args: list[str]) -> subprocess.CompletedProcess[str]:
58    """Run an openssl command synchronously.
59
60    :param args: List of arguments for the openssl command (excluding 'openssl').
61    :return: CompletedProcess result.
62    """
63    return subprocess.run(  # noqa: S603
64        ["openssl", *args],  # noqa: S607
65        capture_output=True,
66        text=True,
67        timeout=10,
68        check=False,
69    )
70
71
72async def create_server_ssl_context(
73    certificate: str,
74    private_key: str,
75    logger: logging.Logger | None = None,
76) -> ssl.SSLContext | None:
77    """Create an SSL context for a server from certificate and private key.
78
79    :param certificate: The SSL certificate (file path or PEM content).
80    :param private_key: The SSL private key (file path or PEM content).
81    :param logger: Optional logger for error messages.
82    :return: SSL context if successful, None otherwise.
83    """
84    log = logger or LOGGER
85    if not certificate or not private_key:
86        log.error(
87            "SSL is enabled but certificate or private key is missing. "
88            "Server will start without SSL."
89        )
90        return None
91
92    cert_path = None
93    key_path = None
94    try:
95        # Load certificate and key content (supports both file paths and raw content)
96        cert_content = await get_ssl_content(certificate)
97        key_content = await get_ssl_content(private_key)
98
99        # Create SSL context
100        ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
101
102        # Write certificate and key to temporary files
103        # This is necessary because ssl.SSLContext.load_cert_chain requires file paths
104        with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as cert_file:
105            cert_file.write(cert_content)
106            cert_path = cert_file.name
107
108        with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as key_file:
109            key_file.write(key_content)
110            key_path = key_file.name
111
112        # Load certificate and private key
113        ssl_context.load_cert_chain(cert_path, key_path)
114        log.info("SSL/TLS enabled for server")
115        return ssl_context
116
117    except Exception as e:
118        log.exception("Failed to create SSL context: %s. Server will start without SSL.", e)
119        return None
120    finally:
121        # Clean up temporary files
122        if cert_path:
123            with contextlib.suppress(Exception):
124                Path(cert_path).unlink()
125        if key_path:
126            with contextlib.suppress(Exception):
127                Path(key_path).unlink()
128
129
130async def verify_ssl_certificate(certificate: str, private_key: str) -> SSLCertificateInfo:
131    """Verify SSL certificate and private key are valid and match.
132
133    :param certificate: The SSL certificate (file path or PEM content).
134    :param private_key: The SSL private key (file path or PEM content).
135    :return: SSLCertificateInfo with verification results.
136    """
137    if not certificate or not private_key:
138        return SSLCertificateInfo(
139            is_valid=False,
140            key_type="Unknown",
141            subject="",
142            expiry="",
143            is_expired=False,
144            is_expiring_soon=False,
145            error_message="Both certificate and private key are required.",
146        )
147
148    # Load certificate and key content
149    try:
150        cert_content = await get_ssl_content(certificate)
151    except FileNotFoundError as e:
152        return SSLCertificateInfo(
153            is_valid=False,
154            key_type="Unknown",
155            subject="",
156            expiry="",
157            is_expired=False,
158            is_expiring_soon=False,
159            error_message=f"Certificate file not found: {e}",
160        )
161    except Exception as e:
162        return SSLCertificateInfo(
163            is_valid=False,
164            key_type="Unknown",
165            subject="",
166            expiry="",
167            is_expired=False,
168            is_expiring_soon=False,
169            error_message=f"Error loading certificate: {e}",
170        )
171
172    try:
173        key_content = await get_ssl_content(private_key)
174    except FileNotFoundError as e:
175        return SSLCertificateInfo(
176            is_valid=False,
177            key_type="Unknown",
178            subject="",
179            expiry="",
180            is_expired=False,
181            is_expiring_soon=False,
182            error_message=f"Private key file not found: {e}",
183        )
184    except Exception as e:
185        return SSLCertificateInfo(
186            is_valid=False,
187            key_type="Unknown",
188            subject="",
189            expiry="",
190            is_expired=False,
191            is_expiring_soon=False,
192            error_message=f"Error loading private key: {e}",
193        )
194
195    # Verify with temp files
196    try:
197        return await _verify_ssl_with_temp_files(cert_content, key_content)
198    except ssl.SSLError as e:
199        return SSLCertificateInfo(
200            is_valid=False,
201            key_type="Unknown",
202            subject="",
203            expiry="",
204            is_expired=False,
205            is_expiring_soon=False,
206            error_message=_format_ssl_error(e),
207        )
208    except Exception as e:
209        return SSLCertificateInfo(
210            is_valid=False,
211            key_type="Unknown",
212            subject="",
213            expiry="",
214            is_expired=False,
215            is_expiring_soon=False,
216            error_message=f"Verification failed: {e}",
217        )
218
219
220async def _verify_ssl_with_temp_files(cert_content: str, key_content: str) -> SSLCertificateInfo:
221    """Verify SSL using temporary files.
222
223    :param cert_content: Certificate PEM content.
224    :param key_content: Private key PEM content.
225    :return: SSLCertificateInfo with verification results.
226    """
227    cert_path = None
228    key_path = None
229    try:
230        with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as cert_file:
231            cert_file.write(cert_content)
232            cert_path = cert_file.name
233
234        with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as key_file:
235            key_file.write(key_content)
236            key_path = key_file.name
237
238        # Test loading into SSL context
239        test_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
240        test_ctx.load_cert_chain(cert_path, key_path)
241
242        # Get certificate details using openssl
243        return await _get_certificate_details(cert_path)
244    finally:
245        # Clean up temp files
246        if cert_path:
247            with contextlib.suppress(Exception):
248                Path(cert_path).unlink()
249        if key_path:
250            with contextlib.suppress(Exception):
251                Path(key_path).unlink()
252
253
254async def _get_certificate_details(cert_path: str) -> SSLCertificateInfo:
255    """Get certificate details using openssl.
256
257    :param cert_path: Path to the certificate file.
258    :return: SSLCertificateInfo with certificate details.
259    """
260    # Get certificate info
261    result = await asyncio.to_thread(
262        _run_openssl_command,
263        ["x509", "-in", cert_path, "-noout", "-subject", "-dates", "-issuer"],
264    )
265
266    if result.returncode != 0:
267        return SSLCertificateInfo(
268            is_valid=True,
269            key_type="Unknown",
270            subject="",
271            expiry="",
272            is_expired=False,
273            is_expiring_soon=False,
274        )
275
276    # Parse certificate info
277    expiry = ""
278    subject = ""
279    for line in result.stdout.strip().split("\n"):
280        if line.startswith("notAfter="):
281            expiry = line.replace("notAfter=", "")
282        elif line.startswith("subject="):
283            subject = line.replace("subject=", "")
284
285    # Check expiry status
286    expiry_check = await asyncio.to_thread(
287        _run_openssl_command,
288        ["x509", "-in", cert_path, "-noout", "-checkend", "0"],
289    )
290    is_expired = expiry_check.returncode != 0
291
292    expiring_soon_check = await asyncio.to_thread(
293        _run_openssl_command,
294        ["x509", "-in", cert_path, "-noout", "-checkend", str(30 * 24 * 60 * 60)],
295    )
296    is_expiring_soon = expiring_soon_check.returncode != 0
297
298    # Detect key type
299    key_type_result = await asyncio.to_thread(
300        _run_openssl_command,
301        ["x509", "-in", cert_path, "-noout", "-text"],
302    )
303    key_type = "Unknown"
304    if "rsaEncryption" in key_type_result.stdout:
305        key_type = "RSA"
306    elif "id-ecPublicKey" in key_type_result.stdout:
307        key_type = "ECDSA"
308
309    return SSLCertificateInfo(
310        is_valid=True,
311        key_type=key_type,
312        subject=subject,
313        expiry=expiry,
314        is_expired=is_expired,
315        is_expiring_soon=is_expiring_soon,
316    )
317
318
319def _format_ssl_error(e: ssl.SSLError) -> str:
320    """Format an SSL error into a user-friendly message.
321
322    :param e: The SSL error.
323    :return: User-friendly error message.
324    """
325    error_msg = str(e)
326    if "PEM lib" in error_msg:
327        return (
328            "Invalid certificate or key format. "
329            "Make sure both are valid PEM format and the key is not encrypted."
330        )
331    if "key values mismatch" in error_msg.lower():
332        return (
333            "Certificate and private key do not match. "
334            "Please verify you're using the correct key for this certificate."
335        )
336    return f"SSL Error: {error_msg}"
337
338
339def format_certificate_info(info: SSLCertificateInfo) -> str:
340    """Format SSLCertificateInfo into a human-readable string.
341
342    :param info: The certificate info to format.
343    :return: Human-readable string.
344    """
345    if not info.is_valid:
346        return f"Error: {info.error_message}"
347
348    status = "VALID"
349    warning = ""
350    if info.is_expired:
351        status = "EXPIRED"
352        warning = " (Certificate has expired!)"
353    elif info.is_expiring_soon:
354        status = "EXPIRING SOON"
355        warning = " (Certificate expires within 30 days)"
356
357    lines = [f"Certificate verification: {status}{warning}", f"Key type: {info.key_type}"]
358    if info.subject:
359        lines.append(f"Subject: {info.subject}")
360    if info.expiry:
361        lines.append(f"Expires: {info.expiry}")
362
363    return "\n".join(lines)
364