831 lines
26 KiB
Python
Executable File
831 lines
26 KiB
Python
Executable File
"""
|
|
数据访问适配器。主要封装了数据访问的基础方法。
|
|
"""
|
|
import datetime
|
|
import random
|
|
import uuid
|
|
from typing import Optional, Any, Union, List, Sequence
|
|
|
|
import pandas as pd
|
|
from sqlalchemy import Column, String, DateTime, Date, inspect, Table, Integer, select, Numeric, ForeignKey, func, \
|
|
text, desc, Result, Engine, tuple_
|
|
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker
|
|
from sqlalchemy.orm import sessionmaker, Session, InstanceState, InstrumentedAttribute, registry
|
|
from sqlalchemy.sql import Select
|
|
from sqlalchemy.sql.base import ColumnCollection
|
|
from sqlalchemy.sql.compiler import StrSQLCompiler
|
|
from sqlalchemy.sql.elements import BinaryExpression
|
|
|
|
from paste.db import engine
|
|
from paste.util import udict
|
|
|
|
LOCAL_DATE_FORMAT = '%Y-%m-%d'
|
|
LOCAL_TIME_FORMAT = '%H:%M:%S'
|
|
LOCAL_DATETIME_FORMAT = f'{LOCAL_DATE_FORMAT} {LOCAL_TIME_FORMAT}'
|
|
|
|
|
|
def guid() -> str:
|
|
"""
|
|
生成 GUID。
|
|
|
|
:return: GUID
|
|
"""
|
|
return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(uuid.uuid1()) + str(random.random()))).replace('-', '')
|
|
|
|
|
|
def get_session() -> Session:
|
|
"""
|
|
获取线程安全的连接会话对象。
|
|
|
|
:return: 连接会话对象
|
|
"""
|
|
_engine: Engine = engine.connect_engine()
|
|
session_maker = sessionmaker(bind=_engine)
|
|
return session_maker()
|
|
|
|
|
|
def get_aio_session() -> AsyncSession:
|
|
"""
|
|
获取异步连接会话对象。
|
|
|
|
:return: 连接会话对象
|
|
"""
|
|
_engine: AsyncEngine = engine.async_connect_engine()
|
|
session_maker = async_sessionmaker(bind=_engine, class_=AsyncSession)
|
|
return session_maker()
|
|
|
|
|
|
Base = registry().generate_base()
|
|
"""
|
|
取得 SQLAlchemy ORM 的声明式基类。
|
|
"""
|
|
|
|
|
|
class BaseAdapter(Base):
|
|
"""
|
|
适配器类,集成数据库引擎、连接、会话、元数据操作、部分 DML 数据处理等功能。
|
|
主要实现了同步执行方法。
|
|
"""
|
|
__abstract__ = True
|
|
|
|
def __init__(self, **kwargs):
|
|
"""
|
|
构造模型对象。
|
|
|
|
:param args:
|
|
:param kwargs:
|
|
"""
|
|
self._is_new = True
|
|
super().__init__(**kwargs)
|
|
|
|
@classmethod
|
|
def get_session(cls):
|
|
"""
|
|
线程安全的连接会话。
|
|
|
|
:return: 同步会话对象
|
|
"""
|
|
return get_session()
|
|
|
|
@classmethod
|
|
def get_aio_session(cls):
|
|
"""
|
|
异步连接会话。
|
|
|
|
:return: 异步会话对象
|
|
"""
|
|
return get_aio_session()
|
|
|
|
@classmethod
|
|
def ping(cls):
|
|
"""
|
|
尝试连接数据库服务器。
|
|
|
|
:return: 连接结果
|
|
"""
|
|
engine.connect_engine().connect()
|
|
|
|
@classmethod
|
|
async def tables_in_db(cls):
|
|
"""
|
|
取得数据库中所有表名称。
|
|
:return: 表名称列表
|
|
"""
|
|
_query = select(text('table_name')).select_from(
|
|
text('information_schema.tables')
|
|
).where(
|
|
text('table_schema = database()')
|
|
)
|
|
_names = list()
|
|
_session = cls.get_aio_session()
|
|
try:
|
|
_result: Result = await _session.execute(_query)
|
|
_names = list(_result.scalars().all())
|
|
finally:
|
|
await _session.close()
|
|
return _names
|
|
|
|
@classmethod
|
|
async def is_table_exist(cls, table_name: str):
|
|
"""
|
|
判断表是否存在。
|
|
:param table_name: 要判断的表名称
|
|
"""
|
|
_query = select(func.count()).select_from(
|
|
text('information_schema.tables')
|
|
).where(
|
|
text('table_schema = database()'), text('table_name = :table_name')
|
|
).params(
|
|
table_name=table_name
|
|
)
|
|
_count = 0
|
|
_session = cls.get_aio_session()
|
|
try:
|
|
_result: Result = await _session.execute(_query)
|
|
_count = int(_result.scalar())
|
|
finally:
|
|
await _session.close()
|
|
return _count > 0
|
|
|
|
@classmethod
|
|
def table(cls) -> Table:
|
|
"""
|
|
取得当前模型对象所对应的表对象。
|
|
|
|
:return: 表对象
|
|
"""
|
|
tables = cls.metadata.tables
|
|
assert cls.__tablename__ is not None, '属性:cls.__tablename__ 未找到.'
|
|
assert cls.__tablename__ in tables, ('数据库中未找到表:%s.' % cls.__tablename__)
|
|
return tables[cls.__tablename__]
|
|
|
|
@classmethod
|
|
def columns(cls) -> List[Column]:
|
|
"""
|
|
当前对象配置的所有列。
|
|
|
|
:return: 所有列
|
|
"""
|
|
cols = cls.table().columns
|
|
assert isinstance(cols, ColumnCollection), '%s 列类型错误' % cls.__name__
|
|
return list(cols)
|
|
|
|
@classmethod
|
|
def instrument_attr(cls, column_name: str) -> Optional[InstrumentedAttribute]:
|
|
"""
|
|
取得查询条件绑定属性,用于配置查询条件表达式。
|
|
|
|
:param column_name: 列名称
|
|
:return: 类配置的列信息,即查询条件绑定属性
|
|
"""
|
|
attr = getattr(cls, column_name)
|
|
if not isinstance(attr, InstrumentedAttribute):
|
|
return None
|
|
return attr
|
|
|
|
@classmethod
|
|
def label(cls, column: Union[Column, Any]):
|
|
"""
|
|
以字段的 comment 作为标签返回。
|
|
|
|
:return: 指定列的备注
|
|
"""
|
|
return column.comment
|
|
|
|
@classmethod
|
|
def labels(cls) -> dict:
|
|
"""
|
|
取得所有列的名称描述。
|
|
|
|
:return: 所有列的描述字典,结构为:{`field`: `comment`}
|
|
"""
|
|
_label_dict = {}
|
|
for _col in cls.columns():
|
|
_label_dict[_col.key] = _col.comment
|
|
return _label_dict
|
|
|
|
@classmethod
|
|
def field(cls, column: Union[Column, Any]) -> str:
|
|
"""
|
|
取得字段名称。
|
|
|
|
:param column: 列对象
|
|
:return: 字段名称
|
|
"""
|
|
return column.key
|
|
|
|
@classmethod
|
|
def fields(cls) -> List[str]:
|
|
"""
|
|
取得所有字段名称列表。
|
|
|
|
:return: 所有字段名称列表
|
|
"""
|
|
return [col.key for col in cls.columns()]
|
|
|
|
@classmethod
|
|
def new_id(cls, **kwargs) -> str:
|
|
"""
|
|
用GUID作为新的ID。
|
|
|
|
:return: ID
|
|
"""
|
|
return guid()
|
|
|
|
@property
|
|
def is_new(self):
|
|
"""
|
|
通过检查是否已经存在ID判断是否为新建对象。
|
|
|
|
:return: 是否为新建对象
|
|
"""
|
|
return hasattr(self, '_is_new') and self._is_new
|
|
|
|
@classmethod
|
|
def raw_sql(cls, query) -> StrSQLCompiler:
|
|
"""
|
|
显示编译后的原始 SQL 命令。
|
|
|
|
:return: 编译后的 SQL 命令文本
|
|
"""
|
|
raw_sql = query.compile(compile_kwargs={'literal_binds': True})
|
|
return raw_sql
|
|
|
|
@classmethod
|
|
def row_count(cls, *where_expressions: BinaryExpression) -> int:
|
|
"""
|
|
按条件取得记录行数。
|
|
|
|
:param where_expressions: 查询条件
|
|
:return: 返回记录行数
|
|
"""
|
|
query = select(func.count(1)).select_from(cls).where(*where_expressions)
|
|
session = cls.get_session()
|
|
count = session.execute(query).scalars().first()
|
|
session.close()
|
|
return count
|
|
|
|
@classmethod
|
|
def exist_by_id(cls, id_val: Union[str, int]):
|
|
"""
|
|
检查是否有存在的主键。
|
|
|
|
:param id_val: 主键值
|
|
"""
|
|
_wheres = [cls.instrument_attr('id') == id_val]
|
|
count: int = cls.row_count(*_wheres)
|
|
return count >= 1
|
|
|
|
@classmethod
|
|
def search_wheres(cls, likes: Optional[dict[str, str]] = None, **kwargs):
|
|
"""
|
|
按参数组织查询条件。
|
|
|
|
:param likes: 需要执行模糊查询的列
|
|
:param kwargs: 属性填充参数
|
|
:return: 查询条件列表
|
|
"""
|
|
if likes is None:
|
|
likes = []
|
|
_query_model = cls().copy_from_dict(kwargs)
|
|
return _query_model.filter_expressions(likes=likes)
|
|
|
|
@classmethod
|
|
def datalist_query(cls, row_list: List[dict], condition_cols: List[Column]) -> Optional[Select]:
|
|
"""
|
|
根据数据列表、查询条件列生成查询。仅支持单表查询。
|
|
|
|
:param row_list: 数据列表
|
|
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
|
:return: 查询到的数据模型列表
|
|
"""
|
|
if not row_list or not condition_cols:
|
|
return None
|
|
|
|
_condition_fields = [_col.key for _col in condition_cols]
|
|
|
|
# 验证字段存在
|
|
for field in _condition_fields:
|
|
if not hasattr(cls, field):
|
|
raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
|
|
|
|
# 构建值列表的元组形式
|
|
_values_tuples = []
|
|
for _row in row_list:
|
|
_val_list = []
|
|
for _f in _condition_fields:
|
|
_val = udict.get_with_default(_row, _f, '')
|
|
# 单独处理日期时间格式,但是此处固定格式
|
|
if isinstance(_val, datetime.datetime):
|
|
_val = _val.strftime(LOCAL_DATETIME_FORMAT)
|
|
if isinstance(_val, datetime.date):
|
|
_val = _val.strftime(LOCAL_DATE_FORMAT)
|
|
_val_list.append(_val)
|
|
|
|
# 过滤掉全为空的元组
|
|
if any(v != '' for v in _val_list):
|
|
_values_tuples.append(tuple(_val_list))
|
|
|
|
if not _values_tuples:
|
|
return None
|
|
|
|
# 去重,提高性能
|
|
_values_tuples = list(set(_values_tuples))
|
|
|
|
# 使用参数化查询
|
|
if len(condition_cols) == 1:
|
|
# 单字段查询
|
|
field = getattr(cls, _condition_fields[0])
|
|
_single_values = [t[0] for t in _values_tuples]
|
|
query = select(cls).where(field.in_(_single_values))
|
|
else:
|
|
# 多字段组合查询
|
|
conditions = [getattr(cls, field) for field in _condition_fields]
|
|
query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
|
|
|
|
return query
|
|
|
|
@classmethod
|
|
def dataframe_query(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
|
|
"""
|
|
根据数据列表、查询条件列生成查询。仅支持单表查询。
|
|
|
|
:param row_df: 数据框架
|
|
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
|
:return: 查询到的数据模型列表
|
|
"""
|
|
if row_df.empty or not condition_cols:
|
|
return None
|
|
|
|
_condition_fields = [_col.key for _col in condition_cols]
|
|
|
|
# 验证字段存在
|
|
for field in _condition_fields:
|
|
if not hasattr(cls, field):
|
|
raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
|
|
|
|
# 去除重复行并删除包含 NaN 的行
|
|
row_df = row_df[_condition_fields].drop_duplicates().dropna()
|
|
|
|
if row_df.empty:
|
|
return None
|
|
|
|
if len(condition_cols) == 1:
|
|
# 单字段查询
|
|
_field_name = _condition_fields[0]
|
|
field = getattr(cls, _field_name)
|
|
# 直接获取 Series 的值
|
|
_values_list = row_df[_field_name].tolist()
|
|
if not _values_list:
|
|
return None
|
|
query = select(cls).where(field.in_(_values_list))
|
|
else:
|
|
# 多字段组合查询 - 使用向量化操作避免 iterrows()
|
|
_values_tuples = [tuple(row) for row in row_df.values]
|
|
# 再次去重确保
|
|
_values_tuples = list(set(_values_tuples))
|
|
|
|
if not _values_tuples:
|
|
return None
|
|
|
|
conditions = [getattr(cls, field) for field in _condition_fields]
|
|
query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
|
|
|
|
return query
|
|
|
|
@classmethod
|
|
def find_by_id(cls, id_val: Union[str, int], reset_session: Optional[bool] = True) -> Optional['BaseAdapter']:
|
|
"""
|
|
根据主键ID查找数据,确认有 id 字段后方可使用。
|
|
|
|
:param id_val: 主键值
|
|
:param reset_session: 重置会话连接
|
|
:return: 数据模型对象
|
|
"""
|
|
query = select(cls).where(cls.instrument_attr('id') == id_val)
|
|
session = cls.get_session()
|
|
model = session.execute(query).scalars().first()
|
|
session.close()
|
|
|
|
if reset_session and isinstance(model, BaseAdapter):
|
|
model.reset_session()
|
|
|
|
return model
|
|
|
|
@classmethod
|
|
def find_by_created(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
|
|
"""
|
|
按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
|
|
|
|
:param from_t: 开始时间
|
|
:param to_t: 结束时间
|
|
"""
|
|
if to_t is None:
|
|
to_t = from_t
|
|
|
|
query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
|
|
session = cls.get_session()
|
|
rows = session.execute(query).scalars().all()
|
|
session.close()
|
|
return rows
|
|
|
|
@classmethod
|
|
def find_by_updated(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
|
|
"""
|
|
按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
|
|
|
|
:param from_t: 开始时间
|
|
:param to_t: 结束时间
|
|
"""
|
|
if to_t is None:
|
|
to_t = from_t
|
|
|
|
query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
|
|
session = cls.get_session()
|
|
rows = session.execute(query).scalars().all()
|
|
session.close()
|
|
return rows
|
|
|
|
@classmethod
|
|
def find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]) -> Sequence['BaseAdapter']:
|
|
"""
|
|
根据数据列表查询已经在数据库中的数据。
|
|
|
|
:param row_list: 数据列表
|
|
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
|
:return: 查询到的数据模型列表
|
|
"""
|
|
rows: Sequence[cls] = []
|
|
_query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
|
|
if _query is not None:
|
|
session = cls.get_session()
|
|
rows = session.execute(_query).scalars().all()
|
|
session.close()
|
|
|
|
return rows
|
|
|
|
@classmethod
|
|
def find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
|
|
"""
|
|
根据数据框架查询已经在数据库中的数据。
|
|
|
|
:param row_df: 数据框架
|
|
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
|
:return: 查询到的数据模型列表
|
|
"""
|
|
_query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
|
|
if _query is not None:
|
|
session = cls.get_session()
|
|
_rows = session.connection().execute(_query).all()
|
|
session.close()
|
|
return pd.DataFrame(_rows)
|
|
|
|
return None
|
|
|
|
def load(self, model_data: dict):
|
|
"""
|
|
从字典对象载入数据。
|
|
该方法仅从字典中读取与对象属性对应的数据,忽略其他数据。
|
|
注意:与 copyXXX 方法不同,该方法跳过 model_data 中的 None 值,保留原始值不变
|
|
|
|
:param model_data: 包含数据的字典对象
|
|
"""
|
|
for attr in self.__dir__():
|
|
if model_data.get(attr, None) is not None:
|
|
self.__setattr__(attr, model_data.get(attr))
|
|
return self
|
|
|
|
def copy_from(self, source: 'BaseAdapter', mapping_fields: Optional[List[str]] = None):
|
|
"""
|
|
从源数据模型复制数据,仅复制相同字段的数据。
|
|
|
|
:param source: 源数据对象
|
|
:param mapping_fields: 映射字段列表,若为空,则从源数据模型中获取字段列表
|
|
"""
|
|
|
|
# 源对象字段列表
|
|
if mapping_fields is None:
|
|
mapping_fields = [column.key for column in source.columns()]
|
|
|
|
for column in self.columns():
|
|
# 若源对象中不包含同名字段,则跳过
|
|
if column.key not in mapping_fields:
|
|
continue
|
|
|
|
# 取出对象属性值
|
|
attr_val = source.ins_value(column.key)
|
|
setattr(self, column.key, attr_val)
|
|
|
|
return self
|
|
|
|
def copy_from_dict(self, source: dict, mapping_keys: Union[list, tuple, set] = None, skip_none: bool = False):
|
|
"""
|
|
从源字典对象复制数据,仅复制相同键值的数据。
|
|
|
|
:param source: 源数据对象
|
|
:param mapping_keys: 映射键列表,若为空,则从源字典对象中获取键列表
|
|
:param skip_none: 是否跳过 None 值,即若源数据对象中属性值为 None 时,跳过
|
|
:return 返回自身
|
|
"""
|
|
|
|
# 源对象字段列表
|
|
if mapping_keys is None:
|
|
mapping_keys = source.keys()
|
|
|
|
for column in self.columns():
|
|
if column.key not in mapping_keys:
|
|
# 若源对象中不包含同名字段,则跳过
|
|
continue
|
|
|
|
# 取出对象属性值
|
|
attr_val = source.get(column.key, None)
|
|
if skip_none and attr_val is None:
|
|
# 若设置了跳过 None 且此时属性值为空时,跳过
|
|
continue
|
|
setattr(self, column.key, attr_val)
|
|
|
|
return self
|
|
|
|
def inspect(self) -> InstanceState:
|
|
"""
|
|
取得对象的监视对象。
|
|
|
|
:return: 监视对象
|
|
"""
|
|
return inspect(self)
|
|
|
|
def session(self) -> Session:
|
|
"""
|
|
取得当前对象的连接会话对象。
|
|
|
|
:return: 连接会话对象
|
|
"""
|
|
return self.inspect().session
|
|
|
|
def has_session(self) -> bool:
|
|
"""
|
|
检测对象是否已有连接会话对象。
|
|
|
|
:return: 有则返回 True 否则返回 False
|
|
"""
|
|
return self.session() is not None
|
|
|
|
def add_session(self):
|
|
"""
|
|
将对象加入到连接会话对象。
|
|
"""
|
|
if not self.has_session():
|
|
self.get_session().add(self)
|
|
return self.session()
|
|
|
|
def close_session(self):
|
|
"""
|
|
关闭会话。
|
|
"""
|
|
if self.has_session():
|
|
self.session().close()
|
|
|
|
def reset_session(self):
|
|
"""
|
|
重置对象的连接会话。
|
|
"""
|
|
self.close_session()
|
|
self.add_session()
|
|
|
|
def ins_value(self, column_name: str) -> Any:
|
|
"""
|
|
取得查询条件绑定的值,该值与对象属性一致。与::
|
|
self.column_name
|
|
或
|
|
self['column_name']
|
|
|
|
:param column_name: 列名
|
|
:return: 对象属性值
|
|
"""
|
|
# 取出对象属性值
|
|
attr_val = getattr(self, column_name)
|
|
if attr_val is None:
|
|
return None
|
|
return attr_val
|
|
|
|
def filter_expressions(self, likes: Optional[dict[str, str]] = None) -> List[BinaryExpression]:
|
|
"""
|
|
自动生成查询条件表达式。
|
|
参数 likes 为需要执行模糊查询的字段字典::
|
|
|
|
{
|
|
field_name1: '%{}%',
|
|
field_name2: '{}%',
|
|
field_name3: '%{}',
|
|
...
|
|
}
|
|
|
|
:param likes: 需要执行模糊查询的字段字典
|
|
:return: 包含查询条件表达式的列表
|
|
"""
|
|
expressions = list()
|
|
for column in self.columns():
|
|
# 取出对象属性值
|
|
attr_val = self.ins_value(column.key)
|
|
# 若属性值为空,不增加条件,跳出
|
|
if attr_val in (None, ''):
|
|
continue
|
|
|
|
# 取出类属性
|
|
attr = self.instrument_attr(column.key)
|
|
# 若类属性类型不正确,不增加条件,跳出
|
|
if attr is None:
|
|
continue
|
|
|
|
# 取出列的数据类型
|
|
column_type = column.type
|
|
if isinstance(column_type, (String, Integer, Numeric, ForeignKey)):
|
|
# 针对属性数据类型,进行不同处理
|
|
if isinstance(attr_val, (list, tuple, set)):
|
|
if not attr_val:
|
|
# 跳过0长数组
|
|
continue
|
|
# 数组使用 in
|
|
expressions.append(attr.in_(attr_val))
|
|
elif isinstance(attr_val, (int, float)):
|
|
# 数值类型
|
|
expressions.append(attr == attr_val)
|
|
else:
|
|
# 字符类型,使用 like
|
|
if likes and attr.key in likes:
|
|
_f_str = likes[attr.key]
|
|
expressions.append(attr.like(_f_str.format(attr_val)))
|
|
else:
|
|
expressions.append(attr == attr_val)
|
|
elif isinstance(column_type, (DateTime, Date)):
|
|
if isinstance(attr_val, (list, tuple, set)):
|
|
if len(attr_val) == 2:
|
|
# 如果长度为 2 则使用 between and 方式
|
|
expressions.append(attr.between(attr_val[0], attr_val[1]))
|
|
else:
|
|
# 数组使用 in
|
|
expressions.append(attr.in_(attr_val))
|
|
else:
|
|
expressions.append(attr.between(attr_val, attr_val))
|
|
else:
|
|
# 其他类型,用等号
|
|
expressions.append(attr == attr_val)
|
|
|
|
return expressions
|
|
|
|
def gen_query(self, likes: Optional[dict[str, str]] = None):
|
|
"""
|
|
根据自身参数生成查询对象。
|
|
|
|
:param likes: 需要执行模糊查询的字段字典
|
|
:return: 查询对象
|
|
"""
|
|
cls = self.__class__
|
|
expressions = self.filter_expressions(likes)
|
|
_query: Select = select(cls).where(*expressions)
|
|
return _query
|
|
|
|
def find(self) -> Sequence['BaseAdapter']:
|
|
"""
|
|
根据自身参数,查询数据库。
|
|
|
|
:return: 查询到的结果对象
|
|
"""
|
|
session = self.get_session()
|
|
_model_list = session.execute(self.gen_query()).scalars().all()
|
|
session.close()
|
|
return _model_list
|
|
|
|
def find_first(self) -> 'BaseAdapter':
|
|
"""
|
|
根据自身参数,查询数据库,仅查询第一条。
|
|
|
|
:return: 查询到的结果对象
|
|
"""
|
|
session = self.get_session()
|
|
_model = session.execute(self.gen_query()).scalars().first()
|
|
session.close()
|
|
return _model
|
|
|
|
def find_piece(self, *where: BinaryExpression, offset=0, limit=500, is_desc=False,
|
|
likes: Optional[dict[str, str]] = None):
|
|
"""
|
|
根据自身参数,查询数据库。
|
|
|
|
:param where: 查询条件
|
|
:param offset: 偏移量
|
|
:param limit: 读取数量
|
|
:param is_desc: 是否逆序排列
|
|
:param likes: 模糊条件
|
|
:return: 查询到的结果对象
|
|
"""
|
|
clz = self.__class__
|
|
expressions = self.filter_expressions(likes=likes)
|
|
if where is not None:
|
|
expressions += where
|
|
|
|
query = select(clz).where(*expressions)
|
|
if limit > 0:
|
|
query = query.limit(limit=limit)
|
|
if offset >= 0:
|
|
query = query.offset(offset=offset)
|
|
if is_desc:
|
|
if hasattr(clz, 'id'):
|
|
query = query.order_by(desc(clz.id))
|
|
|
|
session = self.get_session()
|
|
_model_list = session.execute(query).scalars().all()
|
|
session.close()
|
|
return _model_list
|
|
|
|
def before_save(self):
|
|
"""
|
|
保存前的动作,一般应当在子类中覆盖该方法,增加在保存前应当执行的动作。
|
|
"""
|
|
if self.is_new:
|
|
if hasattr(self, 'id'):
|
|
setattr(self, 'id', self.new_id())
|
|
|
|
if hasattr(self, 'created_at'):
|
|
setattr(self, 'created_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
|
|
|
|
if hasattr(self, 'updated_at'):
|
|
setattr(self, 'updated_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
|
|
|
|
return self
|
|
|
|
def save(self, auto_expunge: Optional[bool] = True, session: Optional[Session] = None):
|
|
"""
|
|
保存数据模型对象,若该对象尚未有连接会话,则自动加入连接会话。
|
|
|
|
:param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
|
|
:param session: 会话对象,主要用于事务
|
|
:return: 保存状态
|
|
"""
|
|
_has_session: bool = True
|
|
"""
|
|
该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
|
|
"""
|
|
|
|
self.before_save()
|
|
if session is None:
|
|
_has_session = False
|
|
session = self.add_session()
|
|
|
|
try:
|
|
session.add(self)
|
|
if not _has_session:
|
|
# 使用新会话时,直接提交
|
|
session.commit()
|
|
self._is_new = False
|
|
except Exception as e:
|
|
session.rollback()
|
|
raise e
|
|
else:
|
|
if auto_expunge and not _has_session:
|
|
session.refresh(self)
|
|
session.expunge(self)
|
|
return True
|
|
|
|
def to_dict(self) -> dict:
|
|
"""
|
|
数据模型转字典。递归处理内部类型对象。
|
|
|
|
:return: 转换后的字典数据
|
|
"""
|
|
# 模型数据字典
|
|
m_dict = {}
|
|
|
|
# 遍历处理内部转换
|
|
for _key, _val in dict(self.__dict__).items():
|
|
if f'{_key}'.startswith('_'):
|
|
# 跳过所有私有属性
|
|
continue
|
|
|
|
if isinstance(_val, BaseAdapter):
|
|
# 内部数据对象数据,直接转换
|
|
m_dict[_key] = _val.to_dict()
|
|
elif isinstance(_val, list):
|
|
# 遍历转换数据对象列表
|
|
_tmp_list = []
|
|
for _i, _v in enumerate(_val):
|
|
if isinstance(_v, BaseAdapter):
|
|
_tmp_list.append(_v.to_dict())
|
|
else:
|
|
_tmp_list.append(_v)
|
|
m_dict[_key] = _tmp_list
|
|
elif isinstance(_val, dict):
|
|
# 遍历转换数据对象字典
|
|
_tmp_dict = {}
|
|
for _ik, _iv in _val.items():
|
|
if isinstance(_iv, BaseAdapter):
|
|
_tmp_dict[_ik] = _iv.to_dict()
|
|
else:
|
|
_tmp_dict[_ik] = _iv
|
|
m_dict[_key] = _tmp_dict
|
|
else:
|
|
# 其他属性直接赋值
|
|
m_dict[_key] = _val
|
|
|
|
return m_dict
|