338 lines
12 KiB
Python
Executable File
338 lines
12 KiB
Python
Executable File
"""
|
|
集成了表的基本操作。
|
|
"""
|
|
import datetime
|
|
from typing import Optional, Union, List, Sequence
|
|
|
|
import pandas as pd
|
|
from sqlalchemy import func, desc, Column, select, util, Result
|
|
from sqlalchemy.engine import ScalarResult, CursorResult
|
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
|
|
from sqlalchemy.sql import Select
|
|
from sqlalchemy.sql.elements import BinaryExpression
|
|
from sqlalchemy.sql.operators import ColumnOperators
|
|
|
|
from paste.db import baseadapter, engine
|
|
|
|
|
|
class BaseTable(baseadapter.BaseAdapter):
|
|
"""
|
|
表对象(数据模型)基类。集成基础数据操作功能。
|
|
实现异步执行方法。
|
|
"""
|
|
|
|
__abstract__ = True
|
|
|
|
@classmethod
|
|
async def raw_execute(cls, query, session: AsyncSession=None,
|
|
params=None, options=None) -> Optional[CursorResult]:
|
|
"""
|
|
异步连接执行原始查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。
|
|
|
|
:param query: 查询语句
|
|
:param session: 数据库会话对象
|
|
:param params: 查询参数
|
|
:param options: 查询选项
|
|
:return: CursorResult 游标
|
|
"""
|
|
if options is None:
|
|
options = util.EMPTY_DICT
|
|
if session is None:
|
|
connection: AsyncConnection = await cls.get_aio_session().connection()
|
|
else:
|
|
connection = await session.connection()
|
|
|
|
try:
|
|
cursor_result: CursorResult = await connection.execute(query, params, execution_options=options)
|
|
await connection.commit()
|
|
return cursor_result
|
|
finally:
|
|
await connection.close()
|
|
|
|
@classmethod
|
|
async def orm_execute(cls, query, session: AsyncSession=None, params=None, options=None) -> Optional[Result]:
|
|
"""
|
|
异步会话执行查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。
|
|
|
|
:param query: 查询请求
|
|
:param session: 数据库会话对象
|
|
:param params: 查询参数
|
|
:param options: 查询选项
|
|
"""
|
|
_has_session: bool = True
|
|
if session is None:
|
|
_has_session = False
|
|
session = cls.get_aio_session()
|
|
|
|
try:
|
|
result: Result = await session.execute(query, params, execution_options=options)
|
|
return result
|
|
finally:
|
|
if not _has_session:
|
|
await session.close()
|
|
|
|
@classmethod
|
|
async def orm_execute_scalars(cls, query, session: AsyncSession=None, **kwargs) -> ScalarResult:
|
|
"""
|
|
使用异步执行查询。并执行 scalars() 方法,这会进行 ORM 映射。
|
|
|
|
:param query: 查询对象
|
|
:param session: 数据库会话对象
|
|
:return: 数据模型对象列表
|
|
"""
|
|
result: Result = await cls.orm_execute(query, session, **kwargs)
|
|
return result.scalars()
|
|
|
|
@classmethod
|
|
async def query_all(cls, query: Select, session: AsyncSession=None) -> List['BaseTable']:
|
|
"""
|
|
使用异步执行 ORM 查询。并执行 scalars().all() 方法。
|
|
|
|
:param query: 查询对象
|
|
:param session: 数据库会话对象
|
|
:return: 数据模型对象列表,注意:当查询部分字段时仅返回查询列表中第一个字段的数据列表
|
|
"""
|
|
scalars = await cls.orm_execute_scalars(query, session)
|
|
return list(scalars.all())
|
|
|
|
@classmethod
|
|
async def query_first(cls, query: Select, session: AsyncSession=None) -> 'BaseTable':
|
|
"""
|
|
使用异步执行 ORM 查询。并执行 scalars().first() 方法。
|
|
|
|
:param query: 查询对象
|
|
:param session: 数据库会话对象
|
|
:return: 数据模型对象列表
|
|
"""
|
|
scalars = await cls.orm_execute_scalars(query, session)
|
|
return scalars.first()
|
|
|
|
@classmethod
|
|
async def query_count(cls, query: Select, is_only_count: bool = False, session: AsyncSession=None) -> int:
|
|
"""
|
|
使用异步执行 ORM 查询,查询原查询的数据行数。
|
|
|
|
:param query: 查询对象
|
|
:param is_only_count: 是否使用仅 count 方式查询
|
|
:param session: 数据库会话对象
|
|
:return: 数据行数
|
|
"""
|
|
if is_only_count:
|
|
count_query = query.with_only_columns(func.count(1))
|
|
else:
|
|
count_query = select(func.count(1)).select_from(query.subquery())
|
|
# 执行查询
|
|
row_count: int = (await cls.orm_execute_scalars(count_query, session)).first()
|
|
return row_count
|
|
|
|
@classmethod
|
|
async def query_as_df(cls, query, session: AsyncSession=None, params=None, options=None):
|
|
"""
|
|
执行 RAW 原始查询,并将数据请求转换为 :class:`pd.DataFrame` 返回。
|
|
注意:为了保持字段数据精度或便于数据处理,可在 SQL 语句中使用 :meth:`func.ifnull` 或 :meth:`func.convert` 等方法直接转换数据类型。
|
|
|
|
:param query: 查询语句
|
|
:param session: 数据库会话对象
|
|
:param params: 查询参数
|
|
:param options: 查询选项
|
|
:return: 数据 DataFrame
|
|
"""
|
|
result = await cls.raw_execute(query, session, params, options)
|
|
return pd.DataFrame(result.all(), columns=pd.Series(result.keys()))
|
|
|
|
@classmethod
|
|
async def async_row_count(cls, *where_expressions: Union[ColumnOperators, BinaryExpression],
|
|
session: AsyncSession=None) -> int:
|
|
"""
|
|
异步按条件取得数据行数量。
|
|
|
|
:param where_expressions: 条件表达式列表
|
|
:param session: 数据库会话对象
|
|
:return: 行数
|
|
"""
|
|
query = select(func.count(1)).select_from(cls).where(*where_expressions)
|
|
result = await cls.orm_execute_scalars(query, session)
|
|
return result.first()
|
|
|
|
@classmethod
|
|
async def async_exist_by_id(cls, id_val: Union[str, int]):
|
|
"""
|
|
检查是否有存在的主键。
|
|
|
|
:param id_val: 主键值
|
|
"""
|
|
count: int = await cls.async_row_count(cls.instrument_attr('id') == id_val)
|
|
return count >= 1
|
|
|
|
@classmethod
|
|
async def async_find_by_id(cls, id_val: Union[str, int]) -> Optional['BaseTable']:
|
|
"""
|
|
根据主键ID查找数据,确认有 id 字段后方可使用。
|
|
|
|
:param id_val: 主键值
|
|
:return: 数据模型对象
|
|
"""
|
|
query = select(cls).where(cls.instrument_attr('id') == id_val)
|
|
model = await cls.query_first(query=query)
|
|
return model
|
|
|
|
@classmethod
|
|
async def async_find_by_created(cls, from_t: datetime.datetime,
|
|
to_t: Optional[datetime.datetime] = None) -> List['BaseTable']:
|
|
"""
|
|
按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
|
|
|
|
:param from_t: 开始时间
|
|
:param to_t: 结束时间
|
|
"""
|
|
if to_t is None:
|
|
to_t = from_t
|
|
|
|
query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
|
|
return await cls.query_all(query)
|
|
|
|
@classmethod
|
|
async def async_find_by_updated(cls, from_t: datetime.datetime,
|
|
to_t: Optional[datetime.datetime] = None) -> List['BaseTable']:
|
|
"""
|
|
按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
|
|
|
|
:param from_t: 开始时间
|
|
:param to_t: 结束时间
|
|
"""
|
|
if to_t is None:
|
|
to_t = from_t
|
|
|
|
query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
|
|
return await cls.query_all(query)
|
|
|
|
@classmethod
|
|
async def async_find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]):
|
|
"""
|
|
根据数据列表查询已经在数据库中的数据。
|
|
|
|
:param row_list: 数据列表
|
|
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
|
:return: 查询到的数据模型列表
|
|
"""
|
|
model_list: list[cls] = []
|
|
query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
|
|
if query is not None:
|
|
model_list = await cls.query_all(query=query)
|
|
|
|
return model_list
|
|
|
|
@classmethod
|
|
async def async_find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
|
|
"""
|
|
根据数据框架查询已经在数据库中的数据。
|
|
|
|
:param row_df: 数据框架
|
|
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
|
:return: 查询到的数据模型列表
|
|
"""
|
|
query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
|
|
if query is not None:
|
|
_result = await cls.orm_execute_scalars(query=query)
|
|
_rows: Sequence[cls] = _result.all()
|
|
return pd.DataFrame(_rows)
|
|
|
|
return None
|
|
|
|
async def async_find(self, likes: Optional[dict[str, str]] = None) -> List['BaseTable']:
|
|
"""
|
|
根据自身参数,查询数据库。
|
|
|
|
:param likes: 模糊条件
|
|
:return: 查询到的结果对象
|
|
"""
|
|
expressions = self.filter_expressions(likes=likes)
|
|
query = select(self.__class__).where(*expressions)
|
|
return await self.query_all(query)
|
|
|
|
async def async_find_first(self, likes: Optional[dict[str, str]] = None) -> Optional['BaseTable']:
|
|
"""
|
|
根据自身参数,查询数据库,仅查询第一条。
|
|
|
|
:param likes: 模糊条件
|
|
:return: 查询到的结果对象
|
|
"""
|
|
expressions = self.filter_expressions(likes=likes)
|
|
query = select(self.__class__).where(*expressions)
|
|
result = await self.query_first(query)
|
|
return result
|
|
|
|
async def async_find_piece(self, *where: Union[ColumnOperators, BinaryExpression],
|
|
offset=0, limit=500, is_desc=False,
|
|
likes: Optional[dict[str, str]] = None) -> List['BaseTable']:
|
|
"""
|
|
根据自身参数,查询数据库。
|
|
|
|
:param where: 查询条件
|
|
:param offset: 偏移量
|
|
:param limit: 读取数量
|
|
:param is_desc: 是否逆序排列
|
|
:param likes: 模糊条件
|
|
:return: 查询到的结果对象
|
|
"""
|
|
clz = self.__class__
|
|
expressions = self.filter_expressions(likes=likes)
|
|
if where is not None:
|
|
expressions += where
|
|
|
|
query = select(clz).where(*expressions)
|
|
if limit > 0:
|
|
query = query.limit(limit=limit)
|
|
if offset >= 0:
|
|
query = query.offset(offset=offset)
|
|
if is_desc:
|
|
if hasattr(clz, 'id'):
|
|
query = query.order_by(desc(clz.id))
|
|
|
|
return await self.query_all(query)
|
|
|
|
async def async_save(self, auto_expunge: Optional[bool] = True, session: Optional[AsyncSession] = None):
|
|
"""
|
|
保存数据模型对象,首先强制关闭原有会话,获取新的会话,并加入对象。
|
|
|
|
:param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
|
|
:param session: 会话对象,主要用于事务
|
|
:return: 保存状态
|
|
"""
|
|
_has_session: bool = True
|
|
"""
|
|
该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
|
|
"""
|
|
|
|
self.before_save()
|
|
self.close_session()
|
|
if session is None:
|
|
_has_session = False
|
|
session = self.get_aio_session()
|
|
|
|
try:
|
|
session.add(self)
|
|
if not _has_session:
|
|
# 使用新会话时,直接提交
|
|
await session.commit()
|
|
self._is_new = False
|
|
except Exception as e:
|
|
await session.rollback()
|
|
raise e
|
|
else:
|
|
if auto_expunge and not _has_session:
|
|
await session.refresh(self)
|
|
session.expunge(self)
|
|
return True
|
|
finally:
|
|
if not _has_session:
|
|
# 使用新会话时,主动关闭
|
|
await session.close()
|
|
|
|
|
|
def create_all_tables():
|
|
"""
|
|
创建所有的表格。
|
|
"""
|
|
baseadapter.registry.metadata.create_all(engine.connect_engine())
|