Files
paste-framework/paste/db/basetable.py
T
2026-06-02 16:26:10 +08:00

338 lines
12 KiB
Python
Executable File

"""
集成了表的基本操作。
"""
import datetime
from typing import Optional, Union, List, Sequence
import pandas as pd
from sqlalchemy import func, desc, Column, select, util, Result
from sqlalchemy.engine import ScalarResult, CursorResult
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.sql import Select
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.operators import ColumnOperators
from paste.db import baseadapter, engine
class BaseTable(baseadapter.BaseAdapter):
"""
表对象(数据模型)基类。集成基础数据操作功能。
实现异步执行方法。
"""
__abstract__ = True
@classmethod
async def raw_execute(cls, query, session: AsyncSession=None,
params=None, options=None) -> Optional[CursorResult]:
"""
异步连接执行原始查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。
:param query: 查询语句
:param session: 数据库会话对象
:param params: 查询参数
:param options: 查询选项
:return: CursorResult 游标
"""
if options is None:
options = util.EMPTY_DICT
if session is None:
connection: AsyncConnection = await cls.get_aio_session().connection()
else:
connection = await session.connection()
try:
cursor_result: CursorResult = await connection.execute(query, params, execution_options=options)
await connection.commit()
return cursor_result
finally:
await connection.close()
@classmethod
async def orm_execute(cls, query, session: AsyncSession=None, params=None, options=None) -> Optional[Result]:
"""
异步会话执行查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。
:param query: 查询请求
:param session: 数据库会话对象
:param params: 查询参数
:param options: 查询选项
"""
_has_session: bool = True
if session is None:
_has_session = False
session = cls.get_aio_session()
try:
result: Result = await session.execute(query, params, execution_options=options)
return result
finally:
if not _has_session:
await session.close()
@classmethod
async def orm_execute_scalars(cls, query, session: AsyncSession=None, **kwargs) -> ScalarResult:
"""
使用异步执行查询。并执行 scalars() 方法,这会进行 ORM 映射。
:param query: 查询对象
:param session: 数据库会话对象
:return: 数据模型对象列表
"""
result: Result = await cls.orm_execute(query, session, **kwargs)
return result.scalars()
@classmethod
async def query_all(cls, query: Select, session: AsyncSession=None) -> List['BaseTable']:
"""
使用异步执行 ORM 查询。并执行 scalars().all() 方法。
:param query: 查询对象
:param session: 数据库会话对象
:return: 数据模型对象列表,注意:当查询部分字段时仅返回查询列表中第一个字段的数据列表
"""
scalars = await cls.orm_execute_scalars(query, session)
return list(scalars.all())
@classmethod
async def query_first(cls, query: Select, session: AsyncSession=None) -> 'BaseTable':
"""
使用异步执行 ORM 查询。并执行 scalars().first() 方法。
:param query: 查询对象
:param session: 数据库会话对象
:return: 数据模型对象列表
"""
scalars = await cls.orm_execute_scalars(query, session)
return scalars.first()
@classmethod
async def query_count(cls, query: Select, is_only_count: bool = False, session: AsyncSession=None) -> int:
"""
使用异步执行 ORM 查询,查询原查询的数据行数。
:param query: 查询对象
:param is_only_count: 是否使用仅 count 方式查询
:param session: 数据库会话对象
:return: 数据行数
"""
if is_only_count:
count_query = query.with_only_columns(func.count(1))
else:
count_query = select(func.count(1)).select_from(query.subquery())
# 执行查询
row_count: int = (await cls.orm_execute_scalars(count_query, session)).first()
return row_count
@classmethod
async def query_as_df(cls, query, session: AsyncSession=None, params=None, options=None):
"""
执行 RAW 原始查询,并将数据请求转换为 :class:`pd.DataFrame` 返回。
注意:为了保持字段数据精度或便于数据处理,可在 SQL 语句中使用 :meth:`func.ifnull` 或 :meth:`func.convert` 等方法直接转换数据类型。
:param query: 查询语句
:param session: 数据库会话对象
:param params: 查询参数
:param options: 查询选项
:return: 数据 DataFrame
"""
result = await cls.raw_execute(query, session, params, options)
return pd.DataFrame(result.all(), columns=pd.Series(result.keys()))
@classmethod
async def async_row_count(cls, *where_expressions: Union[ColumnOperators, BinaryExpression],
session: AsyncSession=None) -> int:
"""
异步按条件取得数据行数量。
:param where_expressions: 条件表达式列表
:param session: 数据库会话对象
:return: 行数
"""
query = select(func.count(1)).select_from(cls).where(*where_expressions)
result = await cls.orm_execute_scalars(query, session)
return result.first()
@classmethod
async def async_exist_by_id(cls, id_val: Union[str, int]):
"""
检查是否有存在的主键。
:param id_val: 主键值
"""
count: int = await cls.async_row_count(cls.instrument_attr('id') == id_val)
return count >= 1
@classmethod
async def async_find_by_id(cls, id_val: Union[str, int]) -> Optional['BaseTable']:
"""
根据主键ID查找数据,确认有 id 字段后方可使用。
:param id_val: 主键值
:return: 数据模型对象
"""
query = select(cls).where(cls.instrument_attr('id') == id_val)
model = await cls.query_first(query=query)
return model
@classmethod
async def async_find_by_created(cls, from_t: datetime.datetime,
to_t: Optional[datetime.datetime] = None) -> List['BaseTable']:
"""
按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
:param from_t: 开始时间
:param to_t: 结束时间
"""
if to_t is None:
to_t = from_t
query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
return await cls.query_all(query)
@classmethod
async def async_find_by_updated(cls, from_t: datetime.datetime,
to_t: Optional[datetime.datetime] = None) -> List['BaseTable']:
"""
按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
:param from_t: 开始时间
:param to_t: 结束时间
"""
if to_t is None:
to_t = from_t
query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
return await cls.query_all(query)
@classmethod
async def async_find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]):
"""
根据数据列表查询已经在数据库中的数据。
:param row_list: 数据列表
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
:return: 查询到的数据模型列表
"""
model_list: list[cls] = []
query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
if query is not None:
model_list = await cls.query_all(query=query)
return model_list
@classmethod
async def async_find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
"""
根据数据框架查询已经在数据库中的数据。
:param row_df: 数据框架
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
:return: 查询到的数据模型列表
"""
query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
if query is not None:
_result = await cls.orm_execute_scalars(query=query)
_rows: Sequence[cls] = _result.all()
return pd.DataFrame(_rows)
return None
async def async_find(self, likes: Optional[dict[str, str]] = None) -> List['BaseTable']:
"""
根据自身参数,查询数据库。
:param likes: 模糊条件
:return: 查询到的结果对象
"""
expressions = self.filter_expressions(likes=likes)
query = select(self.__class__).where(*expressions)
return await self.query_all(query)
async def async_find_first(self, likes: Optional[dict[str, str]] = None) -> Optional['BaseTable']:
"""
根据自身参数,查询数据库,仅查询第一条。
:param likes: 模糊条件
:return: 查询到的结果对象
"""
expressions = self.filter_expressions(likes=likes)
query = select(self.__class__).where(*expressions)
result = await self.query_first(query)
return result
async def async_find_piece(self, *where: Union[ColumnOperators, BinaryExpression],
offset=0, limit=500, is_desc=False,
likes: Optional[dict[str, str]] = None) -> List['BaseTable']:
"""
根据自身参数,查询数据库。
:param where: 查询条件
:param offset: 偏移量
:param limit: 读取数量
:param is_desc: 是否逆序排列
:param likes: 模糊条件
:return: 查询到的结果对象
"""
clz = self.__class__
expressions = self.filter_expressions(likes=likes)
if where is not None:
expressions += where
query = select(clz).where(*expressions)
if limit > 0:
query = query.limit(limit=limit)
if offset >= 0:
query = query.offset(offset=offset)
if is_desc:
if hasattr(clz, 'id'):
query = query.order_by(desc(clz.id))
return await self.query_all(query)
async def async_save(self, auto_expunge: Optional[bool] = True, session: Optional[AsyncSession] = None):
"""
保存数据模型对象,首先强制关闭原有会话,获取新的会话,并加入对象。
:param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
:param session: 会话对象,主要用于事务
:return: 保存状态
"""
_has_session: bool = True
"""
该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
"""
self.before_save()
self.close_session()
if session is None:
_has_session = False
session = self.get_aio_session()
try:
session.add(self)
if not _has_session:
# 使用新会话时,直接提交
await session.commit()
self._is_new = False
except Exception as e:
await session.rollback()
raise e
else:
if auto_expunge and not _has_session:
await session.refresh(self)
session.expunge(self)
return True
finally:
if not _has_session:
# 使用新会话时,主动关闭
await session.close()
def create_all_tables():
"""
创建所有的表格。
"""
baseadapter.registry.metadata.create_all(engine.connect_engine())