Files
2026-06-02 16:26:10 +08:00

831 lines
26 KiB
Python
Executable File

"""
数据访问适配器。主要封装了数据访问的基础方法。
"""
import datetime
import random
import uuid
from typing import Optional, Any, Union, List, Sequence
import pandas as pd
from sqlalchemy import Column, String, DateTime, Date, inspect, Table, Integer, select, Numeric, ForeignKey, func, \
text, desc, Result, Engine, tuple_
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker
from sqlalchemy.orm import sessionmaker, Session, InstanceState, InstrumentedAttribute, registry
from sqlalchemy.sql import Select
from sqlalchemy.sql.base import ColumnCollection
from sqlalchemy.sql.compiler import StrSQLCompiler
from sqlalchemy.sql.elements import BinaryExpression
from paste.db import engine
from paste.util import udict
LOCAL_DATE_FORMAT = '%Y-%m-%d'
LOCAL_TIME_FORMAT = '%H:%M:%S'
LOCAL_DATETIME_FORMAT = f'{LOCAL_DATE_FORMAT} {LOCAL_TIME_FORMAT}'
def guid() -> str:
"""
生成 GUID。
:return: GUID
"""
return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(uuid.uuid1()) + str(random.random()))).replace('-', '')
def get_session() -> Session:
"""
获取线程安全的连接会话对象。
:return: 连接会话对象
"""
_engine: Engine = engine.connect_engine()
session_maker = sessionmaker(bind=_engine)
return session_maker()
def get_aio_session() -> AsyncSession:
"""
获取异步连接会话对象。
:return: 连接会话对象
"""
_engine: AsyncEngine = engine.async_connect_engine()
session_maker = async_sessionmaker(bind=_engine, class_=AsyncSession)
return session_maker()
Base = registry().generate_base()
"""
取得 SQLAlchemy ORM 的声明式基类。
"""
class BaseAdapter(Base):
"""
适配器类,集成数据库引擎、连接、会话、元数据操作、部分 DML 数据处理等功能。
主要实现了同步执行方法。
"""
__abstract__ = True
def __init__(self, **kwargs):
"""
构造模型对象。
:param args:
:param kwargs:
"""
self._is_new = True
super().__init__(**kwargs)
@classmethod
def get_session(cls):
"""
线程安全的连接会话。
:return: 同步会话对象
"""
return get_session()
@classmethod
def get_aio_session(cls):
"""
异步连接会话。
:return: 异步会话对象
"""
return get_aio_session()
@classmethod
def ping(cls):
"""
尝试连接数据库服务器。
:return: 连接结果
"""
engine.connect_engine().connect()
@classmethod
async def tables_in_db(cls):
"""
取得数据库中所有表名称。
:return: 表名称列表
"""
_query = select(text('table_name')).select_from(
text('information_schema.tables')
).where(
text('table_schema = database()')
)
_names = list()
_session = cls.get_aio_session()
try:
_result: Result = await _session.execute(_query)
_names = list(_result.scalars().all())
finally:
await _session.close()
return _names
@classmethod
async def is_table_exist(cls, table_name: str):
"""
判断表是否存在。
:param table_name: 要判断的表名称
"""
_query = select(func.count()).select_from(
text('information_schema.tables')
).where(
text('table_schema = database()'), text('table_name = :table_name')
).params(
table_name=table_name
)
_count = 0
_session = cls.get_aio_session()
try:
_result: Result = await _session.execute(_query)
_count = int(_result.scalar())
finally:
await _session.close()
return _count > 0
@classmethod
def table(cls) -> Table:
"""
取得当前模型对象所对应的表对象。
:return: 表对象
"""
tables = cls.metadata.tables
assert cls.__tablename__ is not None, '属性:cls.__tablename__ 未找到.'
assert cls.__tablename__ in tables, ('数据库中未找到表:%s.' % cls.__tablename__)
return tables[cls.__tablename__]
@classmethod
def columns(cls) -> List[Column]:
"""
当前对象配置的所有列。
:return: 所有列
"""
cols = cls.table().columns
assert isinstance(cols, ColumnCollection), '%s 列类型错误' % cls.__name__
return list(cols)
@classmethod
def instrument_attr(cls, column_name: str) -> Optional[InstrumentedAttribute]:
"""
取得查询条件绑定属性,用于配置查询条件表达式。
:param column_name: 列名称
:return: 类配置的列信息,即查询条件绑定属性
"""
attr = getattr(cls, column_name)
if not isinstance(attr, InstrumentedAttribute):
return None
return attr
@classmethod
def label(cls, column: Union[Column, Any]):
"""
以字段的 comment 作为标签返回。
:return: 指定列的备注
"""
return column.comment
@classmethod
def labels(cls) -> dict:
"""
取得所有列的名称描述。
:return: 所有列的描述字典,结构为:{`field`: `comment`}
"""
_label_dict = {}
for _col in cls.columns():
_label_dict[_col.key] = _col.comment
return _label_dict
@classmethod
def field(cls, column: Union[Column, Any]) -> str:
"""
取得字段名称。
:param column: 列对象
:return: 字段名称
"""
return column.key
@classmethod
def fields(cls) -> List[str]:
"""
取得所有字段名称列表。
:return: 所有字段名称列表
"""
return [col.key for col in cls.columns()]
@classmethod
def new_id(cls, **kwargs) -> str:
"""
用GUID作为新的ID。
:return: ID
"""
return guid()
@property
def is_new(self):
"""
通过检查是否已经存在ID判断是否为新建对象。
:return: 是否为新建对象
"""
return hasattr(self, '_is_new') and self._is_new
@classmethod
def raw_sql(cls, query) -> StrSQLCompiler:
"""
显示编译后的原始 SQL 命令。
:return: 编译后的 SQL 命令文本
"""
raw_sql = query.compile(compile_kwargs={'literal_binds': True})
return raw_sql
@classmethod
def row_count(cls, *where_expressions: BinaryExpression) -> int:
"""
按条件取得记录行数。
:param where_expressions: 查询条件
:return: 返回记录行数
"""
query = select(func.count(1)).select_from(cls).where(*where_expressions)
session = cls.get_session()
count = session.execute(query).scalars().first()
session.close()
return count
@classmethod
def exist_by_id(cls, id_val: Union[str, int]):
"""
检查是否有存在的主键。
:param id_val: 主键值
"""
_wheres = [cls.instrument_attr('id') == id_val]
count: int = cls.row_count(*_wheres)
return count >= 1
@classmethod
def search_wheres(cls, likes: Optional[dict[str, str]] = None, **kwargs):
"""
按参数组织查询条件。
:param likes: 需要执行模糊查询的列
:param kwargs: 属性填充参数
:return: 查询条件列表
"""
if likes is None:
likes = []
_query_model = cls().copy_from_dict(kwargs)
return _query_model.filter_expressions(likes=likes)
@classmethod
def datalist_query(cls, row_list: List[dict], condition_cols: List[Column]) -> Optional[Select]:
"""
根据数据列表、查询条件列生成查询。仅支持单表查询。
:param row_list: 数据列表
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
:return: 查询到的数据模型列表
"""
if not row_list or not condition_cols:
return None
_condition_fields = [_col.key for _col in condition_cols]
# 验证字段存在
for field in _condition_fields:
if not hasattr(cls, field):
raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
# 构建值列表的元组形式
_values_tuples = []
for _row in row_list:
_val_list = []
for _f in _condition_fields:
_val = udict.get_with_default(_row, _f, '')
# 单独处理日期时间格式,但是此处固定格式
if isinstance(_val, datetime.datetime):
_val = _val.strftime(LOCAL_DATETIME_FORMAT)
if isinstance(_val, datetime.date):
_val = _val.strftime(LOCAL_DATE_FORMAT)
_val_list.append(_val)
# 过滤掉全为空的元组
if any(v != '' for v in _val_list):
_values_tuples.append(tuple(_val_list))
if not _values_tuples:
return None
# 去重,提高性能
_values_tuples = list(set(_values_tuples))
# 使用参数化查询
if len(condition_cols) == 1:
# 单字段查询
field = getattr(cls, _condition_fields[0])
_single_values = [t[0] for t in _values_tuples]
query = select(cls).where(field.in_(_single_values))
else:
# 多字段组合查询
conditions = [getattr(cls, field) for field in _condition_fields]
query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
return query
@classmethod
def dataframe_query(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
"""
根据数据列表、查询条件列生成查询。仅支持单表查询。
:param row_df: 数据框架
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
:return: 查询到的数据模型列表
"""
if row_df.empty or not condition_cols:
return None
_condition_fields = [_col.key for _col in condition_cols]
# 验证字段存在
for field in _condition_fields:
if not hasattr(cls, field):
raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
# 去除重复行并删除包含 NaN 的行
row_df = row_df[_condition_fields].drop_duplicates().dropna()
if row_df.empty:
return None
if len(condition_cols) == 1:
# 单字段查询
_field_name = _condition_fields[0]
field = getattr(cls, _field_name)
# 直接获取 Series 的值
_values_list = row_df[_field_name].tolist()
if not _values_list:
return None
query = select(cls).where(field.in_(_values_list))
else:
# 多字段组合查询 - 使用向量化操作避免 iterrows()
_values_tuples = [tuple(row) for row in row_df.values]
# 再次去重确保
_values_tuples = list(set(_values_tuples))
if not _values_tuples:
return None
conditions = [getattr(cls, field) for field in _condition_fields]
query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
return query
@classmethod
def find_by_id(cls, id_val: Union[str, int], reset_session: Optional[bool] = True) -> Optional['BaseAdapter']:
"""
根据主键ID查找数据,确认有 id 字段后方可使用。
:param id_val: 主键值
:param reset_session: 重置会话连接
:return: 数据模型对象
"""
query = select(cls).where(cls.instrument_attr('id') == id_val)
session = cls.get_session()
model = session.execute(query).scalars().first()
session.close()
if reset_session and isinstance(model, BaseAdapter):
model.reset_session()
return model
@classmethod
def find_by_created(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
"""
按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
:param from_t: 开始时间
:param to_t: 结束时间
"""
if to_t is None:
to_t = from_t
query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
session = cls.get_session()
rows = session.execute(query).scalars().all()
session.close()
return rows
@classmethod
def find_by_updated(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
"""
按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
:param from_t: 开始时间
:param to_t: 结束时间
"""
if to_t is None:
to_t = from_t
query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
session = cls.get_session()
rows = session.execute(query).scalars().all()
session.close()
return rows
@classmethod
def find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]) -> Sequence['BaseAdapter']:
"""
根据数据列表查询已经在数据库中的数据。
:param row_list: 数据列表
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
:return: 查询到的数据模型列表
"""
rows: Sequence[cls] = []
_query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
if _query is not None:
session = cls.get_session()
rows = session.execute(_query).scalars().all()
session.close()
return rows
@classmethod
def find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
"""
根据数据框架查询已经在数据库中的数据。
:param row_df: 数据框架
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
:return: 查询到的数据模型列表
"""
_query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
if _query is not None:
session = cls.get_session()
_rows = session.connection().execute(_query).all()
session.close()
return pd.DataFrame(_rows)
return None
def load(self, model_data: dict):
"""
从字典对象载入数据。
该方法仅从字典中读取与对象属性对应的数据,忽略其他数据。
注意:与 copyXXX 方法不同,该方法跳过 model_data 中的 None 值,保留原始值不变
:param model_data: 包含数据的字典对象
"""
for attr in self.__dir__():
if model_data.get(attr, None) is not None:
self.__setattr__(attr, model_data.get(attr))
return self
def copy_from(self, source: 'BaseAdapter', mapping_fields: Optional[List[str]] = None):
"""
从源数据模型复制数据,仅复制相同字段的数据。
:param source: 源数据对象
:param mapping_fields: 映射字段列表,若为空,则从源数据模型中获取字段列表
"""
# 源对象字段列表
if mapping_fields is None:
mapping_fields = [column.key for column in source.columns()]
for column in self.columns():
# 若源对象中不包含同名字段,则跳过
if column.key not in mapping_fields:
continue
# 取出对象属性值
attr_val = source.ins_value(column.key)
setattr(self, column.key, attr_val)
return self
def copy_from_dict(self, source: dict, mapping_keys: Union[list, tuple, set] = None, skip_none: bool = False):
"""
从源字典对象复制数据,仅复制相同键值的数据。
:param source: 源数据对象
:param mapping_keys: 映射键列表,若为空,则从源字典对象中获取键列表
:param skip_none: 是否跳过 None 值,即若源数据对象中属性值为 None 时,跳过
:return 返回自身
"""
# 源对象字段列表
if mapping_keys is None:
mapping_keys = source.keys()
for column in self.columns():
if column.key not in mapping_keys:
# 若源对象中不包含同名字段,则跳过
continue
# 取出对象属性值
attr_val = source.get(column.key, None)
if skip_none and attr_val is None:
# 若设置了跳过 None 且此时属性值为空时,跳过
continue
setattr(self, column.key, attr_val)
return self
def inspect(self) -> InstanceState:
"""
取得对象的监视对象。
:return: 监视对象
"""
return inspect(self)
def session(self) -> Session:
"""
取得当前对象的连接会话对象。
:return: 连接会话对象
"""
return self.inspect().session
def has_session(self) -> bool:
"""
检测对象是否已有连接会话对象。
:return: 有则返回 True 否则返回 False
"""
return self.session() is not None
def add_session(self):
"""
将对象加入到连接会话对象。
"""
if not self.has_session():
self.get_session().add(self)
return self.session()
def close_session(self):
"""
关闭会话。
"""
if self.has_session():
self.session().close()
def reset_session(self):
"""
重置对象的连接会话。
"""
self.close_session()
self.add_session()
def ins_value(self, column_name: str) -> Any:
"""
取得查询条件绑定的值,该值与对象属性一致。与::
self.column_name
self['column_name']
:param column_name: 列名
:return: 对象属性值
"""
# 取出对象属性值
attr_val = getattr(self, column_name)
if attr_val is None:
return None
return attr_val
def filter_expressions(self, likes: Optional[dict[str, str]] = None) -> List[BinaryExpression]:
"""
自动生成查询条件表达式。
参数 likes 为需要执行模糊查询的字段字典::
{
field_name1: '%{}%',
field_name2: '{}%',
field_name3: '%{}',
...
}
:param likes: 需要执行模糊查询的字段字典
:return: 包含查询条件表达式的列表
"""
expressions = list()
for column in self.columns():
# 取出对象属性值
attr_val = self.ins_value(column.key)
# 若属性值为空,不增加条件,跳出
if attr_val in (None, ''):
continue
# 取出类属性
attr = self.instrument_attr(column.key)
# 若类属性类型不正确,不增加条件,跳出
if attr is None:
continue
# 取出列的数据类型
column_type = column.type
if isinstance(column_type, (String, Integer, Numeric, ForeignKey)):
# 针对属性数据类型,进行不同处理
if isinstance(attr_val, (list, tuple, set)):
if not attr_val:
# 跳过0长数组
continue
# 数组使用 in
expressions.append(attr.in_(attr_val))
elif isinstance(attr_val, (int, float)):
# 数值类型
expressions.append(attr == attr_val)
else:
# 字符类型,使用 like
if likes and attr.key in likes:
_f_str = likes[attr.key]
expressions.append(attr.like(_f_str.format(attr_val)))
else:
expressions.append(attr == attr_val)
elif isinstance(column_type, (DateTime, Date)):
if isinstance(attr_val, (list, tuple, set)):
if len(attr_val) == 2:
# 如果长度为 2 则使用 between and 方式
expressions.append(attr.between(attr_val[0], attr_val[1]))
else:
# 数组使用 in
expressions.append(attr.in_(attr_val))
else:
expressions.append(attr.between(attr_val, attr_val))
else:
# 其他类型,用等号
expressions.append(attr == attr_val)
return expressions
def gen_query(self, likes: Optional[dict[str, str]] = None):
"""
根据自身参数生成查询对象。
:param likes: 需要执行模糊查询的字段字典
:return: 查询对象
"""
cls = self.__class__
expressions = self.filter_expressions(likes)
_query: Select = select(cls).where(*expressions)
return _query
def find(self) -> Sequence['BaseAdapter']:
"""
根据自身参数,查询数据库。
:return: 查询到的结果对象
"""
session = self.get_session()
_model_list = session.execute(self.gen_query()).scalars().all()
session.close()
return _model_list
def find_first(self) -> 'BaseAdapter':
"""
根据自身参数,查询数据库,仅查询第一条。
:return: 查询到的结果对象
"""
session = self.get_session()
_model = session.execute(self.gen_query()).scalars().first()
session.close()
return _model
def find_piece(self, *where: BinaryExpression, offset=0, limit=500, is_desc=False,
likes: Optional[dict[str, str]] = None):
"""
根据自身参数,查询数据库。
:param where: 查询条件
:param offset: 偏移量
:param limit: 读取数量
:param is_desc: 是否逆序排列
:param likes: 模糊条件
:return: 查询到的结果对象
"""
clz = self.__class__
expressions = self.filter_expressions(likes=likes)
if where is not None:
expressions += where
query = select(clz).where(*expressions)
if limit > 0:
query = query.limit(limit=limit)
if offset >= 0:
query = query.offset(offset=offset)
if is_desc:
if hasattr(clz, 'id'):
query = query.order_by(desc(clz.id))
session = self.get_session()
_model_list = session.execute(query).scalars().all()
session.close()
return _model_list
def before_save(self):
"""
保存前的动作,一般应当在子类中覆盖该方法,增加在保存前应当执行的动作。
"""
if self.is_new:
if hasattr(self, 'id'):
setattr(self, 'id', self.new_id())
if hasattr(self, 'created_at'):
setattr(self, 'created_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
if hasattr(self, 'updated_at'):
setattr(self, 'updated_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
return self
def save(self, auto_expunge: Optional[bool] = True, session: Optional[Session] = None):
"""
保存数据模型对象,若该对象尚未有连接会话,则自动加入连接会话。
:param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
:param session: 会话对象,主要用于事务
:return: 保存状态
"""
_has_session: bool = True
"""
该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
"""
self.before_save()
if session is None:
_has_session = False
session = self.add_session()
try:
session.add(self)
if not _has_session:
# 使用新会话时,直接提交
session.commit()
self._is_new = False
except Exception as e:
session.rollback()
raise e
else:
if auto_expunge and not _has_session:
session.refresh(self)
session.expunge(self)
return True
def to_dict(self) -> dict:
"""
数据模型转字典。递归处理内部类型对象。
:return: 转换后的字典数据
"""
# 模型数据字典
m_dict = {}
# 遍历处理内部转换
for _key, _val in dict(self.__dict__).items():
if f'{_key}'.startswith('_'):
# 跳过所有私有属性
continue
if isinstance(_val, BaseAdapter):
# 内部数据对象数据,直接转换
m_dict[_key] = _val.to_dict()
elif isinstance(_val, list):
# 遍历转换数据对象列表
_tmp_list = []
for _i, _v in enumerate(_val):
if isinstance(_v, BaseAdapter):
_tmp_list.append(_v.to_dict())
else:
_tmp_list.append(_v)
m_dict[_key] = _tmp_list
elif isinstance(_val, dict):
# 遍历转换数据对象字典
_tmp_dict = {}
for _ik, _iv in _val.items():
if isinstance(_iv, BaseAdapter):
_tmp_dict[_ik] = _iv.to_dict()
else:
_tmp_dict[_ik] = _iv
m_dict[_key] = _tmp_dict
else:
# 其他属性直接赋值
m_dict[_key] = _val
return m_dict