Source code for matrixctl.handlers.db

#!/usr/bin/env python
# matrixctl
# Copyright (c) 2020  Michael Sasser <Michael@MichaelSasser.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Talk to the the database."""


from __future__ import annotations

import logging
import sys
import typing as t
import urllib.parse

from collections.abc import Iterator
from contextlib import contextmanager

import psycopg
import sshtunnel

from .yaml import YAML


__author__: str = "Michael Sasser"
__email__: str = "Michael@MichaelSasser.org"


logger = logging.getLogger(__name__)


[docs]class DBConnectionBuilder(t.NamedTuple): """Build the URL for an API request.""" host: str database: str username: str password: str port: int = 5432 timeout: int = 10 scheme: str = "postgresql" def __str__(self) -> str: """Build the URL. Parameters ---------- None Returns ------- url : str The URL. """ url: str = ( f"{self.scheme}://" f"{self.username}:{self.password}@{self.host}:{self.port}" f"/{self.database}" f"?connect_timeout={self.timeout}" ) return urllib.parse.urlparse(url).geturl()
[docs]@contextmanager def ssh_tunnel( host: str, username: str, remote_port: int, enabled: bool = True, port: int = 22, # private_key: Path | str | None = None, ) -> Iterator[int | None]: """Create an SSH tunnel. Notes ----- The tunnel will only be created, when it is enabled. If the tunnel is disabled (``enabled = False``), the function will yield ``None`` instead of the local bind port. Examples -------- .. code-block:: python with ssh_tunnel("127.0.0.1", myuser, 5432) as remote_port: print(f"The local bind port is: {local_bind_port}") # The local bind port is: 8765 Parameters ---------- host : str The remote host e.g. ``127.0.0.1`` or ``host.domain.tld``. username : str The username of the user. remote_port : int The port of the application, which should be tunneled. enabled : bool, default: True ``True`` if the tunnel should be enabled or ``False`` if not. port : int, default: 22 The ssh port private_key : Path or str, optioal The path to the private key (Currently Disabled) Yields ------ tun : int The remote port None : None Yields none, when the tunnel is disabled (``enabled = False``). """ if enabled: tun = sshtunnel.SSHTunnelForwarder( ssh_address_or_host=(host, port), ssh_username=username, remote_bind_address=("127.0.0.1", remote_port), ssh_pkey=None, logger=logging.getLogger(sshtunnel.__name__), ) try: tun.start() logger.debug( "SSH tunnel created using port: %s", tun.local_bind_port ) yield tun.local_bind_port finally: tun.stop() logger.debug("SSH tunnel closed") return yield None
[docs]@contextmanager def db_connect(yaml: YAML) -> Iterator[psycopg.Connection]: """Connect to a PostgreSQL database. Parameters ---------- yaml : matrixctl.handlers.yaml.YAML The configuration file handler. Yields ------ conn : psycopg.Connection A new ``Connection`` instance. """ with ssh_tunnel( host=yaml.get("server", "ssh", "address"), port=int(yaml.get("server", "ssh", "port")), username=yaml.get("server", "ssh", "user"), remote_port=yaml.get("server", "database", "port"), enabled=yaml.get("server", "database", "tunnel"), # skipcq PY-W0069 # private_key=yaml.get("server", "database", "private_key") ) as local_bind_port: connection_uri = DBConnectionBuilder( host=( "127.0.0.1" if yaml.get("server", "database", "tunnel") else yaml.get("server", "ssh", "address") ), port=int( local_bind_port or yaml.get("server", "database", "port") ), username=yaml.get("server", "database", "synapse_user"), password=yaml.get("server", "database", "synapse_password"), database=yaml.get("server", "database", "synapse_database"), ) conn = psycopg.connect(str(connection_uri)) try: yield conn except BaseException as e: # skipcq: PYL-W0703 logger.error("Rollback initiated.BaseException: %s", e) conn.rollback() sys.exit(1) else: conn.commit() logger.debug("successful -> commit") finally: conn.close() logger.debug("Connection to the Database has been closed.")
# vim: set ft=python :