/
/
/
1"""SSL helpers for the webserver controller."""
2
3from __future__ import annotations
4
5import asyncio
6import contextlib
7import logging
8import ssl
9import subprocess
10import tempfile
11from dataclasses import dataclass
12from pathlib import Path
13
14import aiofiles
15
16LOGGER = logging.getLogger(__name__)
17
18
19@dataclass
20class SSLCertificateInfo:
21 """Information about an SSL certificate."""
22
23 is_valid: bool
24 key_type: str # "RSA", "ECDSA", or "Unknown"
25 subject: str
26 expiry: str
27 is_expired: bool
28 is_expiring_soon: bool # Within 30 days
29 error_message: str | None = None
30
31
32async def get_ssl_content(value: str) -> str:
33 """Get SSL content from either a file path or the raw PEM content.
34
35 :param value: Either an absolute file path or the raw PEM content.
36 :return: The PEM content as a string.
37 :raises FileNotFoundError: If the file path doesn't exist.
38 :raises ValueError: If the path is not a file.
39 """
40 value = value.strip()
41 # Check if this looks like a file path (absolute path starting with /)
42 # PEM content always starts with "-----BEGIN"
43 if value.startswith("/") and not value.startswith("-----BEGIN"):
44 # This looks like a file path
45 path = Path(value)
46 if not path.exists():
47 raise FileNotFoundError(f"SSL file not found: {value}")
48 if not path.is_file():
49 raise ValueError(f"SSL path is not a file: {value}")
50 async with aiofiles.open(path) as f:
51 content: str = await f.read()
52 return content
53 # Otherwise, treat as raw PEM content
54 return value
55
56
57def _run_openssl_command(args: list[str]) -> subprocess.CompletedProcess[str]:
58 """Run an openssl command synchronously.
59
60 :param args: List of arguments for the openssl command (excluding 'openssl').
61 :return: CompletedProcess result.
62 """
63 return subprocess.run( # noqa: S603
64 ["openssl", *args], # noqa: S607
65 capture_output=True,
66 text=True,
67 timeout=10,
68 check=False,
69 )
70
71
72async def create_server_ssl_context(
73 certificate: str,
74 private_key: str,
75 logger: logging.Logger | None = None,
76) -> ssl.SSLContext | None:
77 """Create an SSL context for a server from certificate and private key.
78
79 :param certificate: The SSL certificate (file path or PEM content).
80 :param private_key: The SSL private key (file path or PEM content).
81 :param logger: Optional logger for error messages.
82 :return: SSL context if successful, None otherwise.
83 """
84 log = logger or LOGGER
85 if not certificate or not private_key:
86 log.error(
87 "SSL is enabled but certificate or private key is missing. "
88 "Server will start without SSL."
89 )
90 return None
91
92 cert_path = None
93 key_path = None
94 try:
95 # Load certificate and key content (supports both file paths and raw content)
96 cert_content = await get_ssl_content(certificate)
97 key_content = await get_ssl_content(private_key)
98
99 # Create SSL context
100 ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
101
102 # Write certificate and key to temporary files
103 # This is necessary because ssl.SSLContext.load_cert_chain requires file paths
104 with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as cert_file:
105 cert_file.write(cert_content)
106 cert_path = cert_file.name
107
108 with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as key_file:
109 key_file.write(key_content)
110 key_path = key_file.name
111
112 # Load certificate and private key
113 ssl_context.load_cert_chain(cert_path, key_path)
114 log.info("SSL/TLS enabled for server")
115 return ssl_context
116
117 except Exception as e:
118 log.exception("Failed to create SSL context: %s. Server will start without SSL.", e)
119 return None
120 finally:
121 # Clean up temporary files
122 if cert_path:
123 with contextlib.suppress(Exception):
124 Path(cert_path).unlink()
125 if key_path:
126 with contextlib.suppress(Exception):
127 Path(key_path).unlink()
128
129
130async def verify_ssl_certificate(certificate: str, private_key: str) -> SSLCertificateInfo:
131 """Verify SSL certificate and private key are valid and match.
132
133 :param certificate: The SSL certificate (file path or PEM content).
134 :param private_key: The SSL private key (file path or PEM content).
135 :return: SSLCertificateInfo with verification results.
136 """
137 if not certificate or not private_key:
138 return SSLCertificateInfo(
139 is_valid=False,
140 key_type="Unknown",
141 subject="",
142 expiry="",
143 is_expired=False,
144 is_expiring_soon=False,
145 error_message="Both certificate and private key are required.",
146 )
147
148 # Load certificate and key content
149 try:
150 cert_content = await get_ssl_content(certificate)
151 except FileNotFoundError as e:
152 return SSLCertificateInfo(
153 is_valid=False,
154 key_type="Unknown",
155 subject="",
156 expiry="",
157 is_expired=False,
158 is_expiring_soon=False,
159 error_message=f"Certificate file not found: {e}",
160 )
161 except Exception as e:
162 return SSLCertificateInfo(
163 is_valid=False,
164 key_type="Unknown",
165 subject="",
166 expiry="",
167 is_expired=False,
168 is_expiring_soon=False,
169 error_message=f"Error loading certificate: {e}",
170 )
171
172 try:
173 key_content = await get_ssl_content(private_key)
174 except FileNotFoundError as e:
175 return SSLCertificateInfo(
176 is_valid=False,
177 key_type="Unknown",
178 subject="",
179 expiry="",
180 is_expired=False,
181 is_expiring_soon=False,
182 error_message=f"Private key file not found: {e}",
183 )
184 except Exception as e:
185 return SSLCertificateInfo(
186 is_valid=False,
187 key_type="Unknown",
188 subject="",
189 expiry="",
190 is_expired=False,
191 is_expiring_soon=False,
192 error_message=f"Error loading private key: {e}",
193 )
194
195 # Verify with temp files
196 try:
197 return await _verify_ssl_with_temp_files(cert_content, key_content)
198 except ssl.SSLError as e:
199 return SSLCertificateInfo(
200 is_valid=False,
201 key_type="Unknown",
202 subject="",
203 expiry="",
204 is_expired=False,
205 is_expiring_soon=False,
206 error_message=_format_ssl_error(e),
207 )
208 except Exception as e:
209 return SSLCertificateInfo(
210 is_valid=False,
211 key_type="Unknown",
212 subject="",
213 expiry="",
214 is_expired=False,
215 is_expiring_soon=False,
216 error_message=f"Verification failed: {e}",
217 )
218
219
220async def _verify_ssl_with_temp_files(cert_content: str, key_content: str) -> SSLCertificateInfo:
221 """Verify SSL using temporary files.
222
223 :param cert_content: Certificate PEM content.
224 :param key_content: Private key PEM content.
225 :return: SSLCertificateInfo with verification results.
226 """
227 cert_path = None
228 key_path = None
229 try:
230 with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as cert_file:
231 cert_file.write(cert_content)
232 cert_path = cert_file.name
233
234 with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as key_file:
235 key_file.write(key_content)
236 key_path = key_file.name
237
238 # Test loading into SSL context
239 test_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
240 test_ctx.load_cert_chain(cert_path, key_path)
241
242 # Get certificate details using openssl
243 return await _get_certificate_details(cert_path)
244 finally:
245 # Clean up temp files
246 if cert_path:
247 with contextlib.suppress(Exception):
248 Path(cert_path).unlink()
249 if key_path:
250 with contextlib.suppress(Exception):
251 Path(key_path).unlink()
252
253
254async def _get_certificate_details(cert_path: str) -> SSLCertificateInfo:
255 """Get certificate details using openssl.
256
257 :param cert_path: Path to the certificate file.
258 :return: SSLCertificateInfo with certificate details.
259 """
260 # Get certificate info
261 result = await asyncio.to_thread(
262 _run_openssl_command,
263 ["x509", "-in", cert_path, "-noout", "-subject", "-dates", "-issuer"],
264 )
265
266 if result.returncode != 0:
267 return SSLCertificateInfo(
268 is_valid=True,
269 key_type="Unknown",
270 subject="",
271 expiry="",
272 is_expired=False,
273 is_expiring_soon=False,
274 )
275
276 # Parse certificate info
277 expiry = ""
278 subject = ""
279 for line in result.stdout.strip().split("\n"):
280 if line.startswith("notAfter="):
281 expiry = line.replace("notAfter=", "")
282 elif line.startswith("subject="):
283 subject = line.replace("subject=", "")
284
285 # Check expiry status
286 expiry_check = await asyncio.to_thread(
287 _run_openssl_command,
288 ["x509", "-in", cert_path, "-noout", "-checkend", "0"],
289 )
290 is_expired = expiry_check.returncode != 0
291
292 expiring_soon_check = await asyncio.to_thread(
293 _run_openssl_command,
294 ["x509", "-in", cert_path, "-noout", "-checkend", str(30 * 24 * 60 * 60)],
295 )
296 is_expiring_soon = expiring_soon_check.returncode != 0
297
298 # Detect key type
299 key_type_result = await asyncio.to_thread(
300 _run_openssl_command,
301 ["x509", "-in", cert_path, "-noout", "-text"],
302 )
303 key_type = "Unknown"
304 if "rsaEncryption" in key_type_result.stdout:
305 key_type = "RSA"
306 elif "id-ecPublicKey" in key_type_result.stdout:
307 key_type = "ECDSA"
308
309 return SSLCertificateInfo(
310 is_valid=True,
311 key_type=key_type,
312 subject=subject,
313 expiry=expiry,
314 is_expired=is_expired,
315 is_expiring_soon=is_expiring_soon,
316 )
317
318
319def _format_ssl_error(e: ssl.SSLError) -> str:
320 """Format an SSL error into a user-friendly message.
321
322 :param e: The SSL error.
323 :return: User-friendly error message.
324 """
325 error_msg = str(e)
326 if "PEM lib" in error_msg:
327 return (
328 "Invalid certificate or key format. "
329 "Make sure both are valid PEM format and the key is not encrypted."
330 )
331 if "key values mismatch" in error_msg.lower():
332 return (
333 "Certificate and private key do not match. "
334 "Please verify you're using the correct key for this certificate."
335 )
336 return f"SSL Error: {error_msg}"
337
338
339def format_certificate_info(info: SSLCertificateInfo) -> str:
340 """Format SSLCertificateInfo into a human-readable string.
341
342 :param info: The certificate info to format.
343 :return: Human-readable string.
344 """
345 if not info.is_valid:
346 return f"Error: {info.error_message}"
347
348 status = "VALID"
349 warning = ""
350 if info.is_expired:
351 status = "EXPIRED"
352 warning = " (Certificate has expired!)"
353 elif info.is_expiring_soon:
354 status = "EXPIRING SOON"
355 warning = " (Certificate expires within 30 days)"
356
357 lines = [f"Certificate verification: {status}{warning}", f"Key type: {info.key_type}"]
358 if info.subject:
359 lines.append(f"Subject: {info.subject}")
360 if info.expiry:
361 lines.append(f"Expires: {info.expiry}")
362
363 return "\n".join(lines)
364