"""
aiodesa.Database: Simple SQLite Database Interface
This module provides the `Db` class, a simple SQLite database interface that
supports asynchronous operations.
Classes:
- :class:`Db`: Represents a simple SQLite database interface.
Example:
.. code-block:: python
    from aiodesa import Db
    class Users:
        username: str
        id: str | None = None
        table_name: str = "users"
    async with Db("database.sqlite3") as db:
        await db.read_table_schemas(Users)
"""
from dataclasses import is_dataclass, fields
from typing import Tuple, Callable, Any, Coroutine
from pathlib import Path
import aiosqlite
from aiodesa.utils.table import make_schema, TableSchema
[docs]
class Db:
    """
    Represents a simple SQLite database interface.
    Args:
        db_path : str
                        The path to the SQLite database file.
    Example:
    .. code-block:: python
        class Users:
                username: str
                id: str | None = None
                table_name: str = "users"
        async with Db("database.sqlite3") as db:
            await db.read_table_schemas(Users)
            ...
    """
    _tables: dict
    db_path: Path
    _conn: Any
    def __init__(self, db_path: str) -> None:
        self.db_path = Path(db_path)
        self._conn = None
        self._create_db()
        self._tables = {}
    def _create_db(self) -> None:
        """
        Internal method to create the database file if it does not exist.
        Notes:
        - This method is automatically called during the initialization of the
        Db class.
        - It ensures that the SQLite database file is created at the specified
        path if
        it does not exist.
        """
        if not self.db_path.exists():
            self.db_path.parent.mkdir(parents=True, exist_ok=True)
            self.db_path.touch()
    async def _process_single_data_class(self, schema: Any) -> None:
        """
        Process a single data class schema.
        Args:
            schema: The data class schema representing a table.
        Returns:
            This method does not return any value.
        """
        if not is_dataclass(schema):
            raise ValueError("Provided schema is not a data class")
        self._tables[schema.table_name] = schema
        class_fields = fields(schema)
        for field in class_fields:
            if field.name == "table_name":
                schema_ = make_schema(str(field.default), schema)
                await self._create_table(schema_, field.name)
[docs]
    async def read_table_schemas(self, class_obj: Any | Tuple[Any, ...]) -> None:
        """Read table schemas and create tables in the database.
        Args:
            schema:
                The schema or tuple of schemas to be processed. Each schema
                should be a data class representing a table.
        Returns:
            This method does not return any value.
        Example:
        .. code-block:: python
            class Users:
                    username: str
                    id: str | None = None
                    table_name: str = "users"
            async with Db("database.sqlite3") as db:
                        await db.read_table_schemas(Users)
                ...
        Note:
            Provide any additional notes or considerations about the method.
        """
        # single dataclass
        if is_dataclass(class_obj):
            await self._process_single_data_class(class_obj)
            return
        # tuple of dataclasses
        if isinstance(class_obj, tuple):
            for _obj in class_obj:
                await self._process_single_data_class(_obj)
            return 
    async def _table_exists(self, table_name: str) -> bool | None:
        """
        Create a table in the database based on the provided
        TableSchema instance.
        Args:
            table_name: The name of the table.
        Returns:
            None
        This method creates a table in the database with the specified name
        and schema.
        """
        if self._conn is not None:
            query = "SELECT name FROM sqlite_master \
                  WHERE type='table' AND name=?;"
            cursor = await self._conn.execute(query, (table_name,))
            return await cursor.fetchone() is not None
        return None
    async def _create_table(self, named_data: TableSchema, name: str) -> None:
        """
        Internal method to create a table in the database based on the provided
        TableSchema instance.
        Args:
            named_data: The TableSchema instance containing the table_name
            and SQL data definition.
            name: The name of the table.
        Returns:
            None
        Example:
        .. code-block:: python
            if is_dataclass(schema):
            class_fields = fields(schema)
            for field in class_fields:
                if field.name == "table_name":
                    schema_ = make_schema(str(field.default), schema)
                    await self._create_table(schema_, field.name)
            return
        This method creates a table in the database with the specified name
        and schema.
        Note:
        The `named_data` parameter should include the `table_name` property
        for the name of the table and the `sql_definition` property for the
        SQL data definition of the table.
        """
        if self._conn is not None:
            if not await self._table_exists(name):
                async with self._conn.executescript(named_data.data) as cursor:
                    await cursor.fetchall()
                await self._conn.commit()
[docs]
    def insert(self, data_class: Any) -> Callable[..., Coroutine[Any, Any, None]]:
        """
        Create a record and insert it into the specified table.
        Args:
            data_class: The data class representing the table structure.
        Returns:
            A function to be called with the record data.
        Example:
        .. code-block:: python
            class Users:
                    username: str
                    id: str | None = None
                    table_name: str = "users"
            async with Db("database.sqlite3") as db:
                        await db.read_table_schemas(Users)
                ...
                insert = db.update(UserEcon)
                await insert("john_doe")
        """
        async def _record(*args: Any, **kwargs: Any) -> None:
            data_cls = self._tables[data_class.table_name](*args, **kwargs)
            field_vals = {}
            for field in fields(data_cls):
                value = getattr(data_cls, field.name)
                if value is not None and value != data_cls.table_name:
                    field_vals[field.name] = value
            insertion_vals = tuple(field_vals.values())
            columns_str = ", ".join(field_vals.keys())
            placeholders = ", ".join("?" for _ in insertion_vals)
            sql = f"INSERT INTO {data_class.table_name} \
                ({columns_str}) VALUES ({placeholders});"
            await self._conn.execute(sql, insertion_vals)
            await self._conn.commit()
            return None
        return _record 
[docs]
    def update(
        self, data_class: Any, column_identifier: None | str = None
    ) -> Callable[..., Coroutine[Any, Any, None]]:
        """
        Create a record update operation for the specified table.
        Args:
            data_class: The data class representing the table structure.
            column_identifier: The column to use for identifying records.
        Returns:
            A function to be called with the record data for updating.
        Example:
        .. code-block:: python
            class Users:
                    username: str
                    id: str | None = None
                    table_name: str = "users"
            async with Db("database.sqlite3") as db:
                await db.read_table_schemas(Users)
                ...
                update = db.update(UserEcon)
                await update("john_doe")
        Note:
        If the `column_identifier` is not provided, the primary key of the
        data class will be used as the identifier.
        """
        async def _record(*args, **kwargs) -> None:
            data_cls = self._tables[data_class.table_name](*args, **kwargs)
            values = []
            set_clauses_placeholders = []
            for column, value in kwargs.items():
                values.append(value)
                set_clause = f"{column} = ?"
                set_clauses_placeholders.append(set_clause)
            set_clause_string = ", ".join(set_clauses_placeholders)
            values.extend(args)
            identifier = (
                column_identifier
                if column_identifier is not None
                else data_cls.primary_key
            )
            sql = f"UPDATE {data_class.table_name} SET \
                {set_clause_string} WHERE {identifier} = ?"
            await self._conn.execute(sql, tuple(values))
            await self._conn.commit()
        return _record 
[docs]
    def find(
        self, data_class: Any, column_identifier: None | str = None
    ) -> Callable[..., Coroutine[Any, Any, None]]:
        """
        Create a record retrieval operation for the specified table.
        Args:
            data_class: The data class representing the table structure.
            column_identifier: The column to use for identifying records.
        Defaults to the primary key of the data class if not specified.
        Returns:
            A function to be called with the identifier for record retrieval.
        Example:
        .. code-block:: python
            class MyBestFriends:
                    username: str
                    id: str | None = None
                    table_name: str = "users"
            async with Db("database.sqlite3") as db:
                        await db.read_table_schemas(MyBestFriends)
                ...
                find_jimmy = db.find(MyBestFriends)
                jimmy = await find_jimmy("jimmy")
        """
        async def _record(*args, **kwargs) -> None:
            data_cls = self._tables[data_class.table_name](*args, **kwargs)
            identifier = (
                column_identifier
                if column_identifier is not None
                else data_cls.primary_key
            )
            results = []
            sql = f"SELECT * FROM {data_cls.table_name} WHERE {identifier} = ?"
            sql_args = (args[0],)
            async with self._conn.execute(sql, sql_args) as cursor:
                results = await cursor.fetchall()
            if len(results) > 0:
                rows_fetched = results[0]
                data_cls = data_class(*rows_fetched, *results[1:])
                return data_cls
            return None
        return _record 
[docs]
    def find_all(self, data_class: Any) -> Callable[..., Coroutine[Any, Any, list]]:
        """
        Create record retrieval operation to fetch all records from the specified table.
        Args:
            data_class: The data class representing the table structure.
        Returns:
            A function to be called with optional additional query parameters.
        Example:
        .. code-block:: python
            class Users:
                    username: str
                    id: str | None = None
                    table_name: str = "users"
            async with Db("database.sqlite3") as db:
                await db.read_table_schemas(Users)
                ...
                find_all_users = db.find_all(Users)
                all_users = await find_all_users()
        """
        async def _records() -> list:
            sql = f"SELECT * FROM {data_class.table_name}"
            async with self._conn.execute(sql) as cursor:
                results = await cursor.fetchall()
                records = []
                for row in results:
                    record = data_class(*row)
                    records.append(record)
                return records
        return _records 
[docs]
    def delete(
        self, data_class: Any, column_identifier: None | str = None
    ) -> Callable[..., Coroutine[Any, Any, None]]:
        """
        Create a record deletion operation for the specified table.
        This defaults to the primary key ifthe column_identifier is
        not provided.
        Args:
            data_class: The data class representing the table structure.
            column_identifier: The column to use for identifying records.
        Returns:
            A function to be called with the identifier for record deletion.
        Example:
        .. code-block:: python
            class Users:
                    username: str
                    id: str | None = None
                    table_name: str = "users"
            async with Db("database.sqlite3") as db:
                        await db.read_table_schemas(Users)
                ...
                delete = db.delete(UserEcon)
                await delete("john_doe")
        """
        async def _record(*args, **kwargs) -> None:
            data_cls = self._tables[data_class.table_name](*args, **kwargs)
            identifier = (
                column_identifier
                if column_identifier is not None
                else data_cls.primary_key
            )
            sql = f"DELETE FROM {data_cls.table_name} WHERE {identifier} = ?"
            sql_args = (args[0],)
            async with self._conn.execute(sql, sql_args) as cursor:
                await cursor.fetchall()
            await self._conn.commit()
        return _record 
    async def _connect(self) -> None:
        """
        Establish a connection to the SQLite database.
        Returns:
        None
        Example:
        .. code-block:: python
            connection = YourDatabaseConnection()
            await connection.connect()
            # The database connection is now established.
        Note:
        This method initializes the connection to the SQLite database
        using the provided `db_path`.
        """
        self._conn = await aiosqlite.connect(self.db_path)
    async def _close(self) -> None:
        """
        Close the connection to the SQLite database.
        Returns:
        None
        Example:
        .. code-block:: python
            connection = YourDatabaseConnection()
            await connection.connect()
            # Your database operations here
            await connection.close()
            # The database connection is now closed.
        Note:
        This method closes the connection to the SQLite database if it is open.
        """
        if self._conn is not None:
            await self._conn.close()
            self._conn = None
    async def __aenter__(self) -> "Db":
        """
        Asynchronous context manager entry point.
        Automatically connects to the database upon entering the context.
        Returns:
        Db:
            The Db instance with an active database connection.
        Example:
        .. code-block:: python
            async with YourDatabaseConnection() as connection:
                # Your asynchronous code here
        # Upon entering the context, the database connection is automatically
        established.
        Note:
        This method is intended for use with the `async with` statement in an
        asynchronous context manager. The returned `Db` instance represents
        the connection to the database.
        """
        await self._connect()
        # await self._conn.execute("BEGIN")
        return self
    async def __aexit__(self, exc_type, exc_value, traceback):
        """
        Asynchronous context manager exit point.
        Automatically closes the database connection upon exiting the context.
        Args:
            exc_type (Type): The type of the exception raised, if any.
            exc_value (Exception): The exception object, if an exception
        occurred. Otherwise, None.
            traceback (TracebackType): The traceback information related to
            the exception, if any.
        Returns:
        None
        Example:
        .. code-block:: python
            async with YourDatabaseConnection() as connection:
                # Your asynchronous code here
        # Upon exiting the context, the database connection is
        automatically closed.
        Note:
        This method is intended for use with the `async with` statement in an
        asynchronous context manager.
        """
        await self._close()