Source code for matrixctl.handlers.oidc

"""An OIDC Client for Matrix Authentication Service.

If you are reading this you might probably ask yourself:
"why yet another OIDC client?". The answer is unfortunately. It was easier to
write one from scratch (for this specific usecase) than to use any of the
existing ones I found. Either the documentation was in a devastating state
or they were unmaintained for years. It's pretty sad.

Nevertheless, this should not stop us from having nice things, too.
So here we are with yet another OIDC client.

"""

from __future__ import annotations

import base64
import hashlib
import http.server
import json
import logging
import secrets
import socketserver
import threading
import time
import typing as t
import urllib.parse
import webbrowser

from pathlib import Path
from urllib.parse import parse_qs
from urllib.parse import urlparse

import httpx

from xdg_base_dirs import xdg_data_home

from matrixctl.typehints import JsonDict


__author__: str = "Michael Sasser"
__email__: str = "Michael@MichaelSasser.org"

logger = logging.getLogger(__name__)


[docs] class OidcTCPServer(socketserver.TCPServer): """TCP server wrapper for handling OIDC authentication callbacks. This server listens for incoming HTTP requests containing the OIDC authorization code and stores it for later retrieval. """ def __init__( self, server_address: tuple[str, int], RequestHandlerClass: type[socketserver.BaseRequestHandler], # noqa: N803 ) -> None: """Initialize the OidcTCPServer. This mehod only adds the `auth_code` attribute to the server instance. Parameters ---------- server_address : tuple[str, int] The address on which the server listens. RequestHandlerClass : type[socketserver.BaseRequestHandler]) The request handler class. """ super().__init__(server_address, RequestHandlerClass) self.auth_code: str | None = None
[docs] class TokenManager: """Manager for OIDC tokens handling. It supports both client credentials and authorization code flows. Parameters ---------- token_endpoint : str OIDC provider's token endpoint URL client_id : str Client ID for OIDC authentication client_secret : str Client secret for OIDC authentication auth_endpoint : str | None, optional Authorization endpoint URL for user authentication flow, by default None userinfo_endpoint : str | None, optional Userinfo endpoint URL, by default None jwks_uri : str | None, optional JWKS endpoint URL, by default None cache_path : str, optional Path to token cache file, by default "~/.oidc_token_cache.json" """ wait_for_auth_code_timeout: int = 300 def __init__( # noqa: PLR0913 self, token_endpoint: str, client_id: str, client_secret: str, auth_endpoint: str | None = None, userinfo_endpoint: str | None = None, jwks_uri: str | None = None, cache_path: Path | None = None, ) -> None: data_home = xdg_data_home() / "matrixctl" self.token_endpoint: str = token_endpoint self.client_id: str = client_id self.client_secret: str = client_secret self.auth_endpoint: str | None = auth_endpoint self.userinfo_endpoint: str | None = userinfo_endpoint self.jwks_uri: str | None = jwks_uri self.cache_path: Path = ( cache_path or data_home / "oidc_token_cache.json" ) self.access_token: str | None = None self.refresh_token: str | None = None self.id_token: str | None = None self.expires_at: float = 0.0 logger.debug( "TokenManager initialized with cache path: %s", self.cache_path )
[docs] def recall_cached_token(self, key: str) -> bool: """Load and validate cached tokens from disk. Returns ------- bool True if valid token was loaded, False otherwise Notes ----- Sets instance attributes: - access_token: Loaded access token or None - refresh_token: Loaded refresh token or None - expires_at: Expiration timestamp or 0 """ try: if not self.cache_path.exists(): logger.debug( "Cache file does not exist: %s", self.cache_path, ) return False with self.cache_path.open() as fp: data: JsonDict = json.load(fp) logger.debug("Cache file exists: %s", self.cache_path) keyed: JsonDict = t.cast(JsonDict, data.get(key.strip().lower())) self.access_token = keyed.get("access_token") self.refresh_token = keyed.get("refresh_token") self.id_token = keyed.get("id_token") or self.id_token self.expires_at = keyed.get("expires_at", 0.0) logger.debug( "Recalled refresh token contains: " "{access_token: %s, " "refresh_token: %s, " "id_token: %s, " "expires_at: %s}", self.access_token is not None, self.refresh_token is not None, self.id_token is not None, self.expires_at is not None, ) if time.time() < self.expires_at: logger.debug("Token is not expired on recall") return True except PermissionError: logger.exception( "Insufficiant permissions to opent the file: %s", self.cache_path, ) except IsADirectoryError: logger.exception( ( "The oidc token cache file %s should be a file, not " "a directory" ), self.cache_path, ) except OSError: logger.exception("Failed to open/write to oidc token cache file") except json.JSONDecodeError: logger.exception( ( "The oidc token cache file exist, but it's content is not " "valid JSON: %s" ), self.cache_path, ) except AttributeError: logger.exception( ( "The oidc token cache file exist, but it's content is not " "does not contain the expected keys. Cache file: %s" ), self.cache_path, ) self.access_token = None self.expires_at = 0.0 logger.debug("Token is expired or invalid") return False
[docs] def store_cache_token( self, access_token: str, refresh_token: str | None, id_token: str | None, expires_in: int, key: str, ) -> None: """Cache tokens to disk with expiration information. Parameters ---------- access_token : str New access token to cache refresh_token : str | None Optional refresh token to cache expires_in : int Time in seconds until token expiration Notes ----- Updates instance attributes: - access_token - refresh_token - expires_at """ logger.debug("Started storing cached token: %s", self.cache_path) self.access_token = access_token self.refresh_token = refresh_token self.id_token = id_token self.expires_at = time.time() + expires_in logger.debug( "Stored refresh token contains: " "{access_token: %s, " "refresh_token: %s, " "id_token: %s, " "expires_at: %s}", self.access_token is not None, self.refresh_token is not None, self.id_token is not None, self.expires_at is not None, ) try: self.cache_path.parent.mkdir(parents=True, exist_ok=True) self.cache_path.touch(exist_ok=True) self.cache_path.chmod(0o600) with self.cache_path.open("w") as fp: json.dump( { key.strip().lower(): { "access_token": access_token, "refresh_token": refresh_token, "id_token": id_token, "expires_at": self.expires_at, } }, fp, ) except PermissionError: logger.exception( "Insufficiant permissions to opent the file: %s", self.cache_path, ) except IsADirectoryError: logger.exception( ( "The oidc token cache file %s must be a file, " "not a directory" ), self.cache_path, ) except OSError: logger.exception("Failed to open/write to oidc token cache file") logger.debug("Finished storing cached token: %s", self.cache_path)
[docs] def get_user_info(self) -> JsonDict: """Retrieve user information from the userinfo endpoint. Returns ------- dict[str, Any] User claims dictionary Raises ------ ValueError If userinfo endpoint not configured or no access token httpx.HTTPStatusError For HTTP request failures """ err_msg: str if not self.userinfo_endpoint: err_msg = "Userinfo endpoint not configured" raise ValueError(err_msg) if not self.access_token: err_msg = "No access token available" raise ValueError(err_msg) response = httpx.get( self.userinfo_endpoint, headers={"Authorization": f"Bearer {self.access_token}"}, timeout=10, ) _ = response.raise_for_status() user_info: JsonDict = t.cast(JsonDict, response.json()) logger.debug("User info retrieved: %s", user_info) return user_info
[docs] def get_payload(self) -> JsonDict: """Decode payload from the ID token. Returns ------- dict[str, Any] Decoded ID token payload Raises ------ ValueError If no ID token available """ if not self.id_token: err_: str = "No ID token available" raise ValueError(err_) try: # Split JWT into parts _, payload_unpadded, _ = self.id_token.split(".") # Add padding and decode payload_padded = payload_unpadded + "=" * ( -len(payload_unpadded) % 4 ) payload_decoded = base64.urlsafe_b64decode(payload_padded) payload: JsonDict = t.cast(JsonDict, json.loads(payload_decoded)) except json.JSONDecodeError: logger.exception( ("Unable to decode payload. Invalid JSON"), ) raise logger.debug("Payload decoded: %s", payload) return payload
[docs] def get_client_credentials_token(self) -> str: """Get access token using client credentials flow. Returns ------- str Valid access token Raises ------ httpx.HTTPStatusError For HTTP request failures ValueError If token response is invalid """ logger.debug("Started client credentials token request") if self.recall_cached_token("user") and self.access_token: logger.debug( "Recalled cached token, which contains the access token." ) if time.time() < self.expires_at: logger.debug("Recalled token is not expired") return self.access_token logger.debug("Recalled token is expired, refreshing it") refresh_token = self.refresh_access_token() if refresh_token: return refresh_token try: response = httpx.post( self.token_endpoint, data={ "grant_type": "client_credentials", "client_id": self.client_id, "client_secret": self.client_secret, }, timeout=10, ) _ = response.raise_for_status() token_data: dict[str, t.Any] = t.cast( dict[str, t.Any], response.json() ) except httpx.HTTPStatusError as e: logger.exception( "Token request failed: %s %s", e.response.status_code, e.response.text, ) raise except json.JSONDecodeError: logger.exception( ( "The returned client credentials token could not be " "decoded. Invalid JSON" ), ) raise access_token: str if not (access_token := t.cast(str, token_data.get("access_token"))): err_msg: str = "No access token in response" raise ValueError(err_msg) self.store_cache_token( access_token, token_data.get("refresh_token"), token_data.get("id_token") or self.id_token, t.cast(int, token_data.get("expires_in", 3600)), "user", ) return access_token
@staticmethod def _generate_pkce() -> tuple[str, str]: """Generate PKCE code verifier and challenge pair. Returns ------- tuple[str, str] (code_verifier, code_challenge) pair """ code_verifier = secrets.token_urlsafe(64) code_challenge = ( base64.urlsafe_b64encode( hashlib.sha256(code_verifier.encode()).digest() ) .decode() .replace("=", "") ) return code_verifier, code_challenge def _start_local_server(self) -> tuple[OidcTCPServer, int]: """Start temporary HTTP server for OIDC callback. Returns ------- tuple[OidcTCPServer, int] (server instance, port number) """ class CallbackHandler(http.server.SimpleHTTPRequestHandler): """Handler for OIDC redirect with authorization code capture.""" def do_GET(self) -> None: # noqa: N802 """Handle GET request for OIDC callback.""" query = parse_qs(urlparse(self.path).query) if "code" in query: self.send_response(200) self.end_headers() _ = self.wfile.write( b"Authentication successful! " b"You can close this window." ) auth_server = t.cast(OidcTCPServer, self.server) auth_server.auth_code = query["code"][0] else: self.send_response(400) self.end_headers() _ = self.wfile.write(b"Missing authorization code") # TODO: Find out if we can do random ports in MAS instead of having # a fixed one. server_address: tuple[str, int] = ("127.0.0.1", 8298) server = OidcTCPServer(server_address, CallbackHandler) thread = threading.Thread(target=server.serve_forever) thread.start() return server, server.server_address[1]
[docs] def get_user_token(self, claims: t.Iterable[str]) -> str: """Get access token using authorization code flow with PKCE. Returns ------- str Valid access token Raises ------ TimeoutError If user doesn't complete authentication within 5 minutes httpx.HTTPStatusError For HTTP request failures ValueError If token response is invalid """ self.recall_cached_token("user") logger.debug("Recalled cached token for 'user'") if self.access_token and time.time() < self.expires_at: logger.debug("Recalled acces token exists and is not expired") return self.access_token logger.debug("Recalled access token was invalid or is expired") if self.refresh_token: logger.debug("Refresh token exists") new_access_token = self.refresh_access_token() logger.debug( "Using recalled refresh token token to get a new access token" ) if new_access_token: logger.debug("Refreshed access token exists") return new_access_token code_verifier, code_challenge = self._generate_pkce() server, port = self._start_local_server() logger.debug("Started local server") # TODO: make this configuratble redirect_uri = f"http://127.0.0.1:{port}/callback" auth_url = f"{self.auth_endpoint}?" params = { "response_type": "code", "client_id": self.client_id, "redirect_uri": redirect_uri, "scope": " ".join(claims), "code_challenge": code_challenge, "code_challenge_method": "S256", "state": "active", } url: str = f"{auth_url}{urllib.parse.urlencode(params)}" if webbrowser.open_new_tab(url): print("A new tab should have opened in your browser.") print( "If not, please visit this URL in your browser " f"manually:\n{url}\n" ) else: print(f"Please visit this URL in your browser:\n{url}\n") try: start_time = time.time() while ( time.time() - start_time < type(self).wait_for_auth_code_timeout ): # 5 minute timeout if server.auth_code: logger.debug("Got auth code from from browser flow") break time.sleep(1) else: err_msg: str = "Authorization timed out" logger.debug("Browser flow authorization timed out") raise TimeoutError(err_msg) logger.debug("Requesting access token") token_response = httpx.post( self.token_endpoint, data={ "grant_type": "authorization_code", "client_id": self.client_id, "client_secret": self.client_secret, "code": server.auth_code, "redirect_uri": redirect_uri, "state": "active", "code_verifier": code_verifier, }, ) _ = token_response.raise_for_status() token_data: JsonDict = t.cast(JsonDict, token_response.json()) access_token: str = t.cast(str, token_data.get("access_token")) logger.debug("Got response from requesting access token") if not access_token: logger.debug("There was no acccess token in the response") err_msg = "No access token in response" raise ValueError(err_msg) self.store_cache_token( access_token, token_data.get("refresh_token"), token_data.get("id_token") or self.id_token, t.cast(int, token_data.get("expires_in", 3600)), "user", ) logger.debug("Cached token stored") return access_token finally: logger.debug("Shutting down web server") server.shutdown()
[docs] def refresh_access_token(self) -> str | None: """Refresh access token using refresh token. Returns ------- str | None New access token if successful, None otherwise """ logger.debug("Started refresing access token") if not self.refresh_token: logger.debug( "Unable to refresh access token. " "I don not have a refresh token" ) return None try: response = httpx.post( self.token_endpoint, data={ "grant_type": "refresh_token", "client_id": self.client_id, "client_secret": self.client_secret, "refresh_token": self.refresh_token, }, ) _ = response.raise_for_status() logger.debug("Got refresh token response") token_data: JsonDict = t.cast(JsonDict, response.json()) access_token: str if access_token := t.cast(str, token_data.get("access_token")): self.store_cache_token( access_token, t.cast( str, token_data.get("refresh_token", self.refresh_token), ), token_data.get("id_token") or self.id_token, t.cast(int, token_data.get("expires_in", 3600)), "user", ) logger.debug("Stored refreshed token") return access_token except httpx.HTTPStatusError as e: logger.exception( "Token refresh failed: %s %s", e.response.status_code, e.response.text, ) return None except json.JSONDecodeError: logger.exception( ( "The returned refresh token could not be decoded. Invalid " "JSON: %s" ), self.cache_path, ) return None logger.error("No access token in response") return None
def _exchange_code( # TODO: Unused self, code_verifier: str, auth_code: str | None, redirect_uri: str ) -> dict[str, t.Any]: """Exchange authorization code for tokens.""" if not auth_code: err_msg = "Missing authorization code" raise ValueError(err_msg) response = httpx.post( self.token_endpoint, data={ "grant_type": "authorization_code", "client_id": self.client_id, "client_secret": self.client_secret, "code": auth_code, "redirect_uri": redirect_uri, "state": "active", "code_verifier": code_verifier, }, ) _ = response.raise_for_status() return t.cast(dict[str, t.Any], response.json())
[docs] def discover_oidc_endpoints(issuer_url: str) -> JsonDict: """Retrieve OIDC provider configuration via discovery. Parameters ---------- issuer_url : str Base URL of the OIDC issuer Returns ------- dict[str, t.Any] OIDC provider configuration Raises ------ httpx.HTTPStatusError For HTTP request failures ValueError If discovery document is invalid """ try: discovery_url = issuer_url.rstrip("/") response = httpx.get(discovery_url, timeout=10) _ = response.raise_for_status() oidc_config: JsonDict = t.cast(JsonDict, response.json()) except httpx.HTTPStatusError as e: logger.exception( "Discovery request failed: %s %s", e.response.status_code, e.response.text, ) raise except json.JSONDecodeError: logger.exception( ( "The discovery request JSON response could not be " "decoded. Invalid JSON: %s" ), response, ) raise logger.debug("OIDC discovery response: %s", oidc_config) return oidc_config