Merge commit '47296980495f8bbfc9493e93de85dd62de6fa6b9' as 'paste-framework'
This commit is contained in:
Executable
+830
@@ -0,0 +1,830 @@
|
||||
"""
|
||||
数据访问适配器。主要封装了数据访问的基础方法。
|
||||
"""
|
||||
import datetime
|
||||
import random
|
||||
import uuid
|
||||
from typing import Optional, Any, Union, List, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import Column, String, DateTime, Date, inspect, Table, Integer, select, Numeric, ForeignKey, func, \
|
||||
text, desc, Result, Engine, tuple_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, Session, InstanceState, InstrumentedAttribute, registry
|
||||
from sqlalchemy.sql import Select
|
||||
from sqlalchemy.sql.base import ColumnCollection
|
||||
from sqlalchemy.sql.compiler import StrSQLCompiler
|
||||
from sqlalchemy.sql.elements import BinaryExpression
|
||||
|
||||
from paste.db import engine
|
||||
from paste.util import udict
|
||||
|
||||
LOCAL_DATE_FORMAT = '%Y-%m-%d'
|
||||
LOCAL_TIME_FORMAT = '%H:%M:%S'
|
||||
LOCAL_DATETIME_FORMAT = f'{LOCAL_DATE_FORMAT} {LOCAL_TIME_FORMAT}'
|
||||
|
||||
|
||||
def guid() -> str:
|
||||
"""
|
||||
生成 GUID。
|
||||
|
||||
:return: GUID
|
||||
"""
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(uuid.uuid1()) + str(random.random()))).replace('-', '')
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
"""
|
||||
获取线程安全的连接会话对象。
|
||||
|
||||
:return: 连接会话对象
|
||||
"""
|
||||
_engine: Engine = engine.connect_engine()
|
||||
session_maker = sessionmaker(bind=_engine)
|
||||
return session_maker()
|
||||
|
||||
|
||||
def get_aio_session() -> AsyncSession:
|
||||
"""
|
||||
获取异步连接会话对象。
|
||||
|
||||
:return: 连接会话对象
|
||||
"""
|
||||
_engine: AsyncEngine = engine.async_connect_engine()
|
||||
session_maker = async_sessionmaker(bind=_engine, class_=AsyncSession)
|
||||
return session_maker()
|
||||
|
||||
|
||||
Base = registry().generate_base()
|
||||
"""
|
||||
取得 SQLAlchemy ORM 的声明式基类。
|
||||
"""
|
||||
|
||||
|
||||
class BaseAdapter(Base):
|
||||
"""
|
||||
适配器类,集成数据库引擎、连接、会话、元数据操作、部分 DML 数据处理等功能。
|
||||
主要实现了同步执行方法。
|
||||
"""
|
||||
__abstract__ = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
构造模型对象。
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
"""
|
||||
self._is_new = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_session(cls):
|
||||
"""
|
||||
线程安全的连接会话。
|
||||
|
||||
:return: 同步会话对象
|
||||
"""
|
||||
return get_session()
|
||||
|
||||
@classmethod
|
||||
def get_aio_session(cls):
|
||||
"""
|
||||
异步连接会话。
|
||||
|
||||
:return: 异步会话对象
|
||||
"""
|
||||
return get_aio_session()
|
||||
|
||||
@classmethod
|
||||
def ping(cls):
|
||||
"""
|
||||
尝试连接数据库服务器。
|
||||
|
||||
:return: 连接结果
|
||||
"""
|
||||
engine.connect_engine().connect()
|
||||
|
||||
@classmethod
|
||||
async def tables_in_db(cls):
|
||||
"""
|
||||
取得数据库中所有表名称。
|
||||
:return: 表名称列表
|
||||
"""
|
||||
_query = select(text('table_name')).select_from(
|
||||
text('information_schema.tables')
|
||||
).where(
|
||||
text('table_schema = database()')
|
||||
)
|
||||
_names = list()
|
||||
_session = cls.get_aio_session()
|
||||
try:
|
||||
_result: Result = await _session.execute(_query)
|
||||
_names = list(_result.scalars().all())
|
||||
finally:
|
||||
await _session.close()
|
||||
return _names
|
||||
|
||||
@classmethod
|
||||
async def is_table_exist(cls, table_name: str):
|
||||
"""
|
||||
判断表是否存在。
|
||||
:param table_name: 要判断的表名称
|
||||
"""
|
||||
_query = select(func.count()).select_from(
|
||||
text('information_schema.tables')
|
||||
).where(
|
||||
text('table_schema = database()'), text('table_name = :table_name')
|
||||
).params(
|
||||
table_name=table_name
|
||||
)
|
||||
_count = 0
|
||||
_session = cls.get_aio_session()
|
||||
try:
|
||||
_result: Result = await _session.execute(_query)
|
||||
_count = int(_result.scalar())
|
||||
finally:
|
||||
await _session.close()
|
||||
return _count > 0
|
||||
|
||||
@classmethod
|
||||
def table(cls) -> Table:
|
||||
"""
|
||||
取得当前模型对象所对应的表对象。
|
||||
|
||||
:return: 表对象
|
||||
"""
|
||||
tables = cls.metadata.tables
|
||||
assert cls.__tablename__ is not None, '属性:cls.__tablename__ 未找到.'
|
||||
assert cls.__tablename__ in tables, ('数据库中未找到表:%s.' % cls.__tablename__)
|
||||
return tables[cls.__tablename__]
|
||||
|
||||
@classmethod
|
||||
def columns(cls) -> List[Column]:
|
||||
"""
|
||||
当前对象配置的所有列。
|
||||
|
||||
:return: 所有列
|
||||
"""
|
||||
cols = cls.table().columns
|
||||
assert isinstance(cols, ColumnCollection), '%s 列类型错误' % cls.__name__
|
||||
return list(cols)
|
||||
|
||||
@classmethod
|
||||
def instrument_attr(cls, column_name: str) -> Optional[InstrumentedAttribute]:
|
||||
"""
|
||||
取得查询条件绑定属性,用于配置查询条件表达式。
|
||||
|
||||
:param column_name: 列名称
|
||||
:return: 类配置的列信息,即查询条件绑定属性
|
||||
"""
|
||||
attr = getattr(cls, column_name)
|
||||
if not isinstance(attr, InstrumentedAttribute):
|
||||
return None
|
||||
return attr
|
||||
|
||||
@classmethod
|
||||
def label(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
以字段的 comment 作为标签返回。
|
||||
|
||||
:return: 指定列的备注
|
||||
"""
|
||||
return column.comment
|
||||
|
||||
@classmethod
|
||||
def labels(cls) -> dict:
|
||||
"""
|
||||
取得所有列的名称描述。
|
||||
|
||||
:return: 所有列的描述字典,结构为:{`field`: `comment`}
|
||||
"""
|
||||
_label_dict = {}
|
||||
for _col in cls.columns():
|
||||
_label_dict[_col.key] = _col.comment
|
||||
return _label_dict
|
||||
|
||||
@classmethod
|
||||
def field(cls, column: Union[Column, Any]) -> str:
|
||||
"""
|
||||
取得字段名称。
|
||||
|
||||
:param column: 列对象
|
||||
:return: 字段名称
|
||||
"""
|
||||
return column.key
|
||||
|
||||
@classmethod
|
||||
def fields(cls) -> List[str]:
|
||||
"""
|
||||
取得所有字段名称列表。
|
||||
|
||||
:return: 所有字段名称列表
|
||||
"""
|
||||
return [col.key for col in cls.columns()]
|
||||
|
||||
@classmethod
|
||||
def new_id(cls, **kwargs) -> str:
|
||||
"""
|
||||
用GUID作为新的ID。
|
||||
|
||||
:return: ID
|
||||
"""
|
||||
return guid()
|
||||
|
||||
@property
|
||||
def is_new(self):
|
||||
"""
|
||||
通过检查是否已经存在ID判断是否为新建对象。
|
||||
|
||||
:return: 是否为新建对象
|
||||
"""
|
||||
return hasattr(self, '_is_new') and self._is_new
|
||||
|
||||
@classmethod
|
||||
def raw_sql(cls, query) -> StrSQLCompiler:
|
||||
"""
|
||||
显示编译后的原始 SQL 命令。
|
||||
|
||||
:return: 编译后的 SQL 命令文本
|
||||
"""
|
||||
raw_sql = query.compile(compile_kwargs={'literal_binds': True})
|
||||
return raw_sql
|
||||
|
||||
@classmethod
|
||||
def row_count(cls, *where_expressions: BinaryExpression) -> int:
|
||||
"""
|
||||
按条件取得记录行数。
|
||||
|
||||
:param where_expressions: 查询条件
|
||||
:return: 返回记录行数
|
||||
"""
|
||||
query = select(func.count(1)).select_from(cls).where(*where_expressions)
|
||||
session = cls.get_session()
|
||||
count = session.execute(query).scalars().first()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def exist_by_id(cls, id_val: Union[str, int]):
|
||||
"""
|
||||
检查是否有存在的主键。
|
||||
|
||||
:param id_val: 主键值
|
||||
"""
|
||||
_wheres = [cls.instrument_attr('id') == id_val]
|
||||
count: int = cls.row_count(*_wheres)
|
||||
return count >= 1
|
||||
|
||||
@classmethod
|
||||
def search_wheres(cls, likes: Optional[dict[str, str]] = None, **kwargs):
|
||||
"""
|
||||
按参数组织查询条件。
|
||||
|
||||
:param likes: 需要执行模糊查询的列
|
||||
:param kwargs: 属性填充参数
|
||||
:return: 查询条件列表
|
||||
"""
|
||||
if likes is None:
|
||||
likes = []
|
||||
_query_model = cls().copy_from_dict(kwargs)
|
||||
return _query_model.filter_expressions(likes=likes)
|
||||
|
||||
@classmethod
|
||||
def datalist_query(cls, row_list: List[dict], condition_cols: List[Column]) -> Optional[Select]:
|
||||
"""
|
||||
根据数据列表、查询条件列生成查询。仅支持单表查询。
|
||||
|
||||
:param row_list: 数据列表
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
if not row_list or not condition_cols:
|
||||
return None
|
||||
|
||||
_condition_fields = [_col.key for _col in condition_cols]
|
||||
|
||||
# 验证字段存在
|
||||
for field in _condition_fields:
|
||||
if not hasattr(cls, field):
|
||||
raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
|
||||
|
||||
# 构建值列表的元组形式
|
||||
_values_tuples = []
|
||||
for _row in row_list:
|
||||
_val_list = []
|
||||
for _f in _condition_fields:
|
||||
_val = udict.get_with_default(_row, _f, '')
|
||||
# 单独处理日期时间格式,但是此处固定格式
|
||||
if isinstance(_val, datetime.datetime):
|
||||
_val = _val.strftime(LOCAL_DATETIME_FORMAT)
|
||||
if isinstance(_val, datetime.date):
|
||||
_val = _val.strftime(LOCAL_DATE_FORMAT)
|
||||
_val_list.append(_val)
|
||||
|
||||
# 过滤掉全为空的元组
|
||||
if any(v != '' for v in _val_list):
|
||||
_values_tuples.append(tuple(_val_list))
|
||||
|
||||
if not _values_tuples:
|
||||
return None
|
||||
|
||||
# 去重,提高性能
|
||||
_values_tuples = list(set(_values_tuples))
|
||||
|
||||
# 使用参数化查询
|
||||
if len(condition_cols) == 1:
|
||||
# 单字段查询
|
||||
field = getattr(cls, _condition_fields[0])
|
||||
_single_values = [t[0] for t in _values_tuples]
|
||||
query = select(cls).where(field.in_(_single_values))
|
||||
else:
|
||||
# 多字段组合查询
|
||||
conditions = [getattr(cls, field) for field in _condition_fields]
|
||||
query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
|
||||
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def dataframe_query(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
|
||||
"""
|
||||
根据数据列表、查询条件列生成查询。仅支持单表查询。
|
||||
|
||||
:param row_df: 数据框架
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
if row_df.empty or not condition_cols:
|
||||
return None
|
||||
|
||||
_condition_fields = [_col.key for _col in condition_cols]
|
||||
|
||||
# 验证字段存在
|
||||
for field in _condition_fields:
|
||||
if not hasattr(cls, field):
|
||||
raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
|
||||
|
||||
# 去除重复行并删除包含 NaN 的行
|
||||
row_df = row_df[_condition_fields].drop_duplicates().dropna()
|
||||
|
||||
if row_df.empty:
|
||||
return None
|
||||
|
||||
if len(condition_cols) == 1:
|
||||
# 单字段查询
|
||||
_field_name = _condition_fields[0]
|
||||
field = getattr(cls, _field_name)
|
||||
# 直接获取 Series 的值
|
||||
_values_list = row_df[_field_name].tolist()
|
||||
if not _values_list:
|
||||
return None
|
||||
query = select(cls).where(field.in_(_values_list))
|
||||
else:
|
||||
# 多字段组合查询 - 使用向量化操作避免 iterrows()
|
||||
_values_tuples = [tuple(row) for row in row_df.values]
|
||||
# 再次去重确保
|
||||
_values_tuples = list(set(_values_tuples))
|
||||
|
||||
if not _values_tuples:
|
||||
return None
|
||||
|
||||
conditions = [getattr(cls, field) for field in _condition_fields]
|
||||
query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
|
||||
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def find_by_id(cls, id_val: Union[str, int], reset_session: Optional[bool] = True) -> Optional['BaseAdapter']:
|
||||
"""
|
||||
根据主键ID查找数据,确认有 id 字段后方可使用。
|
||||
|
||||
:param id_val: 主键值
|
||||
:param reset_session: 重置会话连接
|
||||
:return: 数据模型对象
|
||||
"""
|
||||
query = select(cls).where(cls.instrument_attr('id') == id_val)
|
||||
session = cls.get_session()
|
||||
model = session.execute(query).scalars().first()
|
||||
session.close()
|
||||
|
||||
if reset_session and isinstance(model, BaseAdapter):
|
||||
model.reset_session()
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def find_by_created(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
|
||||
"""
|
||||
按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
|
||||
|
||||
:param from_t: 开始时间
|
||||
:param to_t: 结束时间
|
||||
"""
|
||||
if to_t is None:
|
||||
to_t = from_t
|
||||
|
||||
query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
|
||||
session = cls.get_session()
|
||||
rows = session.execute(query).scalars().all()
|
||||
session.close()
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def find_by_updated(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
|
||||
"""
|
||||
按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
|
||||
|
||||
:param from_t: 开始时间
|
||||
:param to_t: 结束时间
|
||||
"""
|
||||
if to_t is None:
|
||||
to_t = from_t
|
||||
|
||||
query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
|
||||
session = cls.get_session()
|
||||
rows = session.execute(query).scalars().all()
|
||||
session.close()
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]) -> Sequence['BaseAdapter']:
|
||||
"""
|
||||
根据数据列表查询已经在数据库中的数据。
|
||||
|
||||
:param row_list: 数据列表
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
rows: Sequence[cls] = []
|
||||
_query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
|
||||
if _query is not None:
|
||||
session = cls.get_session()
|
||||
rows = session.execute(_query).scalars().all()
|
||||
session.close()
|
||||
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
|
||||
"""
|
||||
根据数据框架查询已经在数据库中的数据。
|
||||
|
||||
:param row_df: 数据框架
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
_query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
|
||||
if _query is not None:
|
||||
session = cls.get_session()
|
||||
_rows = session.connection().execute(_query).all()
|
||||
session.close()
|
||||
return pd.DataFrame(_rows)
|
||||
|
||||
return None
|
||||
|
||||
def load(self, model_data: dict):
|
||||
"""
|
||||
从字典对象载入数据。
|
||||
该方法仅从字典中读取与对象属性对应的数据,忽略其他数据。
|
||||
注意:与 copyXXX 方法不同,该方法跳过 model_data 中的 None 值,保留原始值不变
|
||||
|
||||
:param model_data: 包含数据的字典对象
|
||||
"""
|
||||
for attr in self.__dir__():
|
||||
if model_data.get(attr, None) is not None:
|
||||
self.__setattr__(attr, model_data.get(attr))
|
||||
return self
|
||||
|
||||
def copy_from(self, source: 'BaseAdapter', mapping_fields: Optional[List[str]] = None):
|
||||
"""
|
||||
从源数据模型复制数据,仅复制相同字段的数据。
|
||||
|
||||
:param source: 源数据对象
|
||||
:param mapping_fields: 映射字段列表,若为空,则从源数据模型中获取字段列表
|
||||
"""
|
||||
|
||||
# 源对象字段列表
|
||||
if mapping_fields is None:
|
||||
mapping_fields = [column.key for column in source.columns()]
|
||||
|
||||
for column in self.columns():
|
||||
# 若源对象中不包含同名字段,则跳过
|
||||
if column.key not in mapping_fields:
|
||||
continue
|
||||
|
||||
# 取出对象属性值
|
||||
attr_val = source.ins_value(column.key)
|
||||
setattr(self, column.key, attr_val)
|
||||
|
||||
return self
|
||||
|
||||
def copy_from_dict(self, source: dict, mapping_keys: Union[list, tuple, set] = None, skip_none: bool = False):
|
||||
"""
|
||||
从源字典对象复制数据,仅复制相同键值的数据。
|
||||
|
||||
:param source: 源数据对象
|
||||
:param mapping_keys: 映射键列表,若为空,则从源字典对象中获取键列表
|
||||
:param skip_none: 是否跳过 None 值,即若源数据对象中属性值为 None 时,跳过
|
||||
:return 返回自身
|
||||
"""
|
||||
|
||||
# 源对象字段列表
|
||||
if mapping_keys is None:
|
||||
mapping_keys = source.keys()
|
||||
|
||||
for column in self.columns():
|
||||
if column.key not in mapping_keys:
|
||||
# 若源对象中不包含同名字段,则跳过
|
||||
continue
|
||||
|
||||
# 取出对象属性值
|
||||
attr_val = source.get(column.key, None)
|
||||
if skip_none and attr_val is None:
|
||||
# 若设置了跳过 None 且此时属性值为空时,跳过
|
||||
continue
|
||||
setattr(self, column.key, attr_val)
|
||||
|
||||
return self
|
||||
|
||||
def inspect(self) -> InstanceState:
|
||||
"""
|
||||
取得对象的监视对象。
|
||||
|
||||
:return: 监视对象
|
||||
"""
|
||||
return inspect(self)
|
||||
|
||||
def session(self) -> Session:
|
||||
"""
|
||||
取得当前对象的连接会话对象。
|
||||
|
||||
:return: 连接会话对象
|
||||
"""
|
||||
return self.inspect().session
|
||||
|
||||
def has_session(self) -> bool:
|
||||
"""
|
||||
检测对象是否已有连接会话对象。
|
||||
|
||||
:return: 有则返回 True 否则返回 False
|
||||
"""
|
||||
return self.session() is not None
|
||||
|
||||
def add_session(self):
|
||||
"""
|
||||
将对象加入到连接会话对象。
|
||||
"""
|
||||
if not self.has_session():
|
||||
self.get_session().add(self)
|
||||
return self.session()
|
||||
|
||||
def close_session(self):
|
||||
"""
|
||||
关闭会话。
|
||||
"""
|
||||
if self.has_session():
|
||||
self.session().close()
|
||||
|
||||
def reset_session(self):
|
||||
"""
|
||||
重置对象的连接会话。
|
||||
"""
|
||||
self.close_session()
|
||||
self.add_session()
|
||||
|
||||
def ins_value(self, column_name: str) -> Any:
|
||||
"""
|
||||
取得查询条件绑定的值,该值与对象属性一致。与::
|
||||
self.column_name
|
||||
或
|
||||
self['column_name']
|
||||
|
||||
:param column_name: 列名
|
||||
:return: 对象属性值
|
||||
"""
|
||||
# 取出对象属性值
|
||||
attr_val = getattr(self, column_name)
|
||||
if attr_val is None:
|
||||
return None
|
||||
return attr_val
|
||||
|
||||
def filter_expressions(self, likes: Optional[dict[str, str]] = None) -> List[BinaryExpression]:
|
||||
"""
|
||||
自动生成查询条件表达式。
|
||||
参数 likes 为需要执行模糊查询的字段字典::
|
||||
|
||||
{
|
||||
field_name1: '%{}%',
|
||||
field_name2: '{}%',
|
||||
field_name3: '%{}',
|
||||
...
|
||||
}
|
||||
|
||||
:param likes: 需要执行模糊查询的字段字典
|
||||
:return: 包含查询条件表达式的列表
|
||||
"""
|
||||
expressions = list()
|
||||
for column in self.columns():
|
||||
# 取出对象属性值
|
||||
attr_val = self.ins_value(column.key)
|
||||
# 若属性值为空,不增加条件,跳出
|
||||
if attr_val in (None, ''):
|
||||
continue
|
||||
|
||||
# 取出类属性
|
||||
attr = self.instrument_attr(column.key)
|
||||
# 若类属性类型不正确,不增加条件,跳出
|
||||
if attr is None:
|
||||
continue
|
||||
|
||||
# 取出列的数据类型
|
||||
column_type = column.type
|
||||
if isinstance(column_type, (String, Integer, Numeric, ForeignKey)):
|
||||
# 针对属性数据类型,进行不同处理
|
||||
if isinstance(attr_val, (list, tuple, set)):
|
||||
if not attr_val:
|
||||
# 跳过0长数组
|
||||
continue
|
||||
# 数组使用 in
|
||||
expressions.append(attr.in_(attr_val))
|
||||
elif isinstance(attr_val, (int, float)):
|
||||
# 数值类型
|
||||
expressions.append(attr == attr_val)
|
||||
else:
|
||||
# 字符类型,使用 like
|
||||
if likes and attr.key in likes:
|
||||
_f_str = likes[attr.key]
|
||||
expressions.append(attr.like(_f_str.format(attr_val)))
|
||||
else:
|
||||
expressions.append(attr == attr_val)
|
||||
elif isinstance(column_type, (DateTime, Date)):
|
||||
if isinstance(attr_val, (list, tuple, set)):
|
||||
if len(attr_val) == 2:
|
||||
# 如果长度为 2 则使用 between and 方式
|
||||
expressions.append(attr.between(attr_val[0], attr_val[1]))
|
||||
else:
|
||||
# 数组使用 in
|
||||
expressions.append(attr.in_(attr_val))
|
||||
else:
|
||||
expressions.append(attr.between(attr_val, attr_val))
|
||||
else:
|
||||
# 其他类型,用等号
|
||||
expressions.append(attr == attr_val)
|
||||
|
||||
return expressions
|
||||
|
||||
def gen_query(self, likes: Optional[dict[str, str]] = None):
|
||||
"""
|
||||
根据自身参数生成查询对象。
|
||||
|
||||
:param likes: 需要执行模糊查询的字段字典
|
||||
:return: 查询对象
|
||||
"""
|
||||
cls = self.__class__
|
||||
expressions = self.filter_expressions(likes)
|
||||
_query: Select = select(cls).where(*expressions)
|
||||
return _query
|
||||
|
||||
def find(self) -> Sequence['BaseAdapter']:
|
||||
"""
|
||||
根据自身参数,查询数据库。
|
||||
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
session = self.get_session()
|
||||
_model_list = session.execute(self.gen_query()).scalars().all()
|
||||
session.close()
|
||||
return _model_list
|
||||
|
||||
def find_first(self) -> 'BaseAdapter':
|
||||
"""
|
||||
根据自身参数,查询数据库,仅查询第一条。
|
||||
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
session = self.get_session()
|
||||
_model = session.execute(self.gen_query()).scalars().first()
|
||||
session.close()
|
||||
return _model
|
||||
|
||||
def find_piece(self, *where: BinaryExpression, offset=0, limit=500, is_desc=False,
|
||||
likes: Optional[dict[str, str]] = None):
|
||||
"""
|
||||
根据自身参数,查询数据库。
|
||||
|
||||
:param where: 查询条件
|
||||
:param offset: 偏移量
|
||||
:param limit: 读取数量
|
||||
:param is_desc: 是否逆序排列
|
||||
:param likes: 模糊条件
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
clz = self.__class__
|
||||
expressions = self.filter_expressions(likes=likes)
|
||||
if where is not None:
|
||||
expressions += where
|
||||
|
||||
query = select(clz).where(*expressions)
|
||||
if limit > 0:
|
||||
query = query.limit(limit=limit)
|
||||
if offset >= 0:
|
||||
query = query.offset(offset=offset)
|
||||
if is_desc:
|
||||
if hasattr(clz, 'id'):
|
||||
query = query.order_by(desc(clz.id))
|
||||
|
||||
session = self.get_session()
|
||||
_model_list = session.execute(query).scalars().all()
|
||||
session.close()
|
||||
return _model_list
|
||||
|
||||
def before_save(self):
|
||||
"""
|
||||
保存前的动作,一般应当在子类中覆盖该方法,增加在保存前应当执行的动作。
|
||||
"""
|
||||
if self.is_new:
|
||||
if hasattr(self, 'id'):
|
||||
setattr(self, 'id', self.new_id())
|
||||
|
||||
if hasattr(self, 'created_at'):
|
||||
setattr(self, 'created_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
|
||||
|
||||
if hasattr(self, 'updated_at'):
|
||||
setattr(self, 'updated_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
|
||||
|
||||
return self
|
||||
|
||||
def save(self, auto_expunge: Optional[bool] = True, session: Optional[Session] = None):
|
||||
"""
|
||||
保存数据模型对象,若该对象尚未有连接会话,则自动加入连接会话。
|
||||
|
||||
:param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
|
||||
:param session: 会话对象,主要用于事务
|
||||
:return: 保存状态
|
||||
"""
|
||||
_has_session: bool = True
|
||||
"""
|
||||
该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
|
||||
"""
|
||||
|
||||
self.before_save()
|
||||
if session is None:
|
||||
_has_session = False
|
||||
session = self.add_session()
|
||||
|
||||
try:
|
||||
session.add(self)
|
||||
if not _has_session:
|
||||
# 使用新会话时,直接提交
|
||||
session.commit()
|
||||
self._is_new = False
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
else:
|
||||
if auto_expunge and not _has_session:
|
||||
session.refresh(self)
|
||||
session.expunge(self)
|
||||
return True
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
数据模型转字典。递归处理内部类型对象。
|
||||
|
||||
:return: 转换后的字典数据
|
||||
"""
|
||||
# 模型数据字典
|
||||
m_dict = {}
|
||||
|
||||
# 遍历处理内部转换
|
||||
for _key, _val in dict(self.__dict__).items():
|
||||
if f'{_key}'.startswith('_'):
|
||||
# 跳过所有私有属性
|
||||
continue
|
||||
|
||||
if isinstance(_val, BaseAdapter):
|
||||
# 内部数据对象数据,直接转换
|
||||
m_dict[_key] = _val.to_dict()
|
||||
elif isinstance(_val, list):
|
||||
# 遍历转换数据对象列表
|
||||
_tmp_list = []
|
||||
for _i, _v in enumerate(_val):
|
||||
if isinstance(_v, BaseAdapter):
|
||||
_tmp_list.append(_v.to_dict())
|
||||
else:
|
||||
_tmp_list.append(_v)
|
||||
m_dict[_key] = _tmp_list
|
||||
elif isinstance(_val, dict):
|
||||
# 遍历转换数据对象字典
|
||||
_tmp_dict = {}
|
||||
for _ik, _iv in _val.items():
|
||||
if isinstance(_iv, BaseAdapter):
|
||||
_tmp_dict[_ik] = _iv.to_dict()
|
||||
else:
|
||||
_tmp_dict[_ik] = _iv
|
||||
m_dict[_key] = _tmp_dict
|
||||
else:
|
||||
# 其他属性直接赋值
|
||||
m_dict[_key] = _val
|
||||
|
||||
return m_dict
|
||||
Reference in New Issue
Block a user