""" 集成了表的基本操作。 """ import datetime from typing import Optional, Union, List, Sequence import pandas as pd from sqlalchemy import func, desc, Column, select, util, Result from sqlalchemy.engine import ScalarResult, CursorResult from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.sql import Select from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy.sql.operators import ColumnOperators from paste.db import baseadapter, engine class BaseTable(baseadapter.BaseAdapter): """ 表对象(数据模型)基类。集成基础数据操作功能。 实现异步执行方法。 """ __abstract__ = True @classmethod async def raw_execute(cls, query, session: AsyncSession=None, params=None, options=None) -> Optional[CursorResult]: """ 异步连接执行原始查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。 :param query: 查询语句 :param session: 数据库会话对象 :param params: 查询参数 :param options: 查询选项 :return: CursorResult 游标 """ if options is None: options = util.EMPTY_DICT if session is None: connection: AsyncConnection = await cls.get_aio_session().connection() else: connection = await session.connection() try: cursor_result: CursorResult = await connection.execute(query, params, execution_options=options) await connection.commit() return cursor_result finally: await connection.close() @classmethod async def orm_execute(cls, query, session: AsyncSession=None, params=None, options=None) -> Optional[Result]: """ 异步会话执行查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。 :param query: 查询请求 :param session: 数据库会话对象 :param params: 查询参数 :param options: 查询选项 """ _has_session: bool = True if session is None: _has_session = False session = cls.get_aio_session() try: result: Result = await session.execute(query, params, execution_options=options) return result finally: if not _has_session: await session.close() @classmethod async def orm_execute_scalars(cls, query, session: AsyncSession=None, **kwargs) -> ScalarResult: """ 使用异步执行查询。并执行 scalars() 方法,这会进行 ORM 映射。 :param query: 查询对象 :param session: 数据库会话对象 :return: 数据模型对象列表 """ result: Result = await cls.orm_execute(query, session, **kwargs) return result.scalars() @classmethod async def query_all(cls, query: Select, session: AsyncSession=None) -> List['BaseTable']: """ 使用异步执行 ORM 查询。并执行 scalars().all() 方法。 :param query: 查询对象 :param session: 数据库会话对象 :return: 数据模型对象列表,注意:当查询部分字段时仅返回查询列表中第一个字段的数据列表 """ scalars = await cls.orm_execute_scalars(query, session) return list(scalars.all()) @classmethod async def query_first(cls, query: Select, session: AsyncSession=None) -> 'BaseTable': """ 使用异步执行 ORM 查询。并执行 scalars().first() 方法。 :param query: 查询对象 :param session: 数据库会话对象 :return: 数据模型对象列表 """ scalars = await cls.orm_execute_scalars(query, session) return scalars.first() @classmethod async def query_count(cls, query: Select, is_only_count: bool = False, session: AsyncSession=None) -> int: """ 使用异步执行 ORM 查询,查询原查询的数据行数。 :param query: 查询对象 :param is_only_count: 是否使用仅 count 方式查询 :param session: 数据库会话对象 :return: 数据行数 """ if is_only_count: count_query = query.with_only_columns(func.count(1)) else: count_query = select(func.count(1)).select_from(query.subquery()) # 执行查询 row_count: int = (await cls.orm_execute_scalars(count_query, session)).first() return row_count @classmethod async def query_as_df(cls, query, session: AsyncSession=None, params=None, options=None): """ 执行 RAW 原始查询,并将数据请求转换为 :class:`pd.DataFrame` 返回。 注意:为了保持字段数据精度或便于数据处理,可在 SQL 语句中使用 :meth:`func.ifnull` 或 :meth:`func.convert` 等方法直接转换数据类型。 :param query: 查询语句 :param session: 数据库会话对象 :param params: 查询参数 :param options: 查询选项 :return: 数据 DataFrame """ result = await cls.raw_execute(query, session, params, options) return pd.DataFrame(result.all(), columns=pd.Series(result.keys())) @classmethod async def async_row_count(cls, *where_expressions: Union[ColumnOperators, BinaryExpression], session: AsyncSession=None) -> int: """ 异步按条件取得数据行数量。 :param where_expressions: 条件表达式列表 :param session: 数据库会话对象 :return: 行数 """ query = select(func.count(1)).select_from(cls).where(*where_expressions) result = await cls.orm_execute_scalars(query, session) return result.first() @classmethod async def async_exist_by_id(cls, id_val: Union[str, int]): """ 检查是否有存在的主键。 :param id_val: 主键值 """ count: int = await cls.async_row_count(cls.instrument_attr('id') == id_val) return count >= 1 @classmethod async def async_find_by_id(cls, id_val: Union[str, int]) -> Optional['BaseTable']: """ 根据主键ID查找数据,确认有 id 字段后方可使用。 :param id_val: 主键值 :return: 数据模型对象 """ query = select(cls).where(cls.instrument_attr('id') == id_val) model = await cls.query_first(query=query) return model @classmethod async def async_find_by_created(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> List['BaseTable']: """ 按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。 :param from_t: 开始时间 :param to_t: 结束时间 """ if to_t is None: to_t = from_t query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t)) return await cls.query_all(query) @classmethod async def async_find_by_updated(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> List['BaseTable']: """ 按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。 :param from_t: 开始时间 :param to_t: 结束时间 """ if to_t is None: to_t = from_t query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t)) return await cls.query_all(query) @classmethod async def async_find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]): """ 根据数据列表查询已经在数据库中的数据。 :param row_list: 数据列表 :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 :return: 查询到的数据模型列表 """ model_list: list[cls] = [] query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols) if query is not None: model_list = await cls.query_all(query=query) return model_list @classmethod async def async_find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]): """ 根据数据框架查询已经在数据库中的数据。 :param row_df: 数据框架 :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 :return: 查询到的数据模型列表 """ query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols) if query is not None: _result = await cls.orm_execute_scalars(query=query) _rows: Sequence[cls] = _result.all() return pd.DataFrame(_rows) return None async def async_find(self, likes: Optional[dict[str, str]] = None) -> List['BaseTable']: """ 根据自身参数,查询数据库。 :param likes: 模糊条件 :return: 查询到的结果对象 """ expressions = self.filter_expressions(likes=likes) query = select(self.__class__).where(*expressions) return await self.query_all(query) async def async_find_first(self, likes: Optional[dict[str, str]] = None) -> Optional['BaseTable']: """ 根据自身参数,查询数据库,仅查询第一条。 :param likes: 模糊条件 :return: 查询到的结果对象 """ expressions = self.filter_expressions(likes=likes) query = select(self.__class__).where(*expressions) result = await self.query_first(query) return result async def async_find_piece(self, *where: Union[ColumnOperators, BinaryExpression], offset=0, limit=500, is_desc=False, likes: Optional[dict[str, str]] = None) -> List['BaseTable']: """ 根据自身参数,查询数据库。 :param where: 查询条件 :param offset: 偏移量 :param limit: 读取数量 :param is_desc: 是否逆序排列 :param likes: 模糊条件 :return: 查询到的结果对象 """ clz = self.__class__ expressions = self.filter_expressions(likes=likes) if where is not None: expressions += where query = select(clz).where(*expressions) if limit > 0: query = query.limit(limit=limit) if offset >= 0: query = query.offset(offset=offset) if is_desc: if hasattr(clz, 'id'): query = query.order_by(desc(clz.id)) return await self.query_all(query) async def async_save(self, auto_expunge: Optional[bool] = True, session: Optional[AsyncSession] = None): """ 保存数据模型对象,首先强制关闭原有会话,获取新的会话,并加入对象。 :param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效 :param session: 会话对象,主要用于事务 :return: 保存状态 """ _has_session: bool = True """ 该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。 """ self.before_save() self.close_session() if session is None: _has_session = False session = self.get_aio_session() try: session.add(self) if not _has_session: # 使用新会话时,直接提交 await session.commit() self._is_new = False except Exception as e: await session.rollback() raise e else: if auto_expunge and not _has_session: await session.refresh(self) session.expunge(self) return True finally: if not _has_session: # 使用新会话时,主动关闭 await session.close() def create_all_tables(): """ 创建所有的表格。 """ baseadapter.registry.metadata.create_all(engine.connect_engine())