Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:

strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]

services:
postgres:
Expand All @@ -29,6 +29,14 @@ jobs:
- 5432:5432
options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5

mysql:
image: mariadb:latest
env:
MARIADB_ROOT_PASSWORD: based
MARIADB_DB: based
ports:
- 3306:3306

steps:
- uses: "actions/checkout@v4"
- uses: "actions/setup-python@v5"
Expand All @@ -41,7 +49,8 @@ jobs:
- name: "Run tests"
env:
BASED_TEST_DB_URLS: |
postgresql://based:based@localhost:5432/based
postgresql://based:based@localhost:5432/based,
mysql://root:based@127.0.0.1:3306/based
run: "make test"

coverage:
Expand All @@ -59,6 +68,14 @@ jobs:
- 5432:5432
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5

mysql:
image: mariadb:latest
env:
MARIADB_ROOT_PASSWORD: based
MARIADB_DB: based
ports:
- 3306:3306

steps:
- uses: "actions/checkout@v4"
- uses: "actions/setup-python@v5"
Expand All @@ -69,7 +86,8 @@ jobs:
- name: "Run tests"
env:
BASED_TEST_DB_URLS: |
postgresql://based:based@localhost:5432/based
postgresql://based:based@localhost:5432/based,
mysql://root:based@127.0.0.1:3306/based
run: "make test"
- name: Coverage report
uses: irongut/CodeCoverageSummary@v1.3.0
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
.venv
*.egg-info

.mypy_cache
.ruff_cache

.coverage
Expand Down
2 changes: 1 addition & 1 deletion LICENSE.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2024 ansipunk
Copyright (c) 2024 ansipunk <kysput@gmail.com>

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: help bootstrap lint test clean
.PHONY: help bootstrap lint test build clean
DEFAULT: help

VENV = .venv
Expand All @@ -9,16 +9,16 @@ help:
@echo " bootstrap - setup development environment"
@echo " lint - run static code analysis"
@echo " test - run project tests"
@echo " build - build packages"
@echo " clean - clean environment and remove development artifacts"

bootstrap:
python3 -m venv $(VENV)
$(PYTHON) -m pip install --upgrade pip==24.2 setuptools==75.2.0 wheel==0.44.0 build==1.2.2.post1
$(PYTHON) -m pip install -e .[postgres,sqlite,dev]
$(PYTHON) -m pip install -e .[postgres,sqlite,mysql,dev]

lint: $(VENV)
$(PYTHON) -m ruff check based tests
$(PYTHON) -m mypy --strict based

test: $(VENV)
$(PYTHON) -m pytest
Expand Down
21 changes: 15 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
A based asynchronous database connection manager.

Based is designed to be used with SQLAlchemy Core requests. Currently, the only
supported databases are SQLite and PostgreSQL. It's fairly simple to add a new
backend, should you need one. Work in progress - any contributions - issues or
pull requests - are very welcome. API might change, as library is still at its
early experiment stage.
supported databases are SQLite, PostgreSQL and MySQL. It's fairly simple to add
a new backend, should you need one. Work in progress - any contributions -
issues or pull requests - are very welcome. API might change, as library is
still at its early experiment stage.

This library is inspired by [databases](https://github.com/encode/databases).

## Usage

```bash
pip install based[sqlite] # or based[postgres]
pip install based[sqlite] # or based[postgres] or based[mysql]
```

```python
Expand Down Expand Up @@ -99,13 +99,22 @@ need to implement `Backend` class and add its initialization to the `Database`
class. You only need to implement methods that raise `NotImplementedError` in
the base class, adding private helpers as needed.

### Testing

Pass database URLs for those you want to run the tests against. Comma separated
list.

```bash
BASED_TEST_DB_URLS='postgresql://postgres:postgres@localhost:5432/postgres,mysql://root:mariadb@127.0.0.1:3306/mariadb' make test`
```

## TODO

- [x] CI/CD
- [x] Building and uploading packages to PyPi
- [x] Testing with multiple Python versions
- [ ] Database URL parsing and building
- [ ] MySQL backend
- [x] MySQL backend
- [x] Add comments and docstrings
- [x] Add lock for PostgreSQL in `force_rollback` mode and SQLite in both modes
- [x] Refactor tests
Expand Down
2 changes: 1 addition & 1 deletion based/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.2"
__version__ = "0.5.0"

from based.backends import Session
from based.database import Database
Expand Down
91 changes: 44 additions & 47 deletions based/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ def __init__( # noqa: D107
async def _execute(
self,
query: typing.Union[ClauseElement, str],
params: typing.Optional[typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]] = None,
params: typing.Optional[
typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]
] = None,
) -> typing.Any: # noqa: ANN401
"""Execute the provided query and return a corresponding Cursor object.

Expand All @@ -145,34 +147,17 @@ async def _execute(
"""
return await self._conn.execute(query, params)

async def _execute_within_transaction(
self,
query: typing.Union[ClauseElement, str],
params: typing.Optional[typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]] = None,
) -> typing.Any: # noqa: ANN401
await self.create_transaction()

try:
cursor = await self._conn.execute(query, params)
except Exception:
await self.cancel_transaction()
raise
else:
await self.commit_transaction()

return cursor

def _compile_query(
self, query: ClauseElement,
self,
query: ClauseElement,
) -> typing.Tuple[
str,
typing.Optional[typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]],
typing.Optional[
typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]
],
]:
compiled_query = query.compile(
dialect=self._dialect,
Expand All @@ -182,7 +167,9 @@ def _compile_query(
return str(compiled_query), compiled_query.params

def _cast_row(
self, cursor: typing.Any, row: typing.Any, # noqa: ANN401
self,
cursor: typing.Any, # noqa: ANN401
row: typing.Any, # noqa: ANN401
) -> typing.Dict[str, typing.Any]:
"""Cast a driver specific Row object to a more general mapping."""
fields = [column[0] for column in cursor.description]
Expand All @@ -191,10 +178,12 @@ def _cast_row(
async def execute(
self,
query: typing.Union[ClauseElement, str],
params: typing.Optional[typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]] = None,
params: typing.Optional[
typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]
] = None,
) -> None:
"""Execute the provided query.

Expand All @@ -207,15 +196,17 @@ async def execute(
"""
if isinstance(query, ClauseElement):
query, params = self._compile_query(query)
await self._execute_within_transaction(query, params)
await self._execute(query, params)

async def fetch_one(
self,
query: typing.Union[ClauseElement, str],
params: typing.Optional[typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]] = None,
params: typing.Optional[
typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]
] = None,
) -> typing.Optional[typing.Dict[str, typing.Any]]:
"""Execute the provided query.

Expand All @@ -234,19 +225,23 @@ async def fetch_one(
if isinstance(query, ClauseElement):
query, params = self._compile_query(query)

cursor = await self._execute_within_transaction(query, params)
cursor = await self._execute(query, params)
row = await cursor.fetchone()
if not row:
return None
return self._cast_row(cursor, row)
row = self._cast_row(cursor, row)
await cursor.close()
return row

async def fetch_all(
self,
query: typing.Union[ClauseElement, str],
params: typing.Optional[typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]] = None,
params: typing.Optional[
typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]
] = None,
) -> typing.List[typing.Dict[str, typing.Any]]:
"""Execute the provided query.

Expand All @@ -264,9 +259,11 @@ async def fetch_all(
if isinstance(query, ClauseElement):
query, params = self._compile_query(query)

cursor = await self._execute_within_transaction(query, params)
cursor = await self._execute(query, params)
rows = await cursor.fetchall()
return [self._cast_row(cursor, row) for row in rows]
rows = [self._cast_row(cursor, row) for row in rows]
await cursor.close()
return rows

async def create_transaction(self) -> None:
"""Create a transaction and add it to the transaction stack."""
Expand Down
91 changes: 91 additions & 0 deletions based/backends/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import typing
from contextlib import asynccontextmanager

import asyncmy
from sqlalchemy import URL, make_url
from sqlalchemy.dialects.mysql.asyncmy import dialect
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement

from based.backends import Backend, Session


class MySQL(Backend):
"""A MySQL backend for based.Database using asyncmy."""

_url: URL
_pool: asyncmy.Pool
_force_rollback: bool
_force_rollback_connection: asyncmy.Connection
_dialect: Dialect

def __init__(self, url: str, *, force_rollback: bool = False) -> None: # noqa: D107
self._url = make_url(url)
self._force_rollback = force_rollback
self._dialect = dialect() # type: ignore

async def _connect(self) -> None:
self._pool = await asyncmy.create_pool(
user=self._url.username,
password=self._url.password,
host=self._url.host,
port=self._url.port,
database=self._url.database,
)

if self._force_rollback:
self._force_rollback_connection = await self._pool.acquire()

async def _disconnect(self) -> None:
if self._force_rollback:
await self._force_rollback_connection.rollback()
self._pool.release(self._force_rollback_connection)

self._pool.close()
await self._pool.wait_closed()

@asynccontextmanager
async def _session(self) -> typing.AsyncGenerator["Session", None]:
if self._force_rollback:
connection = self._force_rollback_connection
else:
connection = await self._pool.acquire()

session = _MySQLSession(connection, self._dialect)

if self._force_rollback:
await session.create_transaction()

try:
yield session
except Exception:
await session.cancel_transaction()
raise
else:
await session.commit_transaction()
else:
try:
yield session
except Exception:
await connection.rollback()
raise
else:
await connection.commit()
finally:
self._pool.release(connection)


class _MySQLSession(Session):
async def _execute(
self,
query: typing.Union[ClauseElement, str],
params: typing.Optional[
typing.Union[
typing.Dict[str, typing.Any],
typing.List[typing.Any],
]
] = None,
) -> asyncmy.cursors.Cursor:
cursor = self._conn.cursor()
await cursor.execute(query, params)
return cursor
Loading
Loading