diff --git a/requirements.txt b/requirements.txt index 228b8d4..0acb4bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ +sqlalchemy==2.0.43 pytest==8.4.1 pytest-cov==6.2.1 -pdoc==15.0.4 flake8==7.3.0 pdoc==15.0.4 \ No newline at end of file diff --git a/stat_log_db/pyproject.toml b/stat_log_db/pyproject.toml index efd8321..1db96d3 100644 --- a/stat_log_db/pyproject.toml +++ b/stat_log_db/pyproject.toml @@ -9,6 +9,7 @@ description = "" readme = "README.md" requires-python = ">=3.12.10" dependencies = [ + "sqlalchemy==2.0.43" ] [project.optional-dependencies] diff --git a/stat_log_db/src/stat_log_db/__init__.py b/stat_log_db/src/stat_log_db/__init__.py deleted file mode 100644 index 0a3f2c1..0000000 --- a/stat_log_db/src/stat_log_db/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import exceptions -from . import parser -from . import db -from . import cli diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index 55189b7..a8036cc 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -1,8 +1,12 @@ -import os +# import os # import sys +from sqlalchemy import select +from sqlalchemy.orm import Session + # from .parser import create_parser -from .db import MemDB # , FileDB, Database, BaseConnection +from stat_log_db.db import Database as DB +from stat_log_db.modules.log import Log, LogType, LogLevel def main(): @@ -18,21 +22,31 @@ def main(): # print(f"{args=}") - sl_db = MemDB({ - "is_mem": True, - "fkey_constraint": True + sl_db = DB({ + "is_mem": True }) - con = sl_db.init_db(True) - con.create_table("test", [('notes', 'TEXT')], False, True) - con.execute("INSERT INTO test (notes) VALUES (?);", ("Hello world!",)) - con.commit() - con.execute("SELECT * FROM test;") - sql_logs = con.fetchall() - print(sql_logs) - con.drop_table("test", True) - sl_db.close_db() - if sl_db.is_file: - os.remove(sl_db.file_name) + sl_db.init_db() + with Session(sl_db.engine) as session: + info_type = LogType( + name="INFO" + ) + session.add(info_type) + session.commit() + info_level = LogLevel( + name="INFO" + ) + session.add(info_level) + session.commit() + hello_world = Log( + type_id=1, + level_id=1, + message="Hello, World!" + ) + session.add(hello_world) + session.commit() + logs = select(Log).where(Log.id == 1) + for log in session.scalars(logs): + print(f"{log.id=}, {log.type_id=}, {log.level_id=}, {log.message=}") if __name__ == "__main__": diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index 6b29ea9..35ad0e7 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -1,11 +1,13 @@ -# from abc import ABC, abstractmethod -import re -import sqlite3 import uuid from typing import Any +# import sqlite3 +from sqlalchemy import create_engine as sqla_create_engine +from sqlalchemy.engine import Engine -from .exceptions import raise_auto_arg_type_error +# from .exceptions import raise_auto_arg_type_error + +from stat_log_db.modules.base import BaseModel class Database(): @@ -14,7 +16,8 @@ def __init__(self, options: dict[str, Any] = {}): valid_options = { "db_name": str, "is_mem": bool, - "fkey_constraint": bool + # "fkey_constraint": bool, + "debug": bool } for opt, opt_type in options.items(): if opt not in valid_options.keys(): @@ -27,9 +30,12 @@ def __init__(self, options: dict[str, Any] = {}): self._is_file: bool = bool(not self._in_memory) self._db_name: str = options.get("db_name", str(uuid.uuid4())) self._db_file_name: str = ":memory:" if self._in_memory else self._db_name.replace(" ", "_") - self._fkey_constraint: bool = options.get("fkey_constraint", True) - # Keep track of active connections (to ensure that they are closed) - self._connections: dict[str, BaseConnection] = dict() + # self._fkey_constraint: bool = options.get("fkey_constraint", True) + self._debug: bool = options.get("debug", False) + # SQLAlchemy engine + self._engine: Engine | None = None + + # region Properties @property def name(self) -> str: @@ -47,433 +53,47 @@ def in_memory(self) -> bool: def is_file(self) -> bool: return self._is_file - @property - def fkey_constraint(self) -> bool: - return self._fkey_constraint - - def check_connection_integrity(self, connection: 'str | BaseConnection', skip_registry_type_check: bool = False): - """ - Check the integrity of a given connection's registration. - The connection to be checked can be passed as an UID string or a connection object (instance of BaseConnection). - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - connection_is_uid_str = isinstance(connection, str) - connection_is_obj = isinstance(connection, BaseConnection) - if (not connection_is_uid_str) and (not connection_is_obj): - raise_auto_arg_type_error("connection") - if self._connections is None or len(self._connections) == 0: - raise ValueError(f"Connection {connection.uid if connection_is_obj else connection} is not registered, as Connection Registry contains no connections.") - # Check that the registry is of the expected type - if (not skip_registry_type_check) and (not isinstance(self._connections, dict)): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - # If the passed-in connection is a uid string, - # search the registry keys for that uid string. - # Check that a matching connection is found, - # that it has a valid UID, and that it is registered - # under the uid that it has (registry key = found connection's uid). - if connection_is_uid_str: - if len(connection) == 0: - raise ValueError("Connection UID string is empty.") - found_connection = self._connections.get(connection, None) - if found_connection is None: - raise ValueError(f"Connection '{connection}' is not registered.") - if not isinstance(found_connection.uid, str): - raise TypeError(f"Expected the found connection's uid to be str, got {type(found_connection.uid).__name__} instead.") - if len(found_connection.uid) == 0: - raise ValueError("Found connection's uid string is empty.") - if found_connection.uid != connection: - raise ValueError(f"Connection '{connection}' is registered under non-matching uid: {found_connection.uid}") - # If the passed-in connection is a BaseConnection object, - # check that it has a valid uid and that it's UID is in the registry - elif connection_is_obj: - if not isinstance(connection.uid, str): - raise TypeError(f"Expected the connection's uid to be str, got {type(connection.uid).__name__} instead.") - if connection.uid not in self._connections: - raise ValueError(f"Connection '{connection.uid}' is not registered, or is registered under the wrong uid.") - - def check_connection_registry_integrity(self, skip_registry_type_check: bool = False): - """ - Check the integrity of the connection registry. - If not all connections are registered, no error is raised. - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - # Check that the registry is of the expected type - if (not skip_registry_type_check) and (not isinstance(self._connections, dict)): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - # If there are no connections, nothing to check - if len(self._connections) == 0: - return - # Check that all registered connections are registered under a UID of the correct type and are instances of BaseConnection - if any((not isinstance(uid, str)) or (not isinstance(conn, BaseConnection)) for uid, conn in self._connections.items()): - raise TypeError("All connections must be registered by their UID string and be instances of BaseConnection.") - # Perform individual connection integrity checks - for uid in self._connections.keys(): - self.check_connection_integrity(uid, skip_registry_type_check=True) # Registry type already checked - - def _register_connection(self): - """ - Creates a new database connection object and registers it. - Does not open the connection. - """ - connection = BaseConnection(self) - self._connections[connection.uid] = connection - self.check_connection_integrity(connection) - return connection - - def _unregister_connection(self, connection: 'str | BaseConnection'): - """ - Unregister a database connection object. - Does not close it. - """ - connection_is_obj = isinstance(connection, BaseConnection) - if (not isinstance(connection, str)) and (not connection_is_obj): - raise_auto_arg_type_error("connection") - connection_uid_str = connection.uid if connection_is_obj else connection - self.check_connection_integrity(connection_uid_str) - # TODO: consider implementing garbage collector ref-count check - del self._connections[connection_uid_str] - - def init_db(self, commit_fkey: bool = True) -> 'BaseConnection': - if not isinstance(commit_fkey, bool): - raise_auto_arg_type_error("commit_fkey") - connection = self._register_connection() - connection.open() - connection.enforce_foreign_key_constraints(commit_fkey) - return connection - - def init_db_auto_close(self): - if self.in_memory: - raise ValueError("In-memory databases cease to exist upon closure.") - # don't bother to commit fkey constraint because close() will commit before connection closure - connection = self.init_db(False) - connection.close() - self._unregister_connection(connection.uid) - - def close_db(self): - uids = [] - self.check_connection_registry_integrity() - for uid, connection in self._connections.items(): - connection.close() - uids.append(uid) - for uid in uids: - self._unregister_connection(uid) - if not len(self._connections) == 0: - raise RuntimeError("Not all connections were closed properly.") - self._connections = dict() - - -class MemDB(Database): - def __init__(self, options: dict[str, Any] = {}): - super().__init__(options=options) - if not self.in_memory: - raise ValueError("MemDB can only be used for in-memory databases.") - - def check_connection_registry_integrity(self, skip_registry_type_check: bool = False): - """ - Check the integrity of the connection registry. - Implements early raise if more than one connection is found, - since in-memory databases can only have one connection. - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - if not skip_registry_type_check: - if not isinstance(self._connections, dict): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - if (num_connections := len(self._connections)) > 1: - raise ValueError(f"In-memory databases can only have one active connection Found {num_connections}.") - return super().check_connection_registry_integrity(skip_registry_type_check=True) # Registry type already checked - - def init_db_auto_close(self): - raise ValueError("In-memory databases cease to exist upon closure.") - - -class FileDB(Database): - def __init__(self, options: dict[str, Any] = {}): - super().__init__(options=options) - if not self.is_file: - raise ValueError("FileDB can only be used for file-based databases.") - - -class BaseConnection: - def __init__(self, db: Database): - if not isinstance(db, Database): - raise_auto_arg_type_error("db") - self._db: Database = db - self._id = str(uuid.uuid4()) - self._connection: sqlite3.Connection | None = None - self._cursor: sqlite3.Cursor | None = None - - @property - def db_name(self): - return self._db._db_name - - @property - def db_file_name(self): - return self._db._db_file_name - - @property - def db_in_memory(self): - return self._db._in_memory - - @property - def db_is_file(self): - return self._db._is_file - - @property - def db_fkey_constraint(self): - return self._db._fkey_constraint - - @property - def uid(self): - # TODO: Hash together the uuid, db_name, and possibly also the location in memory to ensure uniqueness? - return self._id + # @property + # def fkey_constraint(self) -> bool: + # return self._fkey_constraint @property - def registered(self): - self._db.check_connection_integrity(self) # raises error if not registered - return True + def debug(self) -> bool: + return self._debug @property - def connection(self): - if self._connection is None: - raise RuntimeError("Connection is not open.") - if not isinstance(self._connection, sqlite3.Connection): - raise TypeError(f"Expected self._connection to be sqlite3.Connection, got {type(self._connection).__name__} instead.") - return self._connection - - @property - def cursor(self): - if self._cursor is None: - raise RuntimeError("Cursor is not open.") - if not isinstance(self._cursor, sqlite3.Cursor): - raise TypeError(f"Expected self._cursor to be sqlite3.Cursor, got {type(self._cursor).__name__} instead.") - return self._cursor - - def enforce_foreign_key_constraints(self, commit: bool = True): - if not isinstance(commit, bool): - raise_auto_arg_type_error("commit") - if self.db_fkey_constraint: - self.cursor.execute("PRAGMA foreign_keys = ON;") - if commit: - self.connection.commit() - - def _open(self): - self._connection = sqlite3.connect(self.db_file_name) - self._cursor = self._connection.cursor() - - def open(self): - if isinstance(self._connection, sqlite3.Connection): - raise RuntimeError("Connection is already open.") - if not (self._connection is None): - raise TypeError(f"Expected self._connection to be None, got {type(self._connection).__name__} instead.") - self._open() - - def _close(self): - self.cursor.close() - self._cursor = None - self.connection.close() - self._connection = None - - def close(self): - self.connection.commit() - self._close() - - def _execute(self, query: str, parameters: tuple = ()): + def engine(self) -> Engine: """ - Execute a SQL query with the given parameters. - Performs no checks/validation. Prefer `execute` unless you need raw access. - """ - result = self.cursor.execute(query, parameters) - return result + Get the SQLAlchemy database engine. - def execute(self, query: str, parameters: tuple | None = None): - """ - Execute a SQL query with the given parameters. + `self._engine` will be `None` if the database has not been initialized. + In which case, calling `self.engine` (this property) will raise an error. """ - # Validate query and parameters - if not isinstance(query, str): - raise_auto_arg_type_error("query") - if len(query) == 0: - raise ValueError("'query' argument of execute cannot be an empty string!") - # Create a new space in memory that points to the same object that `parameters` points to - params = parameters - # If `params` points to None, update it to point to an empty tuple - if params is None: - params = tuple() - # If `params` points to an object that isn't a tuple or None (per previous condition), raise a TypeError - elif not isinstance(params, tuple): - raise_auto_arg_type_error("parameters") - # Execute query with `params` - result = self._execute(query, params) - return result - - def commit(self): - self.connection.commit() + if self._engine is None: + raise ValueError("Database engine is not initialized. Call 'init_db()' first.") + if not isinstance(self._engine, Engine): + raise TypeError(f"Database engine is not of type 'Engine', got '{type(self._engine).__name__}'.") + return self._engine - def fetchone(self): - return self.cursor.fetchone() + # endregion - def fetchall(self): - return self.cursor.fetchall() - - def _validate_sql_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: + def init_db(self): """ - Validate and sanitize SQL identifiers (table names, column names) to prevent SQL injection. - Args: - identifier: The identifier to validate - identifier_type: Type of identifier for error messages (e.g., "table name", "column name") - Returns: - The validated identifier - Raises: - ValueError: If the identifier is invalid or potentially dangerous + Initialize the database. """ - if not isinstance(identifier, str): - raise TypeError(f"SQL {identifier_type} must be a string, got {type(identifier).__name__}") - if len(identifier) == 0: - raise ValueError(f"SQL {identifier_type} cannot be empty") - # Check for valid identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore - if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier): - raise ValueError(f"Invalid SQL {identifier_type}: '{identifier}'. Must start with letter or underscore and contain only letters, numbers, and underscores.") - # Check against SQLite reserved words (common ones that could cause issues) - reserved_words = { - 'abort', 'action', 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', 'cascade', 'case', - 'cast', 'check', 'collate', 'column', 'commit', 'conflict', 'constraint', 'create', - 'cross', 'current', 'current_date', 'current_time', 'current_timestamp', 'database', - 'default', 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', 'do', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', 'exists', 'explain', - 'fail', 'filter', 'following', 'for', 'foreign', 'from', 'full', 'glob', 'group', - 'having', 'if', 'ignore', 'immediate', 'in', 'index', 'indexed', 'initially', 'inner', - 'insert', 'instead', 'intersect', 'into', 'is', 'isnull', 'join', 'key', 'left', - 'like', 'limit', 'match', 'natural', 'no', 'not', 'notnull', 'null', 'of', 'offset', - 'on', 'or', 'order', 'outer', 'over', 'partition', 'plan', 'pragma', 'preceding', - 'primary', 'query', 'raise', 'range', 'recursive', 'references', 'regexp', 'reindex', - 'release', 'rename', 'replace', 'restrict', 'right', 'rollback', 'row', 'rows', - 'savepoint', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', 'transaction', - 'trigger', 'unbounded', 'union', 'unique', 'update', 'using', 'vacuum', 'values', - 'view', 'virtual', 'when', 'where', 'window', 'with', 'without' - } - if identifier.lower() in reserved_words: - raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word and cannot be used") - return identifier + self._engine = sqla_create_engine(f"sqlite:///{self._db_file_name}") + BaseModel.metadata.create_all(self.engine) - def _escape_sql_identifier(self, identifier: str) -> str: + def close_db(self): """ - Escape SQL identifier by wrapping in double quotes and escaping any internal quotes. - This should only be used after validation. + Close the database. """ - # Escape any double quotes in the identifier by doubling them - escaped = identifier.replace('"', '""') - return f'"{escaped}"' - - def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True): - # Validate table_name argument - if not isinstance(table_name, str): - raise_auto_arg_type_error("table_name") - if len(table_name) == 0: - raise ValueError("'table_name' argument of create_table cannot be an empty string!") - # Validate temp_table argument - if not isinstance(temp_table, bool): - raise_auto_arg_type_error("temp_table") - if not isinstance(raise_if_exists, bool): - raise_auto_arg_type_error("raise_if_exists") - # Validate columns argument - if (not isinstance(columns, list)) or (not all( - isinstance(col, tuple) and len(col) == 2 - and isinstance(col[0], str) - and isinstance(col[1], str) - for col in columns)): - raise_auto_arg_type_error("columns") - # Validate and sanitize table name - validated_table_name = self._validate_sql_identifier(table_name, "table name") - escaped_table_name = self._escape_sql_identifier(validated_table_name) - # Check if table already exists using parameterized query - if raise_if_exists: - self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) - if self.cursor.fetchone() is not None: - raise ValueError(f"Table '{validated_table_name}' already exists.") - # Validate and construct columns portion of query - validated_columns = [] - for col_name, col_type in columns: - # Validate column name - validated_col_name = self._validate_sql_identifier(col_name, "column name") - escaped_col_name = self._escape_sql_identifier(validated_col_name) - # Validate column type - allow only safe, known SQLite types - allowed_types = { - 'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC', - 'VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR', - 'CLOB', 'DATE', 'DATETIME', 'TIMESTAMP', - 'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT', - 'INT', 'BIGINT', 'SMALLINT', 'TINYINT' - } - # Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2)) - base_type = re.match(r'^([A-Z]+)', col_type.upper()) - if not base_type or base_type.group(1) not in allowed_types: - raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}") - # Basic validation for type specification format - if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', col_type.upper()): - raise ValueError(f"Invalid column type format: '{col_type}'") - validated_columns.append(f"{escaped_col_name} {col_type.upper()}") - columns_qstr = ",\n ".join(validated_columns) - # Assemble full query with escaped identifiers - temp_keyword = " TEMPORARY" if temp_table else "" - query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - {columns_qstr} - );""" - self.execute(query) - - def drop_table(self, table_name: str, raise_if_not_exists: bool = False): - # Validate table_name argument - if not isinstance(table_name, str): - raise_auto_arg_type_error("table_name") - if len(table_name) == 0: - raise ValueError("'table_name' argument of drop_table cannot be an empty string!") - if not isinstance(raise_if_not_exists, bool): - raise_auto_arg_type_error("raise_if_not_exists") - # Validate and sanitize table name - validated_table_name = self._validate_sql_identifier(table_name, "table name") - escaped_table_name = self._escape_sql_identifier(validated_table_name) - # Check if table exists using parameterized query - if raise_if_not_exists: - self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) - if self.cursor.fetchone() is None: - raise ValueError(f"Table '{validated_table_name}' does not exist.") - # Execute DROP statement with escaped identifier - self.cursor.execute(f"DROP TABLE IF EXISTS {escaped_table_name};") - - # def read(self): - # pass - - # def write(self): - # pass - - # def create(self): - # pass - - # def unlink(self): - # pass - - -class Connection(BaseConnection): - def __init__(self, db: Database): - super().__init__(db) - self.open() - self.enforce_foreign_key_constraints(True) - - -class FileConnectionCtx(BaseConnection): - def __init__(self, db: Database): - super().__init__(db) - if not self.db_is_file: - raise ValueError("FileConnectionCtx can only be used with file-based databases.") - - def __enter__(self): - self.open() - self.enforce_foreign_key_constraints(True) - return self + self.engine.dispose() + self._engine = None - def __exit__(self, exc_type, exc_value, exc_tb): - self.close() + # def connect(self): + # """ + # Create and return a new database connection. + # """ + # connection = self.engine.connect() + # return connection diff --git a/stat_log_db/src/stat_log_db/modules/base/__init__.py b/stat_log_db/src/stat_log_db/modules/base/__init__.py new file mode 100644 index 0000000..d2329d8 --- /dev/null +++ b/stat_log_db/src/stat_log_db/modules/base/__init__.py @@ -0,0 +1,3 @@ +from .models.base import BaseModel + +__all__ = ["BaseModel"] diff --git a/stat_log_db/src/stat_log_db/modules/base/models/base.py b/stat_log_db/src/stat_log_db/modules/base/models/base.py new file mode 100644 index 0000000..122b42c --- /dev/null +++ b/stat_log_db/src/stat_log_db/modules/base/models/base.py @@ -0,0 +1,16 @@ +# from typing import Any, List, Optional +from datetime import datetime + +from sqlalchemy import TIMESTAMP, func # , ForeignKey, String, +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column # , relationship + +# from .exceptions import raise_auto_arg_type_error + + +class BaseModel(DeclarativeBase): + __tablename__: str + + id: Mapped[int] = mapped_column(primary_key=True) + + created_at: Mapped[datetime] = mapped_column(TIMESTAMP, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column(TIMESTAMP, server_default=func.now(), onupdate=func.now()) diff --git a/stat_log_db/src/stat_log_db/modules/log/__init__.py b/stat_log_db/src/stat_log_db/modules/log/__init__.py new file mode 100644 index 0000000..d00df4e --- /dev/null +++ b/stat_log_db/src/stat_log_db/modules/log/__init__.py @@ -0,0 +1,9 @@ +from .models.log_level import LogLevel +from .models.log_type import LogType +from .models.log import Log + +__all__ = [ + "Log", + "LogLevel", + "LogType" +] diff --git a/stat_log_db/src/stat_log_db/modules/log/models/log.py b/stat_log_db/src/stat_log_db/modules/log/models/log.py new file mode 100644 index 0000000..f614c9d --- /dev/null +++ b/stat_log_db/src/stat_log_db/modules/log/models/log.py @@ -0,0 +1,17 @@ +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from stat_log_db.modules.base import BaseModel +from stat_log_db.modules.log import LogType, LogLevel + + +class Log(BaseModel): + __tablename__ = "log" + + type_id: Mapped[int] = mapped_column(ForeignKey("log_type.id"), nullable=False) + type: Mapped[LogType] = relationship() + + level_id: Mapped[int] = mapped_column(ForeignKey("log_level.id"), nullable=False) + level: Mapped[LogLevel] = relationship() + + message: Mapped[str] = mapped_column(String, nullable=False) diff --git a/stat_log_db/src/stat_log_db/modules/log/models/log_level.py b/stat_log_db/src/stat_log_db/modules/log/models/log_level.py new file mode 100644 index 0000000..4f271e6 --- /dev/null +++ b/stat_log_db/src/stat_log_db/modules/log/models/log_level.py @@ -0,0 +1,10 @@ +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column + +from stat_log_db.modules.base import BaseModel + + +class LogLevel(BaseModel): + __tablename__ = "log_level" + + name: Mapped[str] = mapped_column(String, nullable=False) diff --git a/stat_log_db/src/stat_log_db/modules/log/models/log_type.py b/stat_log_db/src/stat_log_db/modules/log/models/log_type.py new file mode 100644 index 0000000..24cfdcf --- /dev/null +++ b/stat_log_db/src/stat_log_db/modules/log/models/log_type.py @@ -0,0 +1,10 @@ +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column + +from stat_log_db.modules.base import BaseModel + + +class LogType(BaseModel): + __tablename__ = "log_type" + + name: Mapped[str] = mapped_column(String, nullable=False) diff --git a/stat_log_db/src/stat_log_db/modules/tag/__init__.py b/stat_log_db/src/stat_log_db/modules/tag/__init__.py new file mode 100644 index 0000000..606a4a3 --- /dev/null +++ b/stat_log_db/src/stat_log_db/modules/tag/__init__.py @@ -0,0 +1,5 @@ +from .models.tag import Tag + +__all__ = [ + "Tag" +] diff --git a/stat_log_db/src/stat_log_db/modules/tag/models/tag.py b/stat_log_db/src/stat_log_db/modules/tag/models/tag.py new file mode 100644 index 0000000..065cdae --- /dev/null +++ b/stat_log_db/src/stat_log_db/modules/tag/models/tag.py @@ -0,0 +1,10 @@ +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column + +from stat_log_db.modules.base import BaseModel + + +class Tag(BaseModel): + __tablename__ = "tag" + + name: Mapped[str] = mapped_column(String, nullable=False) diff --git a/stat_log_db/tests/test_db.py b/stat_log_db/tests/test_db.py new file mode 100644 index 0000000..183dd12 --- /dev/null +++ b/stat_log_db/tests/test_db.py @@ -0,0 +1,68 @@ +""" +Test SQL injection protection in create_table and drop_table methods. +""" + +import pytest +import sys +from pathlib import Path + +from sqlalchemy import select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from stat_log_db.db import Database +from stat_log_db.modules.tag import Tag + + +# Add the src directory to the path to import the module +ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT / "stat_log_db" / "src")) + + +@pytest.fixture +def db(): + """Create a test in-memory database and close it after tests.""" + sl_db = Database({ + "is_mem": True + }) + sl_db.init_db() + yield sl_db + sl_db.close_db() + + +def test_db_engine_registered(db): + """ Test database operations """ + assert isinstance(db, Database) + assert db.engine is not None, "Database engine is not initialized" + assert isinstance(db.engine, Engine), "Database engine is not of type Engine (sqlalchemy.engine.Engine)" + + +def test_db_session(db): + """ Test database session creation """ + with Session(db.engine) as session: + tag_name = "Test" + test_tag = Tag( + name=tag_name + ) + session.add(test_tag) + session.commit() + tags = select(Tag).where(Tag.id == 1) + results = session.scalars(tags) + assert (num_res := len(results.all())) == 1, f"Expected 1 result, got {num_res}" + + +def test_db_basic_class(db): + """ Test basic class (Tag) """ + with Session(db.engine) as session: + tag_name = "Test" + test_tag = Tag( + name=tag_name + ) + session.add(test_tag) + session.commit() + tags_stmt = select(Tag).where(Tag.id == 1) + tags = session.scalars(tags_stmt).all() + assert (num_res := len(tags)) == 1, f"Expected 1 result, got {num_res}" + assert tags[0].name == tag_name + assert hasattr(tags[0], "id"), "Tag object does not have an 'id' attribute" + assert tags[0].id == 1, f"Expected tag id to be 1, got {tags[0].id}" diff --git a/stat_log_db/tests/test_sql_injection.py b/stat_log_db/tests/test_sql_injection.py deleted file mode 100644 index 447ba96..0000000 --- a/stat_log_db/tests/test_sql_injection.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -Test SQL injection protection in create_table and drop_table methods. -""" - -import pytest -import sys -from pathlib import Path - -from stat_log_db.db import MemDB - - -# Add the src directory to the path to import the module -ROOT = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(ROOT / "stat_log_db" / "src")) - - -@pytest.fixture -def mem_db(): - """Create a test in-memory database and clean up after tests.""" - sl_db = MemDB({ - "is_mem": True, - "fkey_constraint": True - }) - con = sl_db.init_db(True) - yield con - # Cleanup - sl_db.close_db() - - -class TestSQLInjectionProtection: - """Test class for SQL injection protection in database operations.""" - - def test_malicious_table_name_create(self, mem_db): - """Test that malicious SQL injection in table names is rejected.""" - with pytest.raises(ValueError, match="Invalid SQL table name"): - mem_db.create_table("test'; DROP TABLE users; --", [('notes', 'TEXT')], False, True) - - def test_reserved_word_table_name(self, mem_db): - """Test that SQL reserved words are rejected as table names.""" - with pytest.raises(ValueError, match="is a reserved word"): - mem_db.create_table("select", [('notes', 'TEXT')], False, True) - - def test_invalid_characters_table_name(self, mem_db): - """Test that invalid characters in table names are rejected.""" - with pytest.raises(ValueError, match="Invalid SQL table name"): - mem_db.create_table("test-table", [('notes', 'TEXT')], False, True) - - def test_malicious_column_name(self, mem_db): - """Test that malicious SQL injection in column names is rejected.""" - with pytest.raises(ValueError, match="Invalid SQL column name"): - mem_db.create_table("test_table", [('notes\'; DROP TABLE users; --', 'TEXT')], False, True) - - def test_invalid_column_type(self, mem_db): - """Test that invalid/malicious column types are rejected.""" - with pytest.raises(ValueError, match="Unsupported column type"): - mem_db.create_table("test_table", [('notes', 'MALICIOUS_TYPE; DROP TABLE users; --')], False, True) - - def test_valid_table_creation(self, mem_db): - """Test that valid table creation works correctly.""" - # This should not raise any exception - mem_db.create_table("test_table", [('notes', 'TEXT'), ('count', 'INTEGER')], False, True) - - # Verify table was created by attempting to insert data - mem_db.execute("INSERT INTO test_table (notes, count) VALUES (?, ?);", ("test note", 42)) - mem_db.commit() - - # Verify data was inserted - mem_db.execute("SELECT * FROM test_table;") - result = mem_db.fetchall() - assert len(result) == 1 - assert result[0][1] == "test note" # Column 0 is auto-increment id - assert result[0][2] == 42 - - def test_malicious_drop_table_name(self, mem_db): - """Test that malicious SQL injection in drop table is rejected.""" - # First create a valid table - mem_db.create_table("test_table", [('notes', 'TEXT')], False, True) - - # Then try to drop with malicious name - with pytest.raises(ValueError, match="Invalid SQL table name"): - mem_db.drop_table("test_table'; DROP TABLE sqlite_master; --", False) - - def test_valid_drop_table(self, mem_db): - """Test that valid table dropping works correctly.""" - # Create a table first - mem_db.create_table("test_table", [('notes', 'TEXT')], False, True) - - # Verify it exists by checking sqlite_master - mem_db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", ("test_table",)) - assert mem_db.fetchone() is not None - - # Drop the table - mem_db.drop_table("test_table", False) - - # Verify it's gone - mem_db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", ("test_table",)) - assert mem_db.fetchone() is None - - def test_empty_table_name_create(self, mem_db): - """Test that empty table names are rejected.""" - with pytest.raises(ValueError, match="cannot be an empty string"): - mem_db.create_table("", [('notes', 'TEXT')], False, True) - - def test_empty_table_name_drop(self, mem_db): - """Test that empty table names are rejected in drop operations.""" - with pytest.raises(ValueError, match="cannot be an empty string"): - mem_db.drop_table("", False) - - def test_empty_column_name(self, mem_db): - """Test that empty column names are rejected.""" - with pytest.raises(ValueError, match="cannot be empty"): - mem_db.create_table("test_table", [('', 'TEXT')], False, True) - - def test_column_name_with_numbers(self, mem_db): - """Test that column names with numbers are allowed.""" - mem_db.create_table("test_table", [('column1', 'TEXT'), ('column_2', 'INTEGER')], False, True) - - def test_table_name_with_underscore(self, mem_db): - """Test that table names starting with underscore are allowed.""" - mem_db.create_table("_test_table", [('notes', 'TEXT')], False, True) - - def test_valid_column_types(self, mem_db): - """Test that all supported column types work correctly.""" - valid_types = [ - ('text_col', 'TEXT'), - ('int_col', 'INTEGER'), - ('real_col', 'REAL'), - ('blob_col', 'BLOB'), - ('numeric_col', 'NUMERIC'), - ('varchar_col', 'VARCHAR(255)'), - ('decimal_col', 'DECIMAL(10,2)') - ] - - mem_db.create_table("type_test_table", valid_types, False, True) - - def test_case_insensitive_reserved_words(self, mem_db): - """Test that reserved words are caught regardless of case.""" - with pytest.raises(ValueError, match="is a reserved word"): - mem_db.create_table("SELECT", [('notes', 'TEXT')], False, True) - - with pytest.raises(ValueError, match="is a reserved word"): - mem_db.create_table("Select", [('notes', 'TEXT')], False, True) - - def test_raise_if_exists_functionality(self, mem_db): - """Test the raise_if_exists parameter works correctly.""" - # Create a table - mem_db.create_table("test_table", [('notes', 'TEXT')], False, True) - - # Try to create the same table with raise_if_exists=True (should fail) - with pytest.raises(ValueError, match="already exists"): - mem_db.create_table("test_table", [('notes', 'TEXT')], False, True) - - # Try to create the same table with raise_if_exists=False (should succeed) - mem_db.create_table("test_table", [('notes', 'TEXT')], False, False) - - def test_raise_if_not_exists_functionality(self, mem_db): - """Test the raise_if_not_exists parameter works correctly.""" - # Try to drop non-existent table with raise_if_not_exists=True (should fail) - with pytest.raises(ValueError, match="does not exist"): - mem_db.drop_table("nonexistent_table", True) - - # Try to drop non-existent table with raise_if_not_exists=False (should succeed) - mem_db.drop_table("nonexistent_table", False) - - def test_special_characters_rejection(self, mem_db): - """Test that various special characters are properly rejected.""" - special_chars = [ - "table;name", - "table'name", - 'table"name', - "table name", # space - "table-name", # hyphen - "table.name", # dot - "table(name)", # parentheses - "table[name]", # brackets - "table{name}", # braces - "table@name", # at symbol - "table#name", # hash - "table$name", # dollar (should be rejected by our implementation) - ] - - for table_name in special_chars: - with pytest.raises(ValueError, match="Invalid SQL"): - mem_db.create_table(table_name, [('notes', 'TEXT')], False, True) diff --git a/tools.sh b/tools.sh index 024a2f4..dccb741 100755 --- a/tools.sh +++ b/tools.sh @@ -110,20 +110,17 @@ if [ -n "$doc" ]; then fi # Clean artifacts [-c] -if [ $clean -eq 1 ]; then +if [ "$clean" -eq 1 ]; then echo "Cleaning up workspace..." - dirs_to_clean=( + static_dirs=( ".pytest_cache" - "tests/__pycache__" "stat_log_db/build" "stat_log_db/dist" - "stat_log_db/.pytest_cache" - "stat_log_db/tests/__pycache__" "stat_log_db/src/stat_log_db.egg-info" - "stat_log_db/src/stat_log_db/__pycache__" - "stat_log_db/src/stat_log_db/commands/__pycache__" ) - rm -rf "${dirs_to_clean[@]}" + rm -rf "${static_dirs[@]}" + # Recursively find and remove all __pycache__ directories + find . -type d -name '__pycache__' -exec rm -rf {} + echo "Cleanup complete." fi