Source code for aiodesa.database

"""
aiodesa.Database: Simple SQLite Database Interface

This module provides the `Db` class, a simple SQLite database interface that
supports asynchronous operations.

Classes:

- :class:`Db`: Represents a simple SQLite database interface.

Example:

.. code-block:: python

    from aiodesa import Db

    class Users:
        username: str
        id: str | None = None
        table_name: str = "users"

    async with Db("database.sqlite3") as db:
        await db.read_table_schemas(Users)
"""

from dataclasses import is_dataclass, fields
from typing import Tuple, Callable, Any, Coroutine
from pathlib import Path
import aiosqlite
from aiodesa.utils.table import make_schema, TableSchema


[docs] class Db: """ Represents a simple SQLite database interface. Args: db_path : str The path to the SQLite database file. Example: .. code-block:: python class Users: username: str id: str | None = None table_name: str = "users" async with Db("database.sqlite3") as db: await db.read_table_schemas(Users) ... """ _tables: dict db_path: Path _conn: Any def __init__(self, db_path: str) -> None: self.db_path = Path(db_path) self._conn = None self._create_db() self._tables = {} def _create_db(self) -> None: """ Internal method to create the database file if it does not exist. Notes: - This method is automatically called during the initialization of the Db class. - It ensures that the SQLite database file is created at the specified path if it does not exist. """ if not self.db_path.exists(): self.db_path.parent.mkdir(parents=True, exist_ok=True) self.db_path.touch() async def _process_single_data_class(self, schema: Any) -> None: """ Process a single data class schema. Args: schema: The data class schema representing a table. Returns: This method does not return any value. """ if not is_dataclass(schema): raise ValueError("Provided schema is not a data class") self._tables[schema.table_name] = schema class_fields = fields(schema) for field in class_fields: if field.name == "table_name": schema_ = make_schema(str(field.default), schema) await self._create_table(schema_, field.name)
[docs] async def read_table_schemas(self, class_obj: Any | Tuple[Any, ...]) -> None: """Read table schemas and create tables in the database. Args: schema: The schema or tuple of schemas to be processed. Each schema should be a data class representing a table. Returns: This method does not return any value. Example: .. code-block:: python class Users: username: str id: str | None = None table_name: str = "users" async with Db("database.sqlite3") as db: await db.read_table_schemas(Users) ... Note: Provide any additional notes or considerations about the method. """ # single dataclass if is_dataclass(class_obj): await self._process_single_data_class(class_obj) return # tuple of dataclasses if isinstance(class_obj, tuple): for _obj in class_obj: await self._process_single_data_class(_obj) return
async def _table_exists(self, table_name: str) -> bool | None: """ Create a table in the database based on the provided TableSchema instance. Args: table_name: The name of the table. Returns: None This method creates a table in the database with the specified name and schema. """ if self._conn is not None: query = "SELECT name FROM sqlite_master \ WHERE type='table' AND name=?;" cursor = await self._conn.execute(query, (table_name,)) return await cursor.fetchone() is not None return None async def _create_table(self, named_data: TableSchema, name: str) -> None: """ Internal method to create a table in the database based on the provided TableSchema instance. Args: named_data: The TableSchema instance containing the table_name and SQL data definition. name: The name of the table. Returns: None Example: .. code-block:: python if is_dataclass(schema): class_fields = fields(schema) for field in class_fields: if field.name == "table_name": schema_ = make_schema(str(field.default), schema) await self._create_table(schema_, field.name) return This method creates a table in the database with the specified name and schema. Note: The `named_data` parameter should include the `table_name` property for the name of the table and the `sql_definition` property for the SQL data definition of the table. """ if self._conn is not None: if not await self._table_exists(name): async with self._conn.executescript(named_data.data) as cursor: await cursor.fetchall() await self._conn.commit()
[docs] def insert(self, data_class: Any) -> Callable[..., Coroutine[Any, Any, None]]: """ Create a record and insert it into the specified table. Args: data_class: The data class representing the table structure. Returns: A function to be called with the record data. Example: .. code-block:: python class Users: username: str id: str | None = None table_name: str = "users" async with Db("database.sqlite3") as db: await db.read_table_schemas(Users) ... insert = db.update(UserEcon) await insert("john_doe") """ async def _record(*args: Any, **kwargs: Any) -> None: data_cls = self._tables[data_class.table_name](*args, **kwargs) field_vals = {} for field in fields(data_cls): value = getattr(data_cls, field.name) if value is not None and value != data_cls.table_name: field_vals[field.name] = value insertion_vals = tuple(field_vals.values()) columns_str = ", ".join(field_vals.keys()) placeholders = ", ".join("?" for _ in insertion_vals) sql = f"INSERT INTO {data_class.table_name} \ ({columns_str}) VALUES ({placeholders});" await self._conn.execute(sql, insertion_vals) await self._conn.commit() return None return _record
[docs] def update( self, data_class: Any, column_identifier: None | str = None ) -> Callable[..., Coroutine[Any, Any, None]]: """ Create a record update operation for the specified table. Args: data_class: The data class representing the table structure. column_identifier: The column to use for identifying records. Returns: A function to be called with the record data for updating. Example: .. code-block:: python class Users: username: str id: str | None = None table_name: str = "users" async with Db("database.sqlite3") as db: await db.read_table_schemas(Users) ... update = db.update(UserEcon) await update("john_doe") Note: If the `column_identifier` is not provided, the primary key of the data class will be used as the identifier. """ async def _record(*args, **kwargs) -> None: data_cls = self._tables[data_class.table_name](*args, **kwargs) values = [] set_clauses_placeholders = [] for column, value in kwargs.items(): values.append(value) set_clause = f"{column} = ?" set_clauses_placeholders.append(set_clause) set_clause_string = ", ".join(set_clauses_placeholders) values.extend(args) identifier = ( column_identifier if column_identifier is not None else data_cls.primary_key ) sql = f"UPDATE {data_class.table_name} SET \ {set_clause_string} WHERE {identifier} = ?" await self._conn.execute(sql, tuple(values)) await self._conn.commit() return _record
[docs] def find( self, data_class: Any, column_identifier: None | str = None ) -> Callable[..., Coroutine[Any, Any, None]]: """ Create a record retrieval operation for the specified table. Args: data_class: The data class representing the table structure. column_identifier: The column to use for identifying records. Defaults to the primary key of the data class if not specified. Returns: A function to be called with the identifier for record retrieval. Example: .. code-block:: python class MyBestFriends: username: str id: str | None = None table_name: str = "users" async with Db("database.sqlite3") as db: await db.read_table_schemas(MyBestFriends) ... find_jimmy = db.find(MyBestFriends) jimmy = await find_jimmy("jimmy") """ async def _record(*args, **kwargs) -> None: data_cls = self._tables[data_class.table_name](*args, **kwargs) identifier = ( column_identifier if column_identifier is not None else data_cls.primary_key ) results = [] sql = f"SELECT * FROM {data_cls.table_name} WHERE {identifier} = ?" sql_args = (args[0],) async with self._conn.execute(sql, sql_args) as cursor: results = await cursor.fetchall() if len(results) > 0: rows_fetched = results[0] data_cls = data_class(*rows_fetched, *results[1:]) return data_cls return None return _record
[docs] def find_all(self, data_class: Any) -> Callable[..., Coroutine[Any, Any, list]]: """ Create record retrieval operation to fetch all records from the specified table. Args: data_class: The data class representing the table structure. Returns: A function to be called with optional additional query parameters. Example: .. code-block:: python class Users: username: str id: str | None = None table_name: str = "users" async with Db("database.sqlite3") as db: await db.read_table_schemas(Users) ... find_all_users = db.find_all(Users) all_users = await find_all_users() """ async def _records() -> list: sql = f"SELECT * FROM {data_class.table_name}" async with self._conn.execute(sql) as cursor: results = await cursor.fetchall() records = [] for row in results: record = data_class(*row) records.append(record) return records return _records
[docs] def delete( self, data_class: Any, column_identifier: None | str = None ) -> Callable[..., Coroutine[Any, Any, None]]: """ Create a record deletion operation for the specified table. This defaults to the primary key ifthe column_identifier is not provided. Args: data_class: The data class representing the table structure. column_identifier: The column to use for identifying records. Returns: A function to be called with the identifier for record deletion. Example: .. code-block:: python class Users: username: str id: str | None = None table_name: str = "users" async with Db("database.sqlite3") as db: await db.read_table_schemas(Users) ... delete = db.delete(UserEcon) await delete("john_doe") """ async def _record(*args, **kwargs) -> None: data_cls = self._tables[data_class.table_name](*args, **kwargs) identifier = ( column_identifier if column_identifier is not None else data_cls.primary_key ) sql = f"DELETE FROM {data_cls.table_name} WHERE {identifier} = ?" sql_args = (args[0],) async with self._conn.execute(sql, sql_args) as cursor: await cursor.fetchall() await self._conn.commit() return _record
async def _connect(self) -> None: """ Establish a connection to the SQLite database. Returns: None Example: .. code-block:: python connection = YourDatabaseConnection() await connection.connect() # The database connection is now established. Note: This method initializes the connection to the SQLite database using the provided `db_path`. """ self._conn = await aiosqlite.connect(self.db_path) async def _close(self) -> None: """ Close the connection to the SQLite database. Returns: None Example: .. code-block:: python connection = YourDatabaseConnection() await connection.connect() # Your database operations here await connection.close() # The database connection is now closed. Note: This method closes the connection to the SQLite database if it is open. """ if self._conn is not None: await self._conn.close() self._conn = None async def __aenter__(self) -> "Db": """ Asynchronous context manager entry point. Automatically connects to the database upon entering the context. Returns: Db: The Db instance with an active database connection. Example: .. code-block:: python async with YourDatabaseConnection() as connection: # Your asynchronous code here # Upon entering the context, the database connection is automatically established. Note: This method is intended for use with the `async with` statement in an asynchronous context manager. The returned `Db` instance represents the connection to the database. """ await self._connect() # await self._conn.execute("BEGIN") return self async def __aexit__(self, exc_type, exc_value, traceback): """ Asynchronous context manager exit point. Automatically closes the database connection upon exiting the context. Args: exc_type (Type): The type of the exception raised, if any. exc_value (Exception): The exception object, if an exception occurred. Otherwise, None. traceback (TracebackType): The traceback information related to the exception, if any. Returns: None Example: .. code-block:: python async with YourDatabaseConnection() as connection: # Your asynchronous code here # Upon exiting the context, the database connection is automatically closed. Note: This method is intended for use with the `async with` statement in an asynchronous context manager. """ await self._close()