Source code for scystream.sdk.database_handling.database_manager

from abc import ABC, abstractmethod
from urllib.parse import urlparse
from pyspark.sql import SparkSession, DataFrame

import pandas as pd
from sqlalchemy import create_engine, text
from sqlalchemy.sql import quoted_name


[docs] class BaseDatabaseOperations(ABC): MAX_TABLE_NAME_LENGTH = 63
[docs] def __init__(self, dsn: str, schema: str | None = None): self.dsn = dsn self.schema = self._normalize_schema(schema)
def _validate_table_name(self, table: str): if len(table) > self.MAX_TABLE_NAME_LENGTH: raise ValueError( f"Table name '{table}' exceeds {self.MAX_TABLE_NAME_LENGTH}\ characters." ) def _validate_read_inputs( self, table: str | None, query: str | None, ): if not table and not query: raise ValueError("Either 'table' or 'query' must be provided.") if table: self._validate_table_name(table) @staticmethod def _normalize_schema(schema: str | None) -> str | None: if schema is None: return None schema = schema.strip() return schema or None
[docs] @abstractmethod def read( self, table: str | None = None, query: str | None = None, ): pass
[docs] @abstractmethod def write( self, table: str, data, mode: str = "overwrite", ): pass
[docs] class SparkDatabaseOperations(BaseDatabaseOperations): """ Class to perform PostgreSQL operations using Apache Spark. This class provides methods to read from and write to a PostgreSQL database using JDBC and Spark's DataFrame API. It requires a SparkSession and a PostgresConfig object or the PostgresSettings from an input or output for database connectivity. """
[docs] def __init__(self, spark: SparkSession, dsn: str, schema: str | None): super().__init__(dsn, schema) self.spark_session = spark self.jdbc_url, self.properties = self._dsn_to_jdbc(dsn) self.properties = { "driver": "org.postgresql.Driver", }
def _dsn_to_jdbc(self, dsn: str) -> tuple[str, dict]: """ Convert SQLAlchemy DSN to JDBC URL and connection properties. """ parsed = urlparse(dsn) if not parsed.hostname: raise ValueError("Invalid DSN: missing hostname") # Build JDBC URL jdbc_url = ( f"jdbc:postgresql://{parsed.hostname}:" f"{parsed.port or 5432}{parsed.path}" ) # Extract credentials properties = { "driver": "org.postgresql.Driver", } if parsed.username: properties["user"] = parsed.username if parsed.password: properties["password"] = parsed.password return jdbc_url, properties
[docs] def read( self, table: str | None = None, query: str | None = None, ) -> DataFrame: """ Reads data from a PostgreSQL database into a Spark DataFrame. This method can either read data from a specified table or execute a custom SQL query to retrieve data from the database. :param table: The name of the table to read data from. Must be provided if `query` is not supplied. (optional) :param query: A custom SQL query to run. If provided, this overrides the `table` parameter. (optional) :return: A Spark DataFrame containing the result of the query or table data. :rtype: DataFrame """ self._validate_read_inputs(table, query) if query: dbtable_option = f"({query}) AS subquery" else: dbtable_option = f"{self.schema}.{table}" if self.schema else table return ( self.spark_session.read.format("jdbc") .option("url", self.jdbc_url) .option("dbtable", dbtable_option) .options(**self.properties) .load() )
[docs] def write( self, table: str, dataframe, mode="overwrite", schema: str | None = None, ): """ Writes a Spark DataFrame to a specified table in a PostgreSQL database using JDBC. This method writes the provided DataFrame to the target PostgreSQL table, with the option to specify the write mode (overwrite, append, etc.). :param table: The name of the table where data will be written. :param dataframe: The Spark DataFrame containing the data to write. :param mode: The write mode. Valid options are 'overwrite', 'append', 'ignore', and 'error'. Defaults to 'overwrite'. (optional) :note: Ensure that the schema of the DataFrame matches the schema of the target table if the table exists. :note: The `mode` parameter controls the behavior when the table already exists. """ self._validate_table_name(table) dbtable_option = f"{schema}.{table}" if self.schema else table ( dataframe.write.format("jdbc") .option("url", self.jdbc_url) .option("dbtable", dbtable_option) .options(**self.properties) .mode(mode) .save() )
[docs] class PandasDatabaseOperations(BaseDatabaseOperations): """ Database operations using Pandas and SQLAlchemy. This class provides a simple interface to read from and write to any SQLAlchemy-compatible database using Pandas DataFrames. The connection is established via a DSN (Data Source Name), making this implementation backend-agnostic. Supported databases include (but are not limited to): - PostgreSQL - MySQL - SQLite - Snowflake - Oracle This implementation is best suited for local or small-to-medium sized datasets where distributed processing (e.g., Spark) is not required. """
[docs] def __init__(self, dsn: str, schema: str | None = None): """ Initialize the PandasDatabaseOperations instance. :param dsn: A SQLAlchemy-compatible database connection string (DSN), e.g.: - postgresql://user:pass@host:5432/db - mysql+pymysql://user:pass@host/db - sqlite:///local.db :param schema: An optional schema used in postgres databases can be specified :raises ValueError: If the DSN is invalid or connection fails. :note: Uses SQLAlchemy's connection pooling with `pool_pre_ping=True` to ensure stale connections are automatically refreshed. """ super().__init__(dsn, schema) self.engine = create_engine(dsn, pool_pre_ping=True)
[docs] def read( self, table: str | None = None, query: str | None = None, ) -> pd.DataFrame: """ Read data from the database into a Pandas DataFrame. This method supports two modes of operation: - Reading all rows from a specified table - Executing a custom SQL query :param table: The name of the table to read from. Must be provided if `query` is not supplied. (optional) :param query: A custom SQL query to execute. If provided, this overrides the `table` parameter. (optional) :raises ValueError: If neither `table` nor `query` is provided. :raises ValueError: If the table name exceeds the allowed length. :return: A Pandas DataFrame containing the query result. :rtype: pandas.DataFrame :example: >>> db.read(table="users") >>> db.read(query="SELECT id, name FROM users WHERE active = true") """ self._validate_read_inputs(table, query) if table: if self.schema: query = f'SELECT * FROM "{self.schema}"."{table}"' else: query = f'SELECT * FROM "{table}"' return pd.read_sql(text(query), self.engine)
[docs] def write( self, table: str, data: pd.DataFrame, mode: str = "overwrite", ): """ Write a Pandas DataFrame to the database. This method writes the provided DataFrame to the specified table using SQLAlchemy. The behavior when the table already exists is controlled via the `mode` parameter. :param table: The name of the target table. :param data: The Pandas DataFrame to write. :param mode: The write mode. Supported options are: - 'overwrite': Replace the table if it exists. - 'append': Append data to the existing table. Defaults to 'overwrite'. (optional) :raises ValueError: If the table name exceeds the allowed length. :raises ValueError: If an unsupported mode is provided. :note: - The DataFrame index is not written to the database. - Ensure schema compatibility when using `mode='append'`. :example: >>> db.write("users", df) >>> db.write("users", df, mode="append") """ self._validate_table_name(table) if mode == "overwrite": if_exists = "replace" elif mode == "append": if_exists = "append" else: raise ValueError(f"Unsupported mode: {mode}") table_name = quoted_name(table, quote=True) data.to_sql( name=table_name, con=self.engine, schema=self.schema, if_exists=if_exists, index=False, )