From 9b2b65bfff47dfa5a8f3767c241eac8583529251 Mon Sep 17 00:00:00 2001 From: SyntaxAerror Date: Sat, 30 Aug 2025 18:41:43 -0400 Subject: [PATCH 1/9] Add db.py for managing sqlite database - Add Database class for creating sqlite database, and managing connections and queries. - Update cli.py with basic "hello world" example --- .gitignore | 3 ++ stat_log_db/src/stat_log_db/__init__.py | 1 + stat_log_db/src/stat_log_db/cli.py | 9 ++++++ stat_log_db/src/stat_log_db/db.py | 39 +++++++++++++++++++++++++ 4 files changed, 52 insertions(+) create mode 100644 stat_log_db/src/stat_log_db/db.py diff --git a/.gitignore b/.gitignore index 14baa7a..46b1768 100644 --- a/.gitignore +++ b/.gitignore @@ -206,3 +206,6 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +# sqlite +*.sqlite \ No newline at end of file diff --git a/stat_log_db/src/stat_log_db/__init__.py b/stat_log_db/src/stat_log_db/__init__.py index 62fa98c..0a3f2c1 100644 --- a/stat_log_db/src/stat_log_db/__init__.py +++ b/stat_log_db/src/stat_log_db/__init__.py @@ -1,3 +1,4 @@ from . import exceptions from . import parser +from . import db from . import cli diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index 9450174..6c2fb47 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -2,6 +2,7 @@ import sys from .parser import create_parser +from .db import Database def main(): @@ -17,6 +18,14 @@ def main(): print(f"{args=}") + sl_db = Database('sl_db.sqlite') + + sl_db.execute("CREATE TABLE IF NOT EXISTS logs (id INTEGER PRIMARY KEY, message TEXT)") + sl_db.execute("INSERT INTO logs (message) VALUES (?)", ("Hello, world!",)) + sl_db.commit() + sl_db.execute("SELECT * FROM logs") + sql_logs = sl_db.fetchall() + print(sql_logs) if __name__ == "__main__": main() diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py new file mode 100644 index 0000000..7d074d2 --- /dev/null +++ b/stat_log_db/src/stat_log_db/db.py @@ -0,0 +1,39 @@ +import sqlite3 + +from .exceptions import raise_type_error_with_signature + + +class Database: + def __init__(self, db_name: str): + self._db_name = db_name + self._connection = sqlite3.connect(self._db_name) + self._cursor = self._connection.cursor() + + def __del__(self): # TODO: Is this the right way to handle database closure? + self._connection.close() + + @property + def db_name(self): + return self._db_name + + @property + def connection(self): + return self._connection + + @property + def cursor(self): + return self._cursor + + def commit(self): + self._connection.commit() + + def execute(self, query: str, params: tuple = ()): + if not isinstance(query, str): + raise_type_error_with_signature("query") + if not isinstance(params, tuple): + raise_type_error_with_signature("params") + self._cursor.execute(query, params) + self.commit() + + def fetchall(self): + return self._cursor.fetchall() From 6cebf57a19a400cf289e6e86d7375afe315fe715 Mon Sep 17 00:00:00 2001 From: SyntaxAerror Date: Sun, 31 Aug 2025 01:59:05 -0400 Subject: [PATCH 2/9] Improve db.py with better abstraction - Implement Database/Connection classes for managing SQLite database connections and queries. --- stat_log_db/src/stat_log_db/cli.py | 27 ++- stat_log_db/src/stat_log_db/db.py | 246 ++++++++++++++++++++-- stat_log_db/src/stat_log_db/exceptions.py | 6 +- stat_log_db/src/stat_log_db/parser.py | 6 +- 4 files changed, 252 insertions(+), 33 deletions(-) diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index 6c2fb47..922d913 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -2,7 +2,7 @@ import sys from .parser import create_parser -from .db import Database +from .db import Database, BaseConnection def main(): @@ -16,16 +16,21 @@ def main(): args = parser.parse_args() - print(f"{args=}") - - sl_db = Database('sl_db.sqlite') - - sl_db.execute("CREATE TABLE IF NOT EXISTS logs (id INTEGER PRIMARY KEY, message TEXT)") - sl_db.execute("INSERT INTO logs (message) VALUES (?)", ("Hello, world!",)) - sl_db.commit() - sl_db.execute("SELECT * FROM logs") - sql_logs = sl_db.fetchall() - print(sql_logs) + # print(f"{args=}") + + db_filename = 'sl_db.sqlite' + sl_db = Database(db_filename) + con = sl_db.init_db(False) + if isinstance(con, BaseConnection): + con.execute("CREATE TABLE IF NOT EXISTS logs (id INTEGER PRIMARY KEY, message TEXT);") + con.execute("INSERT INTO logs (message) VALUES (?);", ("Hello, world!",)) + con.commit() + con.execute("SELECT * FROM logs;") + sql_logs = con.fetchall() + print(sql_logs) + + sl_db.close_db() + # os.remove(db_filename) if __name__ == "__main__": main() diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index 7d074d2..ec55182 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -1,39 +1,251 @@ +from abc import ABC, abstractmethod import sqlite3 +import uuid -from .exceptions import raise_type_error_with_signature +from .exceptions import raise_auto_arg_type_error -class Database: - def __init__(self, db_name: str): - self._db_name = db_name - self._connection = sqlite3.connect(self._db_name) - self._cursor = self._connection.cursor() - def __del__(self): # TODO: Is this the right way to handle database closure? - self._connection.close() +class Database(ABC): + def __init__(self, db_name: str | None = None, fkey_constraint: bool = True): + # Validate arguments + if not isinstance(db_name, (str, type(None))): + raise_auto_arg_type_error("db_name") + if not isinstance(fkey_constraint, bool): + raise_auto_arg_type_error("fkey_constraint") + self._db_name: str = ":memory:" if db_name is None else db_name + self._in_memory: bool = bool(self._db_name == ":memory:") + self._is_file: bool = bool(self._db_name != ":memory:") + self._fkey_constraint: bool = fkey_constraint + # Keep track of active connections (to ensure that they are closed) + self._connections: dict[str, BaseConnection] = dict() @property def db_name(self): return self._db_name + @property + def db_in_memory(self): + return self._in_memory + + @property + def db_is_file(self): + return self._is_file + + @property + def db_fkey_constraint(self): + return self._fkey_constraint + + def _register_connection(self): + """ + Creates a new database connection object and registers it. + Does not open the connection. + """ + con = BaseConnection(self) + self._connections[con.uid] = con + return con + + def _unregister_connection(self, connection: 'str | BaseConnection'): + """ + Unregister a database connection object. + Does not close it. + """ + connection_registry_key = None + if isinstance(connection, str): + connection_registry_key = connection + elif isinstance(connection, BaseConnection): + connection_registry_key = connection.uid + else: + raise_auto_arg_type_error("con") + if (connection_registry_key is None) or (connection_registry_key not in self._connections): + raise ValueError(f"Connection {connection} is not registered.") + del self._connections[connection_registry_key] + + def init_db(self, close_connection: bool = True): + if not isinstance(close_connection, bool): + raise_auto_arg_type_error("close_connection") + if self._in_memory and close_connection: + raise ValueError("In-memory databases cease to exist upon closure.") + connection = self._register_connection() + connection.open() + connection.enforce_foreign_key_constraints(False) + if close_connection: + connection.close() + self._unregister_connection(connection) + else: + return connection + + def close_db(self): + uids = [] + for uid, connection in self._connections.items(): + if not isinstance(connection, BaseConnection): + raise TypeError(f"Expected connection to be BaseConnection, got {type(connection).__name__} instead.") + if connection.uid != uid: + raise ValueError(f"Connection {connection.uid} is registered under non-matching uid: {uid}") + connection.close() + uids.append(uid) + for uid in uids: + self._unregister_connection(uid) + if not len(self._connections) == 0: + raise RuntimeError("Not all connections were closed properly.") + self._connections = dict() + + # @abstractmethod + # def create_table(self): + # pass + + # @abstractmethod + # def drop_table(self): + # pass + + # @abstractmethod + # def read(self): + # pass + + # @abstractmethod + # def write(self): + # pass + + # @abstractmethod + # def create(self): + # pass + + # @abstractmethod + # def unlink(self): + # pass + + +class BaseConnection: + def __init__(self, db: Database): + if not isinstance(db, Database): + raise_auto_arg_type_error("db") + self._db: Database = db + self._id = str(uuid.uuid4()) + self._connection: sqlite3.Connection | None = None + self._cursor: sqlite3.Cursor | None = None + + @property + def db_name(self): + return self._db._db_name + + @property + def db_in_memory(self): + return self._db._in_memory + + @property + def db_is_file(self): + return self._db._is_file + + @property + def db_fkey_constraint(self): + return self._db._fkey_constraint + + @property + def uid(self): + # TODO: Hash together the uuid, db_name, and possibly also the location in memory to ensure uniqueness? + return self._id + + @property + def registered(self): + return self.uid in self._db._connections + @property def connection(self): + if self._connection is None: + raise RuntimeError("Connection is not open.") + if not isinstance(self._connection, sqlite3.Connection): + raise TypeError(f"Expected self._connection to be sqlite3.Connection, got {type(self._connection).__name__} instead.") return self._connection @property def cursor(self): + if self._cursor is None: + raise RuntimeError("Cursor is not open.") + if not isinstance(self._cursor, sqlite3.Cursor): + raise TypeError(f"Expected self._cursor to be sqlite3.Cursor, got {type(self._cursor).__name__} instead.") return self._cursor - def commit(self): - self._connection.commit() + def enforce_foreign_key_constraints(self, commit: bool = True): + if not isinstance(commit, bool): + raise_auto_arg_type_error("commit") + if self.db_fkey_constraint: + self.cursor.execute("PRAGMA foreign_keys = ON;") + if commit: + self.connection.commit() + + def _open(self): + self._connection = sqlite3.connect(self.db_name) + self._cursor = self._connection.cursor() + + def open(self): + if isinstance(self._connection, sqlite3.Connection): + raise RuntimeError("Connection is already open.") + if not (self._connection is None): + raise TypeError(f"Expected self._connection to be None, got {type(self._connection).__name__} instead.") + self._open() - def execute(self, query: str, params: tuple = ()): + def _close(self): + self.cursor.close() + self._cursor = None + self.connection.close() + self._connection = None + + def close(self): + self.connection.commit() + self._close() + + def _execute(self, query: str, parameters: tuple = ()): + """ + Execute a SQL query with the given parameters. + """ + result = self.cursor.execute(query, parameters) + return result + + def execute(self, query: str, parameters: tuple | None = None): + """ + Execute a SQL query with the given parameters. + """ + # Validate query and parameters if not isinstance(query, str): - raise_type_error_with_signature("query") - if not isinstance(params, tuple): - raise_type_error_with_signature("params") - self._cursor.execute(query, params) - self.commit() + raise_auto_arg_type_error("query") + if len(query) == 0: + raise ValueError(f"'query' argument of execute cannot be an empty string!") + # Create a new space in memory that points to the same object that `parameters` points to + params = parameters + # If `params` points to None, update it to point to an empty tuple + if params is None: + params = tuple() + # If `params` points to an object that isn't a tuple or None (per previous condition), raise a TypeError + elif not isinstance(params, tuple): + raise_auto_arg_type_error("parameters") + # Execute query with `params` + result = self._execute(query, params) + return result + + def commit(self): + self.connection.commit() def fetchall(self): - return self._cursor.fetchall() + return self.cursor.fetchall() + + +class Connection(BaseConnection): + def __init__(self, db: Database): + super().__init__(db) + self.open() + self.enforce_foreign_key_constraints(True) + + +class FileConnectionCtx(BaseConnection): + def __init__(self, db: Database): + super().__init__(db) + if not self.db_is_file: + raise ValueError("FileConnectionCtx can only be used with file-based databases.") + + def __enter__(self): + self.open() + self.enforce_foreign_key_constraints(True) + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() diff --git a/stat_log_db/src/stat_log_db/exceptions.py b/stat_log_db/src/stat_log_db/exceptions.py index 6b985e6..c0c03eb 100644 --- a/stat_log_db/src/stat_log_db/exceptions.py +++ b/stat_log_db/src/stat_log_db/exceptions.py @@ -1,6 +1,8 @@ -def raise_type_error_with_signature(argument_name: str | list[str] | tuple[str, ...] | set[str] | None = None): - """Generate a standard type error message.""" +def raise_auto_arg_type_error(argument_name: str | list[str] | tuple[str, ...] | set[str] | None = None): + """ + Raise a TypeError with a standard message for arguments not matching the parameter's requested type. + """ message = f"TypeError with one or more of argument(s): {argument_name}" try: import inspect diff --git a/stat_log_db/src/stat_log_db/parser.py b/stat_log_db/src/stat_log_db/parser.py index 1026230..4a25116 100644 --- a/stat_log_db/src/stat_log_db/parser.py +++ b/stat_log_db/src/stat_log_db/parser.py @@ -1,13 +1,13 @@ import argparse -from .exceptions import raise_type_error_with_signature +from .exceptions import raise_auto_arg_type_error def create_parser(parser_args: dict, version: str | int = "0.0.1") -> argparse.ArgumentParser: """Create the main argument parser.""" # Validate parser_args if not isinstance(parser_args, dict): - raise_type_error_with_signature("parser_args") + raise_auto_arg_type_error("parser_args") # Default formatter class if "formatter_class" not in parser_args: parser_args["formatter_class"] = argparse.RawDescriptionHelpFormatter @@ -18,7 +18,7 @@ def create_parser(parser_args: dict, version: str | int = "0.0.1") -> argparse.A # Validate version if not isinstance(version, (str, int)): - raise_type_error_with_signature("version") + raise_auto_arg_type_error("version") # Add version argument parser.add_argument( From 5d445dda7c5dc79c23b1b3dee3c10f7781df4da2 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Sun, 31 Aug 2025 15:54:54 -0400 Subject: [PATCH 3/9] DB Connection Registry Validation, Better Inheritance, Table Methods --- stat_log_db/pyproject.toml | 2 +- stat_log_db/src/stat_log_db/cli.py | 25 ++- stat_log_db/src/stat_log_db/db.py | 275 ++++++++++++++++++++++------- 3 files changed, 228 insertions(+), 74 deletions(-) diff --git a/stat_log_db/pyproject.toml b/stat_log_db/pyproject.toml index fd0139d..432b728 100644 --- a/stat_log_db/pyproject.toml +++ b/stat_log_db/pyproject.toml @@ -7,7 +7,7 @@ name = "stat-log-db" version = "0.0.1" description = "" readme = "README.md" -requires-python = "==3.12.10" +requires-python = ">=3.12.10" dependencies = [ ] diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index 922d913..d9daf53 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -2,7 +2,7 @@ import sys from .parser import create_parser -from .db import Database, BaseConnection +from .db import Database, MemDB, FileDB, BaseConnection def main(): @@ -18,19 +18,18 @@ def main(): # print(f"{args=}") - db_filename = 'sl_db.sqlite' - sl_db = Database(db_filename) - con = sl_db.init_db(False) - if isinstance(con, BaseConnection): - con.execute("CREATE TABLE IF NOT EXISTS logs (id INTEGER PRIMARY KEY, message TEXT);") - con.execute("INSERT INTO logs (message) VALUES (?);", ("Hello, world!",)) - con.commit() - con.execute("SELECT * FROM logs;") - sql_logs = con.fetchall() - print(sql_logs) - + sl_db = MemDB(":memory:", True, True) + con = sl_db.init_db(True) + con.create_table("test", [('notes', 'TEXT')], False, True) + con.execute("INSERT INTO test (notes) VALUES (?);", ("Hello world!",)) + con.commit() + con.execute("SELECT * FROM test;") + sql_logs = con.fetchall() + print(sql_logs) + con.drop_table("test", True) sl_db.close_db() - # os.remove(db_filename) + if sl_db.is_file: + os.remove(sl_db.file_name) if __name__ == "__main__": main() diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index ec55182..c4a07c4 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +# from abc import ABC, abstractmethod import sqlite3 import uuid @@ -6,82 +6,156 @@ from .exceptions import raise_auto_arg_type_error -class Database(ABC): - def __init__(self, db_name: str | None = None, fkey_constraint: bool = True): +class Database(): + def __init__(self, db_name: str | None = None, is_mem: bool = False, fkey_constraint: bool = True): # Validate arguments - if not isinstance(db_name, (str, type(None))): + # database name + if db_name is None: + self._db_name = str(uuid.uuid4()) + elif not isinstance(db_name, str): raise_auto_arg_type_error("db_name") + else: + self._db_name = db_name + # is memory or file database + if not isinstance(is_mem, bool): + raise_auto_arg_type_error("is_mem") + self._in_memory = is_mem + self._is_file = not is_mem + # database file name + if is_mem: + self._db_file_name = ":memory:" + else: + self._db_file_name = self._db_name.replace(" ", "_") if not isinstance(fkey_constraint, bool): raise_auto_arg_type_error("fkey_constraint") - self._db_name: str = ":memory:" if db_name is None else db_name - self._in_memory: bool = bool(self._db_name == ":memory:") - self._is_file: bool = bool(self._db_name != ":memory:") - self._fkey_constraint: bool = fkey_constraint + self._fkey_constraint = fkey_constraint # Keep track of active connections (to ensure that they are closed) self._connections: dict[str, BaseConnection] = dict() @property - def db_name(self): + def name(self) -> str: return self._db_name @property - def db_in_memory(self): + def file_name(self) -> str: + return self._db_file_name + + @property + def in_memory(self) -> bool: return self._in_memory @property - def db_is_file(self): + def is_file(self) -> bool: return self._is_file @property - def db_fkey_constraint(self): + def fkey_constraint(self) -> bool: return self._fkey_constraint + def check_connection_integrity(self, connection: 'str | BaseConnection', skip_registry_type_check: bool = False): + """ + Check the integrity of a given connection's registration. + The connection to be checked can be passed as an UID string or a connection object (instance of BaseConnection). + """ + if not isinstance(skip_registry_type_check, bool): + raise_auto_arg_type_error("skip_registry_type_check") + connection_is_uid_str = isinstance(connection, str) + connection_is_obj = isinstance(connection, BaseConnection) + if (not connection_is_uid_str) and (not connection_is_obj): + raise_auto_arg_type_error("connection") + if self._connections is None or len(self._connections) == 0: + raise ValueError(f"Connection {connection.uid if connection_is_obj else connection} is not registered, as Connection Registry contains no connections.") + # Check that the registry is of the expected type + if (not skip_registry_type_check) and (not isinstance(self._connections, dict)): + raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") + # If the passed-in connection is a uid string, + # search the registry keys for that uid string. + # Check that a matching connection is found, + # that it has a valid UID, and that it is registered + # under the uid that it has (registry key = found connection's uid). + if connection_is_uid_str: + if len(connection) == 0: + raise ValueError("Connection UID string is empty.") + found_connection = self._connections.get(connection, None) + if found_connection is None: + raise ValueError(f"Connection '{connection}' is not registered.") + if not isinstance(found_connection.uid, str): + raise TypeError(f"Expected the found connection's uid to be str, got {type(found_connection.uid).__name__} instead.") + if len(found_connection.uid) == 0: + raise ValueError("Found connection's uid string is empty.") + if found_connection.uid != connection: + raise ValueError(f"Connection '{connection}' is registered under non-matching uid: {found_connection.uid}") + # If the passed-in connection is a BaseConnection object, + # check that it has a valid uid and that it's UID is in the registry + elif connection_is_obj: + if not isinstance(connection.uid, str): + raise TypeError(f"Expected the connection's uid to be str, got {type(connection.uid).__name__} instead.") + if connection.uid not in self._connections: + raise ValueError(f"Connection '{connection.uid}' is not registered, or is registered under the wrong uid.") + + def check_connection_registry_integrity(self, skip_registry_type_check: bool = False): + """ + Check the integrity of the connection registry. + If not all connections are registered, no error is raised. + """ + if not isinstance(skip_registry_type_check, bool): + raise_auto_arg_type_error("skip_registry_type_check") + # Check that the registry is of the expected type + if (not skip_registry_type_check) and (not isinstance(self._connections, dict)): + raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") + # If there are no connections, nothing to check + if len(self._connections) == 0: + return + # Check that all registered connections are registered under a UID of the correct type and are instances of BaseConnection + if any((not isinstance(uid, str)) or (not isinstance(conn, BaseConnection)) for uid, conn in self._connections.items()): + raise TypeError("All connections must be registered by their UID string and be instances of BaseConnection.") + # Perform individual connection integrity checks + for uid in self._connections.keys(): + self.check_connection_integrity(uid, skip_registry_type_check=True) # Registry type already checked + def _register_connection(self): """ Creates a new database connection object and registers it. Does not open the connection. """ - con = BaseConnection(self) - self._connections[con.uid] = con - return con + connection = BaseConnection(self) + self._connections[connection.uid] = connection + self.check_connection_integrity(connection) + return connection def _unregister_connection(self, connection: 'str | BaseConnection'): """ Unregister a database connection object. Does not close it. """ - connection_registry_key = None - if isinstance(connection, str): - connection_registry_key = connection - elif isinstance(connection, BaseConnection): - connection_registry_key = connection.uid - else: - raise_auto_arg_type_error("con") - if (connection_registry_key is None) or (connection_registry_key not in self._connections): - raise ValueError(f"Connection {connection} is not registered.") - del self._connections[connection_registry_key] - - def init_db(self, close_connection: bool = True): - if not isinstance(close_connection, bool): - raise_auto_arg_type_error("close_connection") - if self._in_memory and close_connection: - raise ValueError("In-memory databases cease to exist upon closure.") + connection_is_obj = isinstance(connection, BaseConnection) + if (not isinstance(connection, str)) and (not connection_is_obj): + raise_auto_arg_type_error("connection") + connection_uid_str = connection.uid if connection_is_obj else connection + self.check_connection_integrity(connection_uid_str) + # TODO: consider implementing garbage collector ref-count check + del self._connections[connection_uid_str] + + def init_db(self, commit_fkey: bool = True) -> 'BaseConnection': + if not isinstance(commit_fkey, bool): + raise_auto_arg_type_error("commit_fkey") connection = self._register_connection() connection.open() - connection.enforce_foreign_key_constraints(False) - if close_connection: - connection.close() - self._unregister_connection(connection) - else: - return connection + connection.enforce_foreign_key_constraints(commit_fkey) + return connection + + def init_db_auto_close(self): + if self.in_memory: + raise ValueError("In-memory databases cease to exist upon closure.") + # don't bother to commit fkey constraint because close() will commit before connection closure + connection = self.init_db(False) + connection.close() + self._unregister_connection(connection.uid) def close_db(self): uids = [] + self.check_connection_registry_integrity() for uid, connection in self._connections.items(): - if not isinstance(connection, BaseConnection): - raise TypeError(f"Expected connection to be BaseConnection, got {type(connection).__name__} instead.") - if connection.uid != uid: - raise ValueError(f"Connection {connection.uid} is registered under non-matching uid: {uid}") connection.close() uids.append(uid) for uid in uids: @@ -90,29 +164,37 @@ def close_db(self): raise RuntimeError("Not all connections were closed properly.") self._connections = dict() - # @abstractmethod - # def create_table(self): - # pass - # @abstractmethod - # def drop_table(self): - # pass +class MemDB(Database): + def __init__(self, db_name: str | None = None, is_mem: bool = False, fkey_constraint: bool = True): + super().__init__(db_name, is_mem, fkey_constraint) + if not self.in_memory: + raise ValueError("MemDB can only be used for in-memory databases.") - # @abstractmethod - # def read(self): - # pass + def check_connection_registry_integrity(self, skip_registry_type_check: bool = False): + """ + Check the integrity of the connection registry. + Implements early raise if more than one connection is found, + since in-memory databases can only have one connection. + """ + if not isinstance(skip_registry_type_check, bool): + raise_auto_arg_type_error("skip_registry_type_check") + if not skip_registry_type_check: + if not isinstance(self._connections, dict): + raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") + if (num_connections := len(self._connections)) > 1: + raise ValueError(f"In-memory databases can only have one active connection Found {num_connections}.") + return super().check_connection_registry_integrity(skip_registry_type_check=True) # Registry type already checked - # @abstractmethod - # def write(self): - # pass + def init_db_auto_close(self): + raise ValueError("In-memory databases cease to exist upon closure.") - # @abstractmethod - # def create(self): - # pass - # @abstractmethod - # def unlink(self): - # pass +class FileDB(Database): + def __init__(self, db_name: str | None = None, fkey_constraint: bool = True): + super().__init__(db_name, fkey_constraint) + if not self.is_file: + raise ValueError("FileDB can only be used for file-based databases.") class BaseConnection: @@ -128,6 +210,10 @@ def __init__(self, db: Database): def db_name(self): return self._db._db_name + @property + def db_file_name(self): + return self._db._db_file_name + @property def db_in_memory(self): return self._db._in_memory @@ -147,7 +233,8 @@ def uid(self): @property def registered(self): - return self.uid in self._db._connections + self._db.check_connection_integrity(self) # raises error if not registered + return True @property def connection(self): @@ -197,6 +284,7 @@ def close(self): def _execute(self, query: str, parameters: tuple = ()): """ Execute a SQL query with the given parameters. + Performs no checks/validation. Prefer `execute` unless you need raw access. """ result = self.cursor.execute(query, parameters) return result @@ -225,9 +313,76 @@ def execute(self, query: str, parameters: tuple | None = None): def commit(self): self.connection.commit() + def fetchone(self): + return self.cursor.fetchone() + def fetchall(self): return self.cursor.fetchall() + def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True): + # Validate table_name argument + if not isinstance(table_name, str): + raise_auto_arg_type_error("table_name") + if len(table_name) == 0: + raise ValueError(f"'table_name' argument of create_table cannot be an empty string!") + if not isinstance(raise_if_exists, bool): + raise_auto_arg_type_error("raise_if_exists") + # Check if table already exists + if raise_if_exists: + self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (table_name,)) + if self.cursor.fetchone() is not None: + raise ValueError(f"Table '{table_name}' already exists.") + # Validate temp_table argument + if not isinstance(temp_table, bool): + raise_auto_arg_type_error("temp_table") + # Validate columns argument + if (not isinstance(columns, list)) or (not all( + isinstance(col, tuple) and len(col) == 2 + and isinstance(col[0], str) + and isinstance(col[1], str) + for col in columns)): + raise_auto_arg_type_error("columns") + # Construct columns portion of query + # TODO: construct parameters for columns rather than f-string to prevent SQL injection + columns_qstr = "" + for col in columns: + columns_qstr += f"{col[0]} {col[1]},\n" + columns_qstr = columns_qstr.rstrip(",\n") # Remove trailing comma and newline + # Assemble full query + query = f"""--sql + CREATE{" TEMPORARY" if temp_table else ""} TABLE IF NOT EXISTS '{table_name}' ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + {columns_qstr} + ); + """ + self.execute(query) + + def drop_table(self, table_name: str, raise_if_not_exists: bool = False): + # Validate table_name argument + if not isinstance(table_name, str): + raise_auto_arg_type_error("table_name") + if len(table_name) == 0: + raise ValueError(f"'table_name' argument of drop_table cannot be an empty string!") + if not isinstance(raise_if_not_exists, bool): + raise_auto_arg_type_error("raise_if_not_exists") + if raise_if_not_exists: + self.cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';") + if self.cursor.fetchone() is None: + raise ValueError(f"Table '{table_name}' does not exist.") + self.cursor.execute(f"DROP TABLE IF EXISTS '{table_name}';") + + # def read(self): + + + # def write(self): + + + # def create(self): + + + # def unlink(self): + + class Connection(BaseConnection): def __init__(self, db: Database): From 0aacc4bb7ce4e747547750bcbaba4a2c7d52ce04 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Mon, 1 Sep 2025 09:15:27 -0400 Subject: [PATCH 4/9] refactor Database init to take opts dict --- stat_log_db/src/stat_log_db/cli.py | 15 +++++---- stat_log_db/src/stat_log_db/db.py | 50 ++++++++++++++---------------- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index d9daf53..afccc41 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -9,16 +9,19 @@ def main(): """Main CLI entry point.""" # TODO: Read info from pyproject.toml? - parser = create_parser({ - "prog": "sldb", - "description": "My CLI tool", - }, "0.0.1") + # parser = create_parser({ + # "prog": "sldb", + # "description": "My CLI tool", + # }, "0.0.1") - args = parser.parse_args() + # args = parser.parse_args() # print(f"{args=}") - sl_db = MemDB(":memory:", True, True) + sl_db = MemDB({ + "is_mem": True, + "fkey_constraint": True + }) con = sl_db.init_db(True) con.create_table("test", [('notes', 'TEXT')], False, True) con.execute("INSERT INTO test (notes) VALUES (?);", ("Hello world!",)) diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index c4a07c4..81785ff 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -1,34 +1,32 @@ # from abc import ABC, abstractmethod import sqlite3 import uuid +from typing import Any from .exceptions import raise_auto_arg_type_error class Database(): - def __init__(self, db_name: str | None = None, is_mem: bool = False, fkey_constraint: bool = True): + def __init__(self, options: dict[str, Any] = {}): # Validate arguments - # database name - if db_name is None: - self._db_name = str(uuid.uuid4()) - elif not isinstance(db_name, str): - raise_auto_arg_type_error("db_name") - else: - self._db_name = db_name - # is memory or file database - if not isinstance(is_mem, bool): - raise_auto_arg_type_error("is_mem") - self._in_memory = is_mem - self._is_file = not is_mem - # database file name - if is_mem: - self._db_file_name = ":memory:" - else: - self._db_file_name = self._db_name.replace(" ", "_") - if not isinstance(fkey_constraint, bool): - raise_auto_arg_type_error("fkey_constraint") - self._fkey_constraint = fkey_constraint + valid_options = { + "db_name": str, + "is_mem": bool, + "fkey_constraint": bool + } + for opt, opt_type in options.items(): + if opt not in valid_options.keys(): + raise ValueError(f"Invalid option provided: '{opt}'. Must be one of {list(valid_options.keys())}.") + expected_type = valid_options[opt] + if not isinstance(opt_type, expected_type): + raise TypeError(f"Option '{opt}' must be of type {expected_type.__name__}, got {type(opt_type).__name__}.") + # Assign arguments to class attributes + self._in_memory: bool = options.get("is_mem", False) + self._is_file: bool = bool(not self._in_memory) + self._db_name: str = options.get("db_name", str(uuid.uuid4())) + self._db_file_name: str = ":memory:" if self._in_memory else self._db_name.replace(" ", "_") + self._fkey_constraint: bool = options.get("fkey_constraint", True) # Keep track of active connections (to ensure that they are closed) self._connections: dict[str, BaseConnection] = dict() @@ -166,8 +164,8 @@ def close_db(self): class MemDB(Database): - def __init__(self, db_name: str | None = None, is_mem: bool = False, fkey_constraint: bool = True): - super().__init__(db_name, is_mem, fkey_constraint) + def __init__(self, options: dict[str, Any] = {}): + super().__init__(options=options) if not self.in_memory: raise ValueError("MemDB can only be used for in-memory databases.") @@ -191,8 +189,8 @@ def init_db_auto_close(self): class FileDB(Database): - def __init__(self, db_name: str | None = None, fkey_constraint: bool = True): - super().__init__(db_name, fkey_constraint) + def __init__(self, options: dict[str, Any] = {}): + super().__init__(options=options) if not self.is_file: raise ValueError("FileDB can only be used for file-based databases.") @@ -261,7 +259,7 @@ def enforce_foreign_key_constraints(self, commit: bool = True): self.connection.commit() def _open(self): - self._connection = sqlite3.connect(self.db_name) + self._connection = sqlite3.connect(self.db_file_name) self._cursor = self._connection.cursor() def open(self): From 4304de51561a081837a233fdc63c7e518f4e6d72 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Mon, 1 Sep 2025 12:26:04 -0400 Subject: [PATCH 5/9] Implement basic sanitization, injection tests, improve git-bash lookup on windows --- stat_log_db/src/stat_log_db/db.py | 134 ++++++++++++++--- stat_log_db/tests/test_sql_injection.py | 183 ++++++++++++++++++++++++ tests/test_tools.py | 33 ++++- 3 files changed, 332 insertions(+), 18 deletions(-) create mode 100644 stat_log_db/tests/test_sql_injection.py diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index 81785ff..867d221 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -1,4 +1,5 @@ # from abc import ABC, abstractmethod +import re import sqlite3 import uuid from typing import Any @@ -317,22 +318,88 @@ def fetchone(self): def fetchall(self): return self.cursor.fetchall() + def _validate_sql_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: + """ + Validate and sanitize SQL identifiers (table names, column names) to prevent SQL injection. + + Args: + identifier: The identifier to validate + identifier_type: Type of identifier for error messages (e.g., "table name", "column name") + + Returns: + The validated identifier + + Raises: + ValueError: If the identifier is invalid or potentially dangerous + """ + if not isinstance(identifier, str): + raise TypeError(f"SQL {identifier_type} must be a string, got {type(identifier).__name__}") + + if len(identifier) == 0: + raise ValueError(f"SQL {identifier_type} cannot be empty") + + # Check for valid identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier): + raise ValueError(f"Invalid SQL {identifier_type}: '{identifier}'. Must start with letter or underscore and contain only letters, numbers, and underscores.") + + # Check against SQLite reserved words (common ones that could cause issues) + reserved_words = { + 'abort', 'action', 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', + 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', 'cascade', 'case', + 'cast', 'check', 'collate', 'column', 'commit', 'conflict', 'constraint', 'create', + 'cross', 'current', 'current_date', 'current_time', 'current_timestamp', 'database', + 'default', 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', 'do', + 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', 'exists', 'explain', + 'fail', 'filter', 'following', 'for', 'foreign', 'from', 'full', 'glob', 'group', + 'having', 'if', 'ignore', 'immediate', 'in', 'index', 'indexed', 'initially', 'inner', + 'insert', 'instead', 'intersect', 'into', 'is', 'isnull', 'join', 'key', 'left', + 'like', 'limit', 'match', 'natural', 'no', 'not', 'notnull', 'null', 'of', 'offset', + 'on', 'or', 'order', 'outer', 'over', 'partition', 'plan', 'pragma', 'preceding', + 'primary', 'query', 'raise', 'range', 'recursive', 'references', 'regexp', 'reindex', + 'release', 'rename', 'replace', 'restrict', 'right', 'rollback', 'row', 'rows', + 'savepoint', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', 'transaction', + 'trigger', 'unbounded', 'union', 'unique', 'update', 'using', 'vacuum', 'values', + 'view', 'virtual', 'when', 'where', 'window', 'with', 'without' + } + + if identifier.lower() in reserved_words: + raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word and cannot be used") + + return identifier + + def _escape_sql_identifier(self, identifier: str) -> str: + """ + Escape SQL identifier by wrapping in double quotes and escaping any internal quotes. + This should only be used after validation. + """ + # Escape any double quotes in the identifier by doubling them + escaped = identifier.replace('"', '""') + return f'"{escaped}"' + def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True): # Validate table_name argument if not isinstance(table_name, str): raise_auto_arg_type_error("table_name") if len(table_name) == 0: raise ValueError(f"'table_name' argument of create_table cannot be an empty string!") + + # Validate and sanitize table name + validated_table_name = self._validate_sql_identifier(table_name, "table name") + escaped_table_name = self._escape_sql_identifier(validated_table_name) + if not isinstance(raise_if_exists, bool): raise_auto_arg_type_error("raise_if_exists") - # Check if table already exists + + # Check if table already exists using parameterized query if raise_if_exists: - self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (table_name,)) + self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) if self.cursor.fetchone() is not None: - raise ValueError(f"Table '{table_name}' already exists.") + raise ValueError(f"Table '{validated_table_name}' already exists.") + # Validate temp_table argument if not isinstance(temp_table, bool): raise_auto_arg_type_error("temp_table") + # Validate columns argument if (not isinstance(columns, list)) or (not all( isinstance(col, tuple) and len(col) == 2 @@ -340,19 +407,43 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab and isinstance(col[1], str) for col in columns)): raise_auto_arg_type_error("columns") - # Construct columns portion of query - # TODO: construct parameters for columns rather than f-string to prevent SQL injection - columns_qstr = "" - for col in columns: - columns_qstr += f"{col[0]} {col[1]},\n" - columns_qstr = columns_qstr.rstrip(",\n") # Remove trailing comma and newline - # Assemble full query - query = f"""--sql - CREATE{" TEMPORARY" if temp_table else ""} TABLE IF NOT EXISTS '{table_name}' ( + + # Validate and construct columns portion of query + validated_columns = [] + for col_name, col_type in columns: + # Validate column name + validated_col_name = self._validate_sql_identifier(col_name, "column name") + escaped_col_name = self._escape_sql_identifier(validated_col_name) + + # Validate column type - allow only safe, known SQLite types + allowed_types = { + 'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC', + 'VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR', + 'CLOB', 'DATE', 'DATETIME', 'TIMESTAMP', + 'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT', + 'INT', 'BIGINT', 'SMALLINT', 'TINYINT' + } + + # Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2)) + base_type = re.match(r'^([A-Z]+)', col_type.upper()) + if not base_type or base_type.group(1) not in allowed_types: + raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}") + + # Basic validation for type specification format + if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', col_type.upper()): + raise ValueError(f"Invalid column type format: '{col_type}'") + + validated_columns.append(f"{escaped_col_name} {col_type.upper()}") + + columns_qstr = ",\n ".join(validated_columns) + + # Assemble full query with escaped identifiers + temp_keyword = " TEMPORARY" if temp_table else "" + query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( id INTEGER PRIMARY KEY AUTOINCREMENT, {columns_qstr} - ); - """ + );""" + self.execute(query) def drop_table(self, table_name: str, raise_if_not_exists: bool = False): @@ -361,13 +452,22 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False): raise_auto_arg_type_error("table_name") if len(table_name) == 0: raise ValueError(f"'table_name' argument of drop_table cannot be an empty string!") + + # Validate and sanitize table name + validated_table_name = self._validate_sql_identifier(table_name, "table name") + escaped_table_name = self._escape_sql_identifier(validated_table_name) + if not isinstance(raise_if_not_exists, bool): raise_auto_arg_type_error("raise_if_not_exists") + + # Check if table exists using parameterized query if raise_if_not_exists: - self.cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';") + self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) if self.cursor.fetchone() is None: - raise ValueError(f"Table '{table_name}' does not exist.") - self.cursor.execute(f"DROP TABLE IF EXISTS '{table_name}';") + raise ValueError(f"Table '{validated_table_name}' does not exist.") + + # Execute DROP statement with escaped identifier + self.cursor.execute(f"DROP TABLE IF EXISTS {escaped_table_name};") # def read(self): diff --git a/stat_log_db/tests/test_sql_injection.py b/stat_log_db/tests/test_sql_injection.py new file mode 100644 index 0000000..62cc599 --- /dev/null +++ b/stat_log_db/tests/test_sql_injection.py @@ -0,0 +1,183 @@ +""" +Test SQL injection protection in create_table and drop_table methods. +""" + +import pytest +import sys +from pathlib import Path + +# Add the src directory to the path to import the module +ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(ROOT / "stat_log_db" / "src")) + +from stat_log_db.db import MemDB + + +@pytest.fixture +def mem_db(): + """Create a test in-memory database and clean up after tests.""" + sl_db = MemDB({ + "is_mem": True, + "fkey_constraint": True + }) + con = sl_db.init_db(True) + yield con + # Cleanup + sl_db.close_db() + + +class TestSQLInjectionProtection: + """Test class for SQL injection protection in database operations.""" + + def test_malicious_table_name_create(self, mem_db): + """Test that malicious SQL injection in table names is rejected.""" + with pytest.raises(ValueError, match="Invalid SQL table name"): + mem_db.create_table("test'; DROP TABLE users; --", [('notes', 'TEXT')], False, True) + + def test_reserved_word_table_name(self, mem_db): + """Test that SQL reserved words are rejected as table names.""" + with pytest.raises(ValueError, match="is a reserved word"): + mem_db.create_table("select", [('notes', 'TEXT')], False, True) + + def test_invalid_characters_table_name(self, mem_db): + """Test that invalid characters in table names are rejected.""" + with pytest.raises(ValueError, match="Invalid SQL table name"): + mem_db.create_table("test-table", [('notes', 'TEXT')], False, True) + + def test_malicious_column_name(self, mem_db): + """Test that malicious SQL injection in column names is rejected.""" + with pytest.raises(ValueError, match="Invalid SQL column name"): + mem_db.create_table("test_table", [('notes\'; DROP TABLE users; --', 'TEXT')], False, True) + + def test_invalid_column_type(self, mem_db): + """Test that invalid/malicious column types are rejected.""" + with pytest.raises(ValueError, match="Unsupported column type"): + mem_db.create_table("test_table", [('notes', 'MALICIOUS_TYPE; DROP TABLE users; --')], False, True) + + def test_valid_table_creation(self, mem_db): + """Test that valid table creation works correctly.""" + # This should not raise any exception + mem_db.create_table("test_table", [('notes', 'TEXT'), ('count', 'INTEGER')], False, True) + + # Verify table was created by attempting to insert data + mem_db.execute("INSERT INTO test_table (notes, count) VALUES (?, ?);", ("test note", 42)) + mem_db.commit() + + # Verify data was inserted + mem_db.execute("SELECT * FROM test_table;") + result = mem_db.fetchall() + assert len(result) == 1 + assert result[0][1] == "test note" # Column 0 is auto-increment id + assert result[0][2] == 42 + + def test_malicious_drop_table_name(self, mem_db): + """Test that malicious SQL injection in drop table is rejected.""" + # First create a valid table + mem_db.create_table("test_table", [('notes', 'TEXT')], False, True) + + # Then try to drop with malicious name + with pytest.raises(ValueError, match="Invalid SQL table name"): + mem_db.drop_table("test_table'; DROP TABLE sqlite_master; --", False) + + def test_valid_drop_table(self, mem_db): + """Test that valid table dropping works correctly.""" + # Create a table first + mem_db.create_table("test_table", [('notes', 'TEXT')], False, True) + + # Verify it exists by checking sqlite_master + mem_db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", ("test_table",)) + assert mem_db.fetchone() is not None + + # Drop the table + mem_db.drop_table("test_table", False) + + # Verify it's gone + mem_db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", ("test_table",)) + assert mem_db.fetchone() is None + + def test_empty_table_name_create(self, mem_db): + """Test that empty table names are rejected.""" + with pytest.raises(ValueError, match="cannot be an empty string"): + mem_db.create_table("", [('notes', 'TEXT')], False, True) + + def test_empty_table_name_drop(self, mem_db): + """Test that empty table names are rejected in drop operations.""" + with pytest.raises(ValueError, match="cannot be an empty string"): + mem_db.drop_table("", False) + + def test_empty_column_name(self, mem_db): + """Test that empty column names are rejected.""" + with pytest.raises(ValueError, match="cannot be empty"): + mem_db.create_table("test_table", [('', 'TEXT')], False, True) + + def test_column_name_with_numbers(self, mem_db): + """Test that column names with numbers are allowed.""" + mem_db.create_table("test_table", [('column1', 'TEXT'), ('column_2', 'INTEGER')], False, True) + + def test_table_name_with_underscore(self, mem_db): + """Test that table names starting with underscore are allowed.""" + mem_db.create_table("_test_table", [('notes', 'TEXT')], False, True) + + def test_valid_column_types(self, mem_db): + """Test that all supported column types work correctly.""" + valid_types = [ + ('text_col', 'TEXT'), + ('int_col', 'INTEGER'), + ('real_col', 'REAL'), + ('blob_col', 'BLOB'), + ('numeric_col', 'NUMERIC'), + ('varchar_col', 'VARCHAR(255)'), + ('decimal_col', 'DECIMAL(10,2)') + ] + + mem_db.create_table("type_test_table", valid_types, False, True) + + def test_case_insensitive_reserved_words(self, mem_db): + """Test that reserved words are caught regardless of case.""" + with pytest.raises(ValueError, match="is a reserved word"): + mem_db.create_table("SELECT", [('notes', 'TEXT')], False, True) + + with pytest.raises(ValueError, match="is a reserved word"): + mem_db.create_table("Select", [('notes', 'TEXT')], False, True) + + def test_raise_if_exists_functionality(self, mem_db): + """Test the raise_if_exists parameter works correctly.""" + # Create a table + mem_db.create_table("test_table", [('notes', 'TEXT')], False, True) + + # Try to create the same table with raise_if_exists=True (should fail) + with pytest.raises(ValueError, match="already exists"): + mem_db.create_table("test_table", [('notes', 'TEXT')], False, True) + + # Try to create the same table with raise_if_exists=False (should succeed) + mem_db.create_table("test_table", [('notes', 'TEXT')], False, False) + + def test_raise_if_not_exists_functionality(self, mem_db): + """Test the raise_if_not_exists parameter works correctly.""" + # Try to drop non-existent table with raise_if_not_exists=True (should fail) + with pytest.raises(ValueError, match="does not exist"): + mem_db.drop_table("nonexistent_table", True) + + # Try to drop non-existent table with raise_if_not_exists=False (should succeed) + mem_db.drop_table("nonexistent_table", False) + + def test_special_characters_rejection(self, mem_db): + """Test that various special characters are properly rejected.""" + special_chars = [ + "table;name", + "table'name", + 'table"name', + "table name", # space + "table-name", # hyphen + "table.name", # dot + "table(name)", # parentheses + "table[name]", # brackets + "table{name}", # braces + "table@name", # at symbol + "table#name", # hash + "table$name", # dollar (should be rejected by our implementation) + ] + + for table_name in special_chars: + with pytest.raises(ValueError, match="Invalid SQL"): + mem_db.create_table(table_name, [('notes', 'TEXT')], False, True) diff --git a/tests/test_tools.py b/tests/test_tools.py index 47b7002..bb35ea9 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -44,6 +44,37 @@ def is_installed(package: str) -> bool: return result.returncode == 0 +def _find_bash_executable(): # TODO: Improve this + """Find bash executable, preferring Git Bash on Windows.""" + if os.name != "nt": + return "bash" + # Common Git Bash locations on Windows + common_paths = [ + r"C:\Program Files\Git\bin\bash.exe", + r"C:\Program Files (x86)\Git\bin\bash.exe", + r"C:\Users\{}\AppData\Local\Programs\Git\bin\bash.exe".format(os.getenv("USERNAME", "")), + r"C:\Git\bin\bash.exe", + ] + # Check common paths first + for path in common_paths: + if os.path.isfile(path): + return path + # Try to find bash using 'where' command + try: + result = subprocess.run(["where", "bash"], capture_output=True, text=True, check=True) + bash_path = result.stdout.strip().split('\n')[0] # Get first result + if os.path.isfile(bash_path): + return bash_path + except (subprocess.CalledProcessError, FileNotFoundError, IndexError): + pass + # If we get here, bash was not found + raise FileNotFoundError( + "Git Bash not found. Please install Git for Windows from https://git-scm.com/download/win " + "or ensure bash.exe is in your PATH. Tried the following locations:\n" + + "\n".join(f" - {path}" for path in common_paths) + ) + + def run_tools(args, use_test_venv=False): """Run tools.sh returning (code, stdout+stderr).""" env = os.environ.copy() @@ -53,7 +84,7 @@ def run_tools(args, use_test_venv=False): env["PATH"] = str(scripts_dir) + os.pathsep + env.get("PATH", "") env["VIRTUAL_ENV"] = str(VENV_TEST) env["PYTHONHOME"] = "" # ensure venv python resolution - bash = r"C:\Program Files\Git\bin\bash.exe" if os.name == "nt" else "bash" # TODO: indicate to the user that they need git bash + bash = _find_bash_executable() proc = subprocess.run([bash, str(SCRIPT), *args], capture_output=True, text=True, cwd=ROOT, env=env) return proc.returncode, proc.stdout + proc.stderr From 817f5dd4a9de91ff2abab9e1fcba7f22cb5fe440 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Sat, 6 Sep 2025 11:21:22 -0400 Subject: [PATCH 6/9] install stat_log_db in python-package.yml for tests --- .github/workflows/python-package.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c552b19..7c2aaad 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -35,6 +35,9 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --ignore=E261 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Install stat-log-db package (dev) + run: | + python -m pip install -e ./stat_log_db[dev] - name: Test with pytest run: | pytest From 93cb808aaa271db6358d84dc90ac4516533efc26 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Sat, 6 Sep 2025 11:33:51 -0400 Subject: [PATCH 7/9] fix(style): flake8 ignore rules & improve style Updated flake8 ignore rules to include W293 and W503. Improved comments for clarity on linting behavior. See [Action Run #17 (17516215023) > build > Lint with flake8](https://github.com/SyntaxAerror/stat-log-db/actions/runs/17516215023) --- .github/workflows/python-package.yml | 6 ++++-- stat_log_db/src/stat_log_db/db.py | 2 +- tests/test_tools.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 7c2aaad..af1c252 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -31,8 +31,10 @@ jobs: if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --ignore=E261 --show-source --statistics + # stop the build if th1ere are Python syntax errors or undefined names + # Notes: + # W503: https://peps.python.org/pep-0008/#should-a-line-break-before-or-after-a-binary-operator + flake8 . --count --select=E9,F63,F7,F82 --ignore=E261,W293,W503 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Install stat-log-db package (dev) diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index 867d221..1fc3914 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -451,7 +451,7 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False): if not isinstance(table_name, str): raise_auto_arg_type_error("table_name") if len(table_name) == 0: - raise ValueError(f"'table_name' argument of drop_table cannot be an empty string!") + raise ValueError("'table_name' argument of drop_table cannot be an empty string!") # Validate and sanitize table name validated_table_name = self._validate_sql_identifier(table_name, "table name") diff --git a/tests/test_tools.py b/tests/test_tools.py index bb35ea9..ed2e885 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -51,7 +51,7 @@ def _find_bash_executable(): # TODO: Improve this # Common Git Bash locations on Windows common_paths = [ r"C:\Program Files\Git\bin\bash.exe", - r"C:\Program Files (x86)\Git\bin\bash.exe", + r"C:\Program Files (x86)\Git\bin\bash.exe", r"C:\Users\{}\AppData\Local\Programs\Git\bin\bash.exe".format(os.getenv("USERNAME", "")), r"C:\Git\bin\bash.exe", ] From 4226874c8c2e42121c48e4248878380412724047 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Sat, 6 Sep 2025 12:33:30 -0400 Subject: [PATCH 8/9] fix(tests): add style testing Added a new test to tools.sh for style checks using flake8 to ensure code quality. Added flake8 to pyproject dev dependencies. --- .flake8 | 4 +++- .github/workflows/python-package.yml | 6 ++---- stat_log_db/pyproject.toml | 3 ++- stat_log_db/src/stat_log_db/cli.py | 1 + stat_log_db/src/stat_log_db/db.py | 4 ++-- tests/test_tools.py | 13 ++++++++++--- tools.sh | 7 ++++++- 7 files changed, 26 insertions(+), 12 deletions(-) diff --git a/.flake8 b/.flake8 index 2376c0a..f3e0942 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,7 @@ [flake8] -ignore = E261 +# Notes: +# W503: https://peps.python.org/pep-0008/#should-a-line-break-before-or-after-a-binary-operator +ignore = E261, W293, W503 max-line-length = 200 exclude = .git,.github,__pycache__,.pytest_cache,.venv,.venv_test,.vscode max-complexity = 10 \ No newline at end of file diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index af1c252..aedf9ad 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -31,10 +31,8 @@ jobs: if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | - # stop the build if th1ere are Python syntax errors or undefined names - # Notes: - # W503: https://peps.python.org/pep-0008/#should-a-line-break-before-or-after-a-binary-operator - flake8 . --count --select=E9,F63,F7,F82 --ignore=E261,W293,W503 --show-source --statistics + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Install stat-log-db package (dev) diff --git a/stat_log_db/pyproject.toml b/stat_log_db/pyproject.toml index 432b728..df2a999 100644 --- a/stat_log_db/pyproject.toml +++ b/stat_log_db/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest==8.4.1", - "pytest-cov==6.2.1" + "pytest-cov==6.2.1", + "flake8==7.3.0" ] [project.scripts] diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index afccc41..50116fa 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -34,5 +34,6 @@ def main(): if sl_db.is_file: os.remove(sl_db.file_name) + if __name__ == "__main__": main() diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index 1fc3914..babeca3 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -296,7 +296,7 @@ def execute(self, query: str, parameters: tuple | None = None): if not isinstance(query, str): raise_auto_arg_type_error("query") if len(query) == 0: - raise ValueError(f"'query' argument of execute cannot be an empty string!") + raise ValueError("'query' argument of execute cannot be an empty string!") # Create a new space in memory that points to the same object that `parameters` points to params = parameters # If `params` points to None, update it to point to an empty tuple @@ -381,7 +381,7 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab if not isinstance(table_name, str): raise_auto_arg_type_error("table_name") if len(table_name) == 0: - raise ValueError(f"'table_name' argument of create_table cannot be an empty string!") + raise ValueError("'table_name' argument of create_table cannot be an empty string!") # Validate and sanitize table name validated_table_name = self._validate_sql_identifier(table_name, "table name") diff --git a/tests/test_tools.py b/tests/test_tools.py index ed2e885..8165d7e 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -70,8 +70,8 @@ def _find_bash_executable(): # TODO: Improve this # If we get here, bash was not found raise FileNotFoundError( "Git Bash not found. Please install Git for Windows from https://git-scm.com/download/win " - "or ensure bash.exe is in your PATH. Tried the following locations:\n" + - "\n".join(f" - {path}" for path in common_paths) + "or ensure bash.exe is in your PATH. Tried the following locations:\n" + + "\n".join(f" - {path}" for path in common_paths) ) @@ -121,7 +121,6 @@ def test_help(): except AssertionError: assert out.strip() == readme_content.strip(), "Help output does not match README content (leading & trailing whitespace stripped)" - @pytest.mark.skipif(GITHUB_ACTIONS, reason="Skipping test on GitHub Actions") def test_install_dev(test_venv): code, out = run_tools(["-id"], use_test_venv=True) @@ -186,6 +185,14 @@ def test_test_invalid_arg(): assert ("Unsupported argument" in out) or ("Invalid test mode" in out) +@pytest.mark.skipif(GITHUB_ACTIONS, reason="Skipping test on GitHub Actions") +def test_test_style(): + code, out = run_tools(["-ts"]) + assert code == 0 + assert "Running style tests" in out + assert "flake8" in out + + @pytest.mark.skipif(GITHUB_ACTIONS, reason="Skipping test on GitHub Actions") def test_clean(): code, out = run_tools(["-c"]) diff --git a/tools.sh b/tools.sh index 9d88ee0..19f2e9f 100755 --- a/tools.sh +++ b/tools.sh @@ -5,7 +5,7 @@ supported_installation_opts="d n" install="" uninstall=0 clean=0 -supported_test_opts="p t a d" +supported_test_opts="p t a d s" test="" while getopts ":i:t:chu" flag; do @@ -68,6 +68,11 @@ if [ -n "$test" ]; then echo "Running all tests..." pytest ;; + s) + echo "Running style tests (flake8)..." + flake8 . + # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + ;; *) echo "Invalid test mode '$test'. Use one of: $supported_test_opts" >&2 exit 1 From d7daf058ba93fc50f3d537874998436a3e649567 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Sat, 6 Sep 2025 16:23:18 -0400 Subject: [PATCH 9/9] refactor(style): style fixes per flake8 - cli: clean up imports and improve code readability - db: enhance SQL identifier validation and sanitization - test: reorganize imports in SQL injection tests --- stat_log_db/src/stat_log_db/cli.py | 6 +-- stat_log_db/src/stat_log_db/db.py | 57 +++++++------------------ stat_log_db/tests/test_sql_injection.py | 5 ++- tests/test_tools.py | 1 + 4 files changed, 23 insertions(+), 46 deletions(-) diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index 50116fa..55189b7 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -1,8 +1,8 @@ import os -import sys +# import sys -from .parser import create_parser -from .db import Database, MemDB, FileDB, BaseConnection +# from .parser import create_parser +from .db import MemDB # , FileDB, Database, BaseConnection def main(): diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index babeca3..6b29ea9 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -321,27 +321,21 @@ def fetchall(self): def _validate_sql_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: """ Validate and sanitize SQL identifiers (table names, column names) to prevent SQL injection. - Args: identifier: The identifier to validate identifier_type: Type of identifier for error messages (e.g., "table name", "column name") - Returns: The validated identifier - Raises: ValueError: If the identifier is invalid or potentially dangerous """ if not isinstance(identifier, str): raise TypeError(f"SQL {identifier_type} must be a string, got {type(identifier).__name__}") - if len(identifier) == 0: raise ValueError(f"SQL {identifier_type} cannot be empty") - # Check for valid identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier): raise ValueError(f"Invalid SQL {identifier_type}: '{identifier}'. Must start with letter or underscore and contain only letters, numbers, and underscores.") - # Check against SQLite reserved words (common ones that could cause issues) reserved_words = { 'abort', 'action', 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', @@ -361,10 +355,8 @@ def _validate_sql_identifier(self, identifier: str, identifier_type: str = "iden 'trigger', 'unbounded', 'union', 'unique', 'update', 'using', 'vacuum', 'values', 'view', 'virtual', 'when', 'where', 'window', 'with', 'without' } - if identifier.lower() in reserved_words: raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word and cannot be used") - return identifier def _escape_sql_identifier(self, identifier: str) -> str: @@ -382,24 +374,11 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab raise_auto_arg_type_error("table_name") if len(table_name) == 0: raise ValueError("'table_name' argument of create_table cannot be an empty string!") - - # Validate and sanitize table name - validated_table_name = self._validate_sql_identifier(table_name, "table name") - escaped_table_name = self._escape_sql_identifier(validated_table_name) - - if not isinstance(raise_if_exists, bool): - raise_auto_arg_type_error("raise_if_exists") - - # Check if table already exists using parameterized query - if raise_if_exists: - self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) - if self.cursor.fetchone() is not None: - raise ValueError(f"Table '{validated_table_name}' already exists.") - # Validate temp_table argument if not isinstance(temp_table, bool): raise_auto_arg_type_error("temp_table") - + if not isinstance(raise_if_exists, bool): + raise_auto_arg_type_error("raise_if_exists") # Validate columns argument if (not isinstance(columns, list)) or (not all( isinstance(col, tuple) and len(col) == 2 @@ -407,14 +386,20 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab and isinstance(col[1], str) for col in columns)): raise_auto_arg_type_error("columns") - + # Validate and sanitize table name + validated_table_name = self._validate_sql_identifier(table_name, "table name") + escaped_table_name = self._escape_sql_identifier(validated_table_name) + # Check if table already exists using parameterized query + if raise_if_exists: + self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) + if self.cursor.fetchone() is not None: + raise ValueError(f"Table '{validated_table_name}' already exists.") # Validate and construct columns portion of query validated_columns = [] for col_name, col_type in columns: # Validate column name validated_col_name = self._validate_sql_identifier(col_name, "column name") escaped_col_name = self._escape_sql_identifier(validated_col_name) - # Validate column type - allow only safe, known SQLite types allowed_types = { 'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC', @@ -423,27 +408,21 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab 'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'SMALLINT', 'TINYINT' } - # Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2)) base_type = re.match(r'^([A-Z]+)', col_type.upper()) if not base_type or base_type.group(1) not in allowed_types: raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}") - # Basic validation for type specification format if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', col_type.upper()): raise ValueError(f"Invalid column type format: '{col_type}'") - validated_columns.append(f"{escaped_col_name} {col_type.upper()}") - columns_qstr = ",\n ".join(validated_columns) - # Assemble full query with escaped identifiers temp_keyword = " TEMPORARY" if temp_table else "" query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( id INTEGER PRIMARY KEY AUTOINCREMENT, {columns_qstr} );""" - self.execute(query) def drop_table(self, table_name: str, raise_if_not_exists: bool = False): @@ -452,34 +431,30 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False): raise_auto_arg_type_error("table_name") if len(table_name) == 0: raise ValueError("'table_name' argument of drop_table cannot be an empty string!") - + if not isinstance(raise_if_not_exists, bool): + raise_auto_arg_type_error("raise_if_not_exists") # Validate and sanitize table name validated_table_name = self._validate_sql_identifier(table_name, "table name") escaped_table_name = self._escape_sql_identifier(validated_table_name) - - if not isinstance(raise_if_not_exists, bool): - raise_auto_arg_type_error("raise_if_not_exists") - # Check if table exists using parameterized query if raise_if_not_exists: self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) if self.cursor.fetchone() is None: raise ValueError(f"Table '{validated_table_name}' does not exist.") - # Execute DROP statement with escaped identifier self.cursor.execute(f"DROP TABLE IF EXISTS {escaped_table_name};") # def read(self): - + # pass # def write(self): - + # pass # def create(self): - + # pass # def unlink(self): - + # pass class Connection(BaseConnection): diff --git a/stat_log_db/tests/test_sql_injection.py b/stat_log_db/tests/test_sql_injection.py index 62cc599..447ba96 100644 --- a/stat_log_db/tests/test_sql_injection.py +++ b/stat_log_db/tests/test_sql_injection.py @@ -6,12 +6,13 @@ import sys from pathlib import Path +from stat_log_db.db import MemDB + + # Add the src directory to the path to import the module ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT / "stat_log_db" / "src")) -from stat_log_db.db import MemDB - @pytest.fixture def mem_db(): diff --git a/tests/test_tools.py b/tests/test_tools.py index 8165d7e..60612b5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -121,6 +121,7 @@ def test_help(): except AssertionError: assert out.strip() == readme_content.strip(), "Help output does not match README content (leading & trailing whitespace stripped)" + @pytest.mark.skipif(GITHUB_ACTIONS, reason="Skipping test on GitHub Actions") def test_install_dev(test_venv): code, out = run_tools(["-id"], use_test_venv=True)