Squashed 'paste-framework/' content from commit 34e8684
git-subtree-dir: paste-framework git-subtree-split: 34e8684c4bc3cebbe177509f42ab4ef5b5425a7a
This commit is contained in:
Executable
+830
@@ -0,0 +1,830 @@
|
||||
"""
|
||||
数据访问适配器。主要封装了数据访问的基础方法。
|
||||
"""
|
||||
import datetime
|
||||
import random
|
||||
import uuid
|
||||
from typing import Optional, Any, Union, List, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import Column, String, DateTime, Date, inspect, Table, Integer, select, Numeric, ForeignKey, func, \
|
||||
text, desc, Result, Engine, tuple_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, Session, InstanceState, InstrumentedAttribute, registry
|
||||
from sqlalchemy.sql import Select
|
||||
from sqlalchemy.sql.base import ColumnCollection
|
||||
from sqlalchemy.sql.compiler import StrSQLCompiler
|
||||
from sqlalchemy.sql.elements import BinaryExpression
|
||||
|
||||
from paste.db import engine
|
||||
from paste.util import udict
|
||||
|
||||
LOCAL_DATE_FORMAT = '%Y-%m-%d'
|
||||
LOCAL_TIME_FORMAT = '%H:%M:%S'
|
||||
LOCAL_DATETIME_FORMAT = f'{LOCAL_DATE_FORMAT} {LOCAL_TIME_FORMAT}'
|
||||
|
||||
|
||||
def guid() -> str:
|
||||
"""
|
||||
生成 GUID。
|
||||
|
||||
:return: GUID
|
||||
"""
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(uuid.uuid1()) + str(random.random()))).replace('-', '')
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
"""
|
||||
获取线程安全的连接会话对象。
|
||||
|
||||
:return: 连接会话对象
|
||||
"""
|
||||
_engine: Engine = engine.connect_engine()
|
||||
session_maker = sessionmaker(bind=_engine)
|
||||
return session_maker()
|
||||
|
||||
|
||||
def get_aio_session() -> AsyncSession:
|
||||
"""
|
||||
获取异步连接会话对象。
|
||||
|
||||
:return: 连接会话对象
|
||||
"""
|
||||
_engine: AsyncEngine = engine.async_connect_engine()
|
||||
session_maker = async_sessionmaker(bind=_engine, class_=AsyncSession)
|
||||
return session_maker()
|
||||
|
||||
|
||||
Base = registry().generate_base()
|
||||
"""
|
||||
取得 SQLAlchemy ORM 的声明式基类。
|
||||
"""
|
||||
|
||||
|
||||
class BaseAdapter(Base):
|
||||
"""
|
||||
适配器类,集成数据库引擎、连接、会话、元数据操作、部分 DML 数据处理等功能。
|
||||
主要实现了同步执行方法。
|
||||
"""
|
||||
__abstract__ = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
构造模型对象。
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
"""
|
||||
self._is_new = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_session(cls):
|
||||
"""
|
||||
线程安全的连接会话。
|
||||
|
||||
:return: 同步会话对象
|
||||
"""
|
||||
return get_session()
|
||||
|
||||
@classmethod
|
||||
def get_aio_session(cls):
|
||||
"""
|
||||
异步连接会话。
|
||||
|
||||
:return: 异步会话对象
|
||||
"""
|
||||
return get_aio_session()
|
||||
|
||||
@classmethod
|
||||
def ping(cls):
|
||||
"""
|
||||
尝试连接数据库服务器。
|
||||
|
||||
:return: 连接结果
|
||||
"""
|
||||
engine.connect_engine().connect()
|
||||
|
||||
@classmethod
|
||||
async def tables_in_db(cls):
|
||||
"""
|
||||
取得数据库中所有表名称。
|
||||
:return: 表名称列表
|
||||
"""
|
||||
_query = select(text('table_name')).select_from(
|
||||
text('information_schema.tables')
|
||||
).where(
|
||||
text('table_schema = database()')
|
||||
)
|
||||
_names = list()
|
||||
_session = cls.get_aio_session()
|
||||
try:
|
||||
_result: Result = await _session.execute(_query)
|
||||
_names = list(_result.scalars().all())
|
||||
finally:
|
||||
await _session.close()
|
||||
return _names
|
||||
|
||||
@classmethod
|
||||
async def is_table_exist(cls, table_name: str):
|
||||
"""
|
||||
判断表是否存在。
|
||||
:param table_name: 要判断的表名称
|
||||
"""
|
||||
_query = select(func.count()).select_from(
|
||||
text('information_schema.tables')
|
||||
).where(
|
||||
text('table_schema = database()'), text('table_name = :table_name')
|
||||
).params(
|
||||
table_name=table_name
|
||||
)
|
||||
_count = 0
|
||||
_session = cls.get_aio_session()
|
||||
try:
|
||||
_result: Result = await _session.execute(_query)
|
||||
_count = int(_result.scalar())
|
||||
finally:
|
||||
await _session.close()
|
||||
return _count > 0
|
||||
|
||||
@classmethod
|
||||
def table(cls) -> Table:
|
||||
"""
|
||||
取得当前模型对象所对应的表对象。
|
||||
|
||||
:return: 表对象
|
||||
"""
|
||||
tables = cls.metadata.tables
|
||||
assert cls.__tablename__ is not None, '属性:cls.__tablename__ 未找到.'
|
||||
assert cls.__tablename__ in tables, ('数据库中未找到表:%s.' % cls.__tablename__)
|
||||
return tables[cls.__tablename__]
|
||||
|
||||
@classmethod
|
||||
def columns(cls) -> List[Column]:
|
||||
"""
|
||||
当前对象配置的所有列。
|
||||
|
||||
:return: 所有列
|
||||
"""
|
||||
cols = cls.table().columns
|
||||
assert isinstance(cols, ColumnCollection), '%s 列类型错误' % cls.__name__
|
||||
return list(cols)
|
||||
|
||||
@classmethod
|
||||
def instrument_attr(cls, column_name: str) -> Optional[InstrumentedAttribute]:
|
||||
"""
|
||||
取得查询条件绑定属性,用于配置查询条件表达式。
|
||||
|
||||
:param column_name: 列名称
|
||||
:return: 类配置的列信息,即查询条件绑定属性
|
||||
"""
|
||||
attr = getattr(cls, column_name)
|
||||
if not isinstance(attr, InstrumentedAttribute):
|
||||
return None
|
||||
return attr
|
||||
|
||||
@classmethod
|
||||
def label(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
以字段的 comment 作为标签返回。
|
||||
|
||||
:return: 指定列的备注
|
||||
"""
|
||||
return column.comment
|
||||
|
||||
@classmethod
|
||||
def labels(cls) -> dict:
|
||||
"""
|
||||
取得所有列的名称描述。
|
||||
|
||||
:return: 所有列的描述字典,结构为:{`field`: `comment`}
|
||||
"""
|
||||
_label_dict = {}
|
||||
for _col in cls.columns():
|
||||
_label_dict[_col.key] = _col.comment
|
||||
return _label_dict
|
||||
|
||||
@classmethod
|
||||
def field(cls, column: Union[Column, Any]) -> str:
|
||||
"""
|
||||
取得字段名称。
|
||||
|
||||
:param column: 列对象
|
||||
:return: 字段名称
|
||||
"""
|
||||
return column.key
|
||||
|
||||
@classmethod
|
||||
def fields(cls) -> List[str]:
|
||||
"""
|
||||
取得所有字段名称列表。
|
||||
|
||||
:return: 所有字段名称列表
|
||||
"""
|
||||
return [col.key for col in cls.columns()]
|
||||
|
||||
@classmethod
|
||||
def new_id(cls, **kwargs) -> str:
|
||||
"""
|
||||
用GUID作为新的ID。
|
||||
|
||||
:return: ID
|
||||
"""
|
||||
return guid()
|
||||
|
||||
@property
|
||||
def is_new(self):
|
||||
"""
|
||||
通过检查是否已经存在ID判断是否为新建对象。
|
||||
|
||||
:return: 是否为新建对象
|
||||
"""
|
||||
return hasattr(self, '_is_new') and self._is_new
|
||||
|
||||
@classmethod
|
||||
def raw_sql(cls, query) -> StrSQLCompiler:
|
||||
"""
|
||||
显示编译后的原始 SQL 命令。
|
||||
|
||||
:return: 编译后的 SQL 命令文本
|
||||
"""
|
||||
raw_sql = query.compile(compile_kwargs={'literal_binds': True})
|
||||
return raw_sql
|
||||
|
||||
@classmethod
|
||||
def row_count(cls, *where_expressions: BinaryExpression) -> int:
|
||||
"""
|
||||
按条件取得记录行数。
|
||||
|
||||
:param where_expressions: 查询条件
|
||||
:return: 返回记录行数
|
||||
"""
|
||||
query = select(func.count(1)).select_from(cls).where(*where_expressions)
|
||||
session = cls.get_session()
|
||||
count = session.execute(query).scalars().first()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def exist_by_id(cls, id_val: Union[str, int]):
|
||||
"""
|
||||
检查是否有存在的主键。
|
||||
|
||||
:param id_val: 主键值
|
||||
"""
|
||||
_wheres = [cls.instrument_attr('id') == id_val]
|
||||
count: int = cls.row_count(*_wheres)
|
||||
return count >= 1
|
||||
|
||||
@classmethod
|
||||
def search_wheres(cls, likes: Optional[dict[str, str]] = None, **kwargs):
|
||||
"""
|
||||
按参数组织查询条件。
|
||||
|
||||
:param likes: 需要执行模糊查询的列
|
||||
:param kwargs: 属性填充参数
|
||||
:return: 查询条件列表
|
||||
"""
|
||||
if likes is None:
|
||||
likes = []
|
||||
_query_model = cls().copy_from_dict(kwargs)
|
||||
return _query_model.filter_expressions(likes=likes)
|
||||
|
||||
@classmethod
|
||||
def datalist_query(cls, row_list: List[dict], condition_cols: List[Column]) -> Optional[Select]:
|
||||
"""
|
||||
根据数据列表、查询条件列生成查询。仅支持单表查询。
|
||||
|
||||
:param row_list: 数据列表
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
if not row_list or not condition_cols:
|
||||
return None
|
||||
|
||||
_condition_fields = [_col.key for _col in condition_cols]
|
||||
|
||||
# 验证字段存在
|
||||
for field in _condition_fields:
|
||||
if not hasattr(cls, field):
|
||||
raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
|
||||
|
||||
# 构建值列表的元组形式
|
||||
_values_tuples = []
|
||||
for _row in row_list:
|
||||
_val_list = []
|
||||
for _f in _condition_fields:
|
||||
_val = udict.get_with_default(_row, _f, '')
|
||||
# 单独处理日期时间格式,但是此处固定格式
|
||||
if isinstance(_val, datetime.datetime):
|
||||
_val = _val.strftime(LOCAL_DATETIME_FORMAT)
|
||||
if isinstance(_val, datetime.date):
|
||||
_val = _val.strftime(LOCAL_DATE_FORMAT)
|
||||
_val_list.append(_val)
|
||||
|
||||
# 过滤掉全为空的元组
|
||||
if any(v != '' for v in _val_list):
|
||||
_values_tuples.append(tuple(_val_list))
|
||||
|
||||
if not _values_tuples:
|
||||
return None
|
||||
|
||||
# 去重,提高性能
|
||||
_values_tuples = list(set(_values_tuples))
|
||||
|
||||
# 使用参数化查询
|
||||
if len(condition_cols) == 1:
|
||||
# 单字段查询
|
||||
field = getattr(cls, _condition_fields[0])
|
||||
_single_values = [t[0] for t in _values_tuples]
|
||||
query = select(cls).where(field.in_(_single_values))
|
||||
else:
|
||||
# 多字段组合查询
|
||||
conditions = [getattr(cls, field) for field in _condition_fields]
|
||||
query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
|
||||
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def dataframe_query(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
|
||||
"""
|
||||
根据数据列表、查询条件列生成查询。仅支持单表查询。
|
||||
|
||||
:param row_df: 数据框架
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
if row_df.empty or not condition_cols:
|
||||
return None
|
||||
|
||||
_condition_fields = [_col.key for _col in condition_cols]
|
||||
|
||||
# 验证字段存在
|
||||
for field in _condition_fields:
|
||||
if not hasattr(cls, field):
|
||||
raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
|
||||
|
||||
# 去除重复行并删除包含 NaN 的行
|
||||
row_df = row_df[_condition_fields].drop_duplicates().dropna()
|
||||
|
||||
if row_df.empty:
|
||||
return None
|
||||
|
||||
if len(condition_cols) == 1:
|
||||
# 单字段查询
|
||||
_field_name = _condition_fields[0]
|
||||
field = getattr(cls, _field_name)
|
||||
# 直接获取 Series 的值
|
||||
_values_list = row_df[_field_name].tolist()
|
||||
if not _values_list:
|
||||
return None
|
||||
query = select(cls).where(field.in_(_values_list))
|
||||
else:
|
||||
# 多字段组合查询 - 使用向量化操作避免 iterrows()
|
||||
_values_tuples = [tuple(row) for row in row_df.values]
|
||||
# 再次去重确保
|
||||
_values_tuples = list(set(_values_tuples))
|
||||
|
||||
if not _values_tuples:
|
||||
return None
|
||||
|
||||
conditions = [getattr(cls, field) for field in _condition_fields]
|
||||
query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
|
||||
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def find_by_id(cls, id_val: Union[str, int], reset_session: Optional[bool] = True) -> Optional['BaseAdapter']:
|
||||
"""
|
||||
根据主键ID查找数据,确认有 id 字段后方可使用。
|
||||
|
||||
:param id_val: 主键值
|
||||
:param reset_session: 重置会话连接
|
||||
:return: 数据模型对象
|
||||
"""
|
||||
query = select(cls).where(cls.instrument_attr('id') == id_val)
|
||||
session = cls.get_session()
|
||||
model = session.execute(query).scalars().first()
|
||||
session.close()
|
||||
|
||||
if reset_session and isinstance(model, BaseAdapter):
|
||||
model.reset_session()
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def find_by_created(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
|
||||
"""
|
||||
按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
|
||||
|
||||
:param from_t: 开始时间
|
||||
:param to_t: 结束时间
|
||||
"""
|
||||
if to_t is None:
|
||||
to_t = from_t
|
||||
|
||||
query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
|
||||
session = cls.get_session()
|
||||
rows = session.execute(query).scalars().all()
|
||||
session.close()
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def find_by_updated(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
|
||||
"""
|
||||
按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
|
||||
|
||||
:param from_t: 开始时间
|
||||
:param to_t: 结束时间
|
||||
"""
|
||||
if to_t is None:
|
||||
to_t = from_t
|
||||
|
||||
query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
|
||||
session = cls.get_session()
|
||||
rows = session.execute(query).scalars().all()
|
||||
session.close()
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]) -> Sequence['BaseAdapter']:
|
||||
"""
|
||||
根据数据列表查询已经在数据库中的数据。
|
||||
|
||||
:param row_list: 数据列表
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
rows: Sequence[cls] = []
|
||||
_query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
|
||||
if _query is not None:
|
||||
session = cls.get_session()
|
||||
rows = session.execute(_query).scalars().all()
|
||||
session.close()
|
||||
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
|
||||
"""
|
||||
根据数据框架查询已经在数据库中的数据。
|
||||
|
||||
:param row_df: 数据框架
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
_query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
|
||||
if _query is not None:
|
||||
session = cls.get_session()
|
||||
_rows = session.connection().execute(_query).all()
|
||||
session.close()
|
||||
return pd.DataFrame(_rows)
|
||||
|
||||
return None
|
||||
|
||||
def load(self, model_data: dict):
|
||||
"""
|
||||
从字典对象载入数据。
|
||||
该方法仅从字典中读取与对象属性对应的数据,忽略其他数据。
|
||||
注意:与 copyXXX 方法不同,该方法跳过 model_data 中的 None 值,保留原始值不变
|
||||
|
||||
:param model_data: 包含数据的字典对象
|
||||
"""
|
||||
for attr in self.__dir__():
|
||||
if model_data.get(attr, None) is not None:
|
||||
self.__setattr__(attr, model_data.get(attr))
|
||||
return self
|
||||
|
||||
def copy_from(self, source: 'BaseAdapter', mapping_fields: Optional[List[str]] = None):
|
||||
"""
|
||||
从源数据模型复制数据,仅复制相同字段的数据。
|
||||
|
||||
:param source: 源数据对象
|
||||
:param mapping_fields: 映射字段列表,若为空,则从源数据模型中获取字段列表
|
||||
"""
|
||||
|
||||
# 源对象字段列表
|
||||
if mapping_fields is None:
|
||||
mapping_fields = [column.key for column in source.columns()]
|
||||
|
||||
for column in self.columns():
|
||||
# 若源对象中不包含同名字段,则跳过
|
||||
if column.key not in mapping_fields:
|
||||
continue
|
||||
|
||||
# 取出对象属性值
|
||||
attr_val = source.ins_value(column.key)
|
||||
setattr(self, column.key, attr_val)
|
||||
|
||||
return self
|
||||
|
||||
def copy_from_dict(self, source: dict, mapping_keys: Union[list, tuple, set] = None, skip_none: bool = False):
|
||||
"""
|
||||
从源字典对象复制数据,仅复制相同键值的数据。
|
||||
|
||||
:param source: 源数据对象
|
||||
:param mapping_keys: 映射键列表,若为空,则从源字典对象中获取键列表
|
||||
:param skip_none: 是否跳过 None 值,即若源数据对象中属性值为 None 时,跳过
|
||||
:return 返回自身
|
||||
"""
|
||||
|
||||
# 源对象字段列表
|
||||
if mapping_keys is None:
|
||||
mapping_keys = source.keys()
|
||||
|
||||
for column in self.columns():
|
||||
if column.key not in mapping_keys:
|
||||
# 若源对象中不包含同名字段,则跳过
|
||||
continue
|
||||
|
||||
# 取出对象属性值
|
||||
attr_val = source.get(column.key, None)
|
||||
if skip_none and attr_val is None:
|
||||
# 若设置了跳过 None 且此时属性值为空时,跳过
|
||||
continue
|
||||
setattr(self, column.key, attr_val)
|
||||
|
||||
return self
|
||||
|
||||
def inspect(self) -> InstanceState:
|
||||
"""
|
||||
取得对象的监视对象。
|
||||
|
||||
:return: 监视对象
|
||||
"""
|
||||
return inspect(self)
|
||||
|
||||
def session(self) -> Session:
|
||||
"""
|
||||
取得当前对象的连接会话对象。
|
||||
|
||||
:return: 连接会话对象
|
||||
"""
|
||||
return self.inspect().session
|
||||
|
||||
def has_session(self) -> bool:
|
||||
"""
|
||||
检测对象是否已有连接会话对象。
|
||||
|
||||
:return: 有则返回 True 否则返回 False
|
||||
"""
|
||||
return self.session() is not None
|
||||
|
||||
def add_session(self):
|
||||
"""
|
||||
将对象加入到连接会话对象。
|
||||
"""
|
||||
if not self.has_session():
|
||||
self.get_session().add(self)
|
||||
return self.session()
|
||||
|
||||
def close_session(self):
|
||||
"""
|
||||
关闭会话。
|
||||
"""
|
||||
if self.has_session():
|
||||
self.session().close()
|
||||
|
||||
def reset_session(self):
|
||||
"""
|
||||
重置对象的连接会话。
|
||||
"""
|
||||
self.close_session()
|
||||
self.add_session()
|
||||
|
||||
def ins_value(self, column_name: str) -> Any:
|
||||
"""
|
||||
取得查询条件绑定的值,该值与对象属性一致。与::
|
||||
self.column_name
|
||||
或
|
||||
self['column_name']
|
||||
|
||||
:param column_name: 列名
|
||||
:return: 对象属性值
|
||||
"""
|
||||
# 取出对象属性值
|
||||
attr_val = getattr(self, column_name)
|
||||
if attr_val is None:
|
||||
return None
|
||||
return attr_val
|
||||
|
||||
def filter_expressions(self, likes: Optional[dict[str, str]] = None) -> List[BinaryExpression]:
|
||||
"""
|
||||
自动生成查询条件表达式。
|
||||
参数 likes 为需要执行模糊查询的字段字典::
|
||||
|
||||
{
|
||||
field_name1: '%{}%',
|
||||
field_name2: '{}%',
|
||||
field_name3: '%{}',
|
||||
...
|
||||
}
|
||||
|
||||
:param likes: 需要执行模糊查询的字段字典
|
||||
:return: 包含查询条件表达式的列表
|
||||
"""
|
||||
expressions = list()
|
||||
for column in self.columns():
|
||||
# 取出对象属性值
|
||||
attr_val = self.ins_value(column.key)
|
||||
# 若属性值为空,不增加条件,跳出
|
||||
if attr_val in (None, ''):
|
||||
continue
|
||||
|
||||
# 取出类属性
|
||||
attr = self.instrument_attr(column.key)
|
||||
# 若类属性类型不正确,不增加条件,跳出
|
||||
if attr is None:
|
||||
continue
|
||||
|
||||
# 取出列的数据类型
|
||||
column_type = column.type
|
||||
if isinstance(column_type, (String, Integer, Numeric, ForeignKey)):
|
||||
# 针对属性数据类型,进行不同处理
|
||||
if isinstance(attr_val, (list, tuple, set)):
|
||||
if not attr_val:
|
||||
# 跳过0长数组
|
||||
continue
|
||||
# 数组使用 in
|
||||
expressions.append(attr.in_(attr_val))
|
||||
elif isinstance(attr_val, (int, float)):
|
||||
# 数值类型
|
||||
expressions.append(attr == attr_val)
|
||||
else:
|
||||
# 字符类型,使用 like
|
||||
if likes and attr.key in likes:
|
||||
_f_str = likes[attr.key]
|
||||
expressions.append(attr.like(_f_str.format(attr_val)))
|
||||
else:
|
||||
expressions.append(attr == attr_val)
|
||||
elif isinstance(column_type, (DateTime, Date)):
|
||||
if isinstance(attr_val, (list, tuple, set)):
|
||||
if len(attr_val) == 2:
|
||||
# 如果长度为 2 则使用 between and 方式
|
||||
expressions.append(attr.between(attr_val[0], attr_val[1]))
|
||||
else:
|
||||
# 数组使用 in
|
||||
expressions.append(attr.in_(attr_val))
|
||||
else:
|
||||
expressions.append(attr.between(attr_val, attr_val))
|
||||
else:
|
||||
# 其他类型,用等号
|
||||
expressions.append(attr == attr_val)
|
||||
|
||||
return expressions
|
||||
|
||||
def gen_query(self, likes: Optional[dict[str, str]] = None):
|
||||
"""
|
||||
根据自身参数生成查询对象。
|
||||
|
||||
:param likes: 需要执行模糊查询的字段字典
|
||||
:return: 查询对象
|
||||
"""
|
||||
cls = self.__class__
|
||||
expressions = self.filter_expressions(likes)
|
||||
_query: Select = select(cls).where(*expressions)
|
||||
return _query
|
||||
|
||||
def find(self) -> Sequence['BaseAdapter']:
|
||||
"""
|
||||
根据自身参数,查询数据库。
|
||||
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
session = self.get_session()
|
||||
_model_list = session.execute(self.gen_query()).scalars().all()
|
||||
session.close()
|
||||
return _model_list
|
||||
|
||||
def find_first(self) -> 'BaseAdapter':
|
||||
"""
|
||||
根据自身参数,查询数据库,仅查询第一条。
|
||||
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
session = self.get_session()
|
||||
_model = session.execute(self.gen_query()).scalars().first()
|
||||
session.close()
|
||||
return _model
|
||||
|
||||
def find_piece(self, *where: BinaryExpression, offset=0, limit=500, is_desc=False,
|
||||
likes: Optional[dict[str, str]] = None):
|
||||
"""
|
||||
根据自身参数,查询数据库。
|
||||
|
||||
:param where: 查询条件
|
||||
:param offset: 偏移量
|
||||
:param limit: 读取数量
|
||||
:param is_desc: 是否逆序排列
|
||||
:param likes: 模糊条件
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
clz = self.__class__
|
||||
expressions = self.filter_expressions(likes=likes)
|
||||
if where is not None:
|
||||
expressions += where
|
||||
|
||||
query = select(clz).where(*expressions)
|
||||
if limit > 0:
|
||||
query = query.limit(limit=limit)
|
||||
if offset >= 0:
|
||||
query = query.offset(offset=offset)
|
||||
if is_desc:
|
||||
if hasattr(clz, 'id'):
|
||||
query = query.order_by(desc(clz.id))
|
||||
|
||||
session = self.get_session()
|
||||
_model_list = session.execute(query).scalars().all()
|
||||
session.close()
|
||||
return _model_list
|
||||
|
||||
def before_save(self):
|
||||
"""
|
||||
保存前的动作,一般应当在子类中覆盖该方法,增加在保存前应当执行的动作。
|
||||
"""
|
||||
if self.is_new:
|
||||
if hasattr(self, 'id'):
|
||||
setattr(self, 'id', self.new_id())
|
||||
|
||||
if hasattr(self, 'created_at'):
|
||||
setattr(self, 'created_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
|
||||
|
||||
if hasattr(self, 'updated_at'):
|
||||
setattr(self, 'updated_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
|
||||
|
||||
return self
|
||||
|
||||
def save(self, auto_expunge: Optional[bool] = True, session: Optional[Session] = None):
|
||||
"""
|
||||
保存数据模型对象,若该对象尚未有连接会话,则自动加入连接会话。
|
||||
|
||||
:param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
|
||||
:param session: 会话对象,主要用于事务
|
||||
:return: 保存状态
|
||||
"""
|
||||
_has_session: bool = True
|
||||
"""
|
||||
该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
|
||||
"""
|
||||
|
||||
self.before_save()
|
||||
if session is None:
|
||||
_has_session = False
|
||||
session = self.add_session()
|
||||
|
||||
try:
|
||||
session.add(self)
|
||||
if not _has_session:
|
||||
# 使用新会话时,直接提交
|
||||
session.commit()
|
||||
self._is_new = False
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
else:
|
||||
if auto_expunge and not _has_session:
|
||||
session.refresh(self)
|
||||
session.expunge(self)
|
||||
return True
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
数据模型转字典。递归处理内部类型对象。
|
||||
|
||||
:return: 转换后的字典数据
|
||||
"""
|
||||
# 模型数据字典
|
||||
m_dict = {}
|
||||
|
||||
# 遍历处理内部转换
|
||||
for _key, _val in dict(self.__dict__).items():
|
||||
if f'{_key}'.startswith('_'):
|
||||
# 跳过所有私有属性
|
||||
continue
|
||||
|
||||
if isinstance(_val, BaseAdapter):
|
||||
# 内部数据对象数据,直接转换
|
||||
m_dict[_key] = _val.to_dict()
|
||||
elif isinstance(_val, list):
|
||||
# 遍历转换数据对象列表
|
||||
_tmp_list = []
|
||||
for _i, _v in enumerate(_val):
|
||||
if isinstance(_v, BaseAdapter):
|
||||
_tmp_list.append(_v.to_dict())
|
||||
else:
|
||||
_tmp_list.append(_v)
|
||||
m_dict[_key] = _tmp_list
|
||||
elif isinstance(_val, dict):
|
||||
# 遍历转换数据对象字典
|
||||
_tmp_dict = {}
|
||||
for _ik, _iv in _val.items():
|
||||
if isinstance(_iv, BaseAdapter):
|
||||
_tmp_dict[_ik] = _iv.to_dict()
|
||||
else:
|
||||
_tmp_dict[_ik] = _iv
|
||||
m_dict[_key] = _tmp_dict
|
||||
else:
|
||||
# 其他属性直接赋值
|
||||
m_dict[_key] = _val
|
||||
|
||||
return m_dict
|
||||
@@ -0,0 +1,629 @@
|
||||
"""
|
||||
数据模型基础类,继承于数据表。集成了模型的基础功能,如数据验证、错误消息、数据影射、对象比较等功能。
|
||||
"""
|
||||
import datetime
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import Union, Any, Optional, Callable
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import Column, text, desc
|
||||
|
||||
from paste.db import baseadapter
|
||||
from paste.db.basetable import BaseTable
|
||||
from paste.util import udict, ustr
|
||||
from paste.util.pagination import Pagination
|
||||
from paste.util.snow_id import IdWorker
|
||||
|
||||
LOCAL_DATE_FORMAT = baseadapter.LOCAL_DATE_FORMAT
|
||||
LOCAL_TIME_FORMAT = baseadapter.LOCAL_TIME_FORMAT
|
||||
LOCAL_DATETIME_FORMAT = baseadapter.LOCAL_DATETIME_FORMAT
|
||||
|
||||
|
||||
class BaseModel(BaseTable):
|
||||
"""
|
||||
数据模型基类。集成了验证辅助功能。
|
||||
"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
@classmethod
|
||||
def new_id(cls, datacenter_id: int = 1, worker_id: int = 1, sequence: int = 0) -> int:
|
||||
"""
|
||||
生成新的 Snow ID 对象,并生成 ID 值。
|
||||
|
||||
:param datacenter_id: 数据中心(机器区域)ID
|
||||
:param worker_id: 机器ID
|
||||
:param sequence: 起始序号
|
||||
:return: 新的 Snow ID 值
|
||||
"""
|
||||
return IdWorker.get_id_worker(datacenter_id, worker_id, sequence).get_id()
|
||||
|
||||
@classmethod
|
||||
def now(cls):
|
||||
"""
|
||||
取得当前时间的格式化字符串。
|
||||
|
||||
:return: 当前时间格式化字符串
|
||||
"""
|
||||
return datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT)
|
||||
|
||||
@classmethod
|
||||
def is_len(cls, v: str, length: int):
|
||||
"""
|
||||
检测字符串长度的函数,例如检测是否是18位。主要用于数据校验。
|
||||
|
||||
:param v: 待检测的值
|
||||
:param length: 目标长度
|
||||
:return: 相同返回 True,否则返回 False
|
||||
"""
|
||||
v = f"{v}"
|
||||
return not cls.is_empty_or_none(v) and len(v) == length
|
||||
|
||||
@classmethod
|
||||
def is_in_range(cls, v: Union[int, float], v_min: Union[int, float], v_max: Union[int, float]):
|
||||
"""
|
||||
返回检测数值范围的函数,检测是处于最大最小值范围内。主要用于数据校验。
|
||||
|
||||
:param v: 待检测的值
|
||||
:param v_min: 最小值(包含)
|
||||
:param v_max: 最大值(包含)
|
||||
:return: 在范围内返回 True,否则返回 False
|
||||
"""
|
||||
return not cls.is_empty_or_none(v) and v_min <= v <= v_max
|
||||
|
||||
@classmethod
|
||||
def is_in_items(cls, v: Union[int, float], items: list = None):
|
||||
"""
|
||||
返回检测数值是否在列表中。主要用于数据校验。
|
||||
|
||||
:param v: 待检测的值
|
||||
:param items: 所有项目
|
||||
:return: 在列表内返回 True,否则返回 False
|
||||
"""
|
||||
return items is not None and v in items
|
||||
|
||||
@classmethod
|
||||
def is_empty_or_none(cls, v: Any):
|
||||
"""
|
||||
检查是 None 或 空字符串。
|
||||
|
||||
:param v: 待检查的内容
|
||||
:return: 为 None 或 Nan 或 '' 时返回 True,否则返回 False
|
||||
"""
|
||||
return v is None or pd.isna(v) or f"{v}" == ''
|
||||
|
||||
@classmethod
|
||||
def not_empty_or_none(cls, v: Any):
|
||||
"""
|
||||
与 isEmptyOrNone 函数功能相反。
|
||||
"""
|
||||
return not cls.is_empty_or_none(v)
|
||||
|
||||
@classmethod
|
||||
def is_digit(cls, v: str):
|
||||
"""
|
||||
检查字符串是否是整数。
|
||||
|
||||
:param v: 带检查内容
|
||||
:return: 是整数返回 True,否则返回 False
|
||||
"""
|
||||
v = f"{v}"
|
||||
return v.isdigit()
|
||||
|
||||
@classmethod
|
||||
def is_decimal(cls, v: str):
|
||||
"""
|
||||
检查是否是浮点数,若为整数,也返回 True。
|
||||
:param v: 待检查内容
|
||||
:return: 浮点数或整数返回 True,否则返回 False
|
||||
"""
|
||||
v = f"{v}"
|
||||
is_decimal = True
|
||||
vs = v.replace(',', '').split('.')
|
||||
for _v in vs:
|
||||
is_decimal = is_decimal and cls.is_digit(_v)
|
||||
return is_decimal
|
||||
|
||||
@classmethod
|
||||
def is_datetime(cls, v: str):
|
||||
"""
|
||||
检查是否是日期时间格式。
|
||||
:param v: 待检查内容
|
||||
:return: 日期时间返回 True,否则返回 False
|
||||
"""
|
||||
try:
|
||||
datetime.datetime.strptime(v, LOCAL_DATETIME_FORMAT)
|
||||
except (ValueError, Exception):
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_date(cls, v: str):
|
||||
"""
|
||||
检查是否是日期格式。
|
||||
|
||||
:param v: 待检查内容
|
||||
:return: 日期返回 True,否则返回 False
|
||||
"""
|
||||
try:
|
||||
datetime.datetime.strptime(v, LOCAL_DATE_FORMAT)
|
||||
except (ValueError, Exception):
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_time(cls, v: str):
|
||||
"""
|
||||
检查是否是时间格式。
|
||||
|
||||
:param v: 待检查内容
|
||||
:return: 时间返回 True,否则返回 False
|
||||
"""
|
||||
try:
|
||||
datetime.datetime.strptime(v, LOCAL_TIME_FORMAT)
|
||||
except (ValueError, Exception):
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def error_empty_msg(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
空字符串错误。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须包含内容.' % cls.label(column=column)
|
||||
|
||||
@classmethod
|
||||
def error_date_msg(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
日期格式错误。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须是日期.' % cls.label(column=column)
|
||||
|
||||
@classmethod
|
||||
def error_datetime_msg(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
日期时间格式错误。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须是日期时间.' % cls.label(column=column)
|
||||
|
||||
@classmethod
|
||||
def error_time_msg(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
时间格式错误。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须是时间.' % cls.label(column=column)
|
||||
|
||||
@classmethod
|
||||
def error_decimal_msg(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
非浮点或双精度类型错误。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须是浮点或双进度类型.' % cls.label(column=column)
|
||||
|
||||
@classmethod
|
||||
def error_format_msg(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
格式错误。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s格式错误.' % cls.label(column=column)
|
||||
|
||||
@classmethod
|
||||
def error_int_msg(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
非整数类型错误。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须是整数.' % cls.label(column=column)
|
||||
|
||||
@classmethod
|
||||
def error_len_msg(cls, column: Union[Column, Any], length: int):
|
||||
"""
|
||||
长度错误消息。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须是%d位.' % (cls.label(column=column), length)
|
||||
|
||||
@classmethod
|
||||
def error_in_range_msg(cls, column: Union[Column, Any],
|
||||
v_min: Union[int, float], v_max: Union[int, float]):
|
||||
"""
|
||||
范围错误。主要用于数据值校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须在:[%s,%s] 范围内.' % (cls.label(column=column), f"{v_min}", f"{v_max}")
|
||||
|
||||
@classmethod
|
||||
def error_in_items_msg(cls, column: Union[Column, Any], items: list = None):
|
||||
"""
|
||||
范围错误。主要用于数据项校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
if items is None:
|
||||
return '%s超出范围.' % cls.label(column=column)
|
||||
else:
|
||||
return '%s必须在:[%s] 范围内.' % (cls.label(column=column), ','.join(items))
|
||||
|
||||
@classmethod
|
||||
def error_str_msg(cls, column: Union[Column, Any]):
|
||||
"""
|
||||
非字符串类型错误。主要用于数据校验错误。
|
||||
|
||||
:return: 以字段备注为主的错误消息
|
||||
"""
|
||||
return '%s必须是字符串' % cls.label(column=column)
|
||||
|
||||
field_validators: dict[Column, tuple] = {}
|
||||
"""
|
||||
字段验证器配置。
|
||||
规则为:字段名 -> 验证配置
|
||||
验证配置为一个 tuple 数据,各元素说明如下::
|
||||
|
||||
第 0 项:验证方法与消息方法,类型为 method 或 tuple,若仅有验证方法,则直接是方法名即可,若两者皆有,则为 tuple。
|
||||
第 1 项:是否跳过 None 值,类型为 bool。
|
||||
第 2~n 项,验证方法或消息方法的参数,注意验证方法与消息方法除第一项参数外的其他参数必须一致。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def validate_fields(cls, row: dict) -> list[dict[str, str]]:
|
||||
"""
|
||||
结合 field_validators 的配置,对字段执行验证。
|
||||
若发现错误,则记录在 _errors 中并返回。
|
||||
|
||||
:param row: 待验证数据。
|
||||
:return: 验证得到的错误描述。
|
||||
"""
|
||||
_errors: list[dict[str: str]] = []
|
||||
for _column, _validator in cls.field_validators.items():
|
||||
# 消息函数
|
||||
_message_func = None
|
||||
# 验证函数,是否跳过空值
|
||||
_verify_func, _skip_null = _validator[:2]
|
||||
|
||||
if isinstance(_verify_func, tuple):
|
||||
_verify_func, _message_func = _verify_func
|
||||
|
||||
_value = udict.get_with_default(row, _column.key, None)
|
||||
if _value is None and _skip_null:
|
||||
continue
|
||||
|
||||
_args = _validator[2:]
|
||||
_vfy_args = (_value,) + _args
|
||||
_err_args = (_column,) + _args
|
||||
|
||||
assert isinstance(_verify_func, Callable), '验证器配置错误.'
|
||||
if not _verify_func(*_vfy_args):
|
||||
if isinstance(_message_func, Callable):
|
||||
_errors.append({_column.key: _message_func(*_err_args)})
|
||||
else:
|
||||
_errors.append({_column.key: f'{_column.key} 字段数据错误.'})
|
||||
return _errors
|
||||
|
||||
@classmethod
|
||||
def validate_dict(cls, row: dict, row_list: list[dict], err_list: list[dict]):
|
||||
"""
|
||||
验证字典数据。仅将结果加入对应的列表,不改变原有数据。
|
||||
|
||||
:param row: 待验证的行数据对象
|
||||
:param row_list: 用于存放验证成功模型的列表
|
||||
:param err_list: 用于存放错误消息的列表
|
||||
"""
|
||||
try:
|
||||
row_list.append(row)
|
||||
except TypeError:
|
||||
err_list.append(row)
|
||||
|
||||
@classmethod
|
||||
def validate_dict_list(cls, row_list: list[dict]) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
验证字典列表数据,首先清除历史模型列表和错误消息。
|
||||
|
||||
:param row_list: 待验证的字典数组
|
||||
:return: 数据模型列表和错误消息列表
|
||||
"""
|
||||
_row_list: list[dict] = []
|
||||
_err_list: list[dict] = []
|
||||
|
||||
for row in row_list:
|
||||
cls.validate_dict(row=row, row_list=_row_list, err_list=_err_list)
|
||||
return _row_list, _err_list
|
||||
|
||||
@classmethod
|
||||
def mapping_data_struct(cls, source: Optional[dict], mapping: Optional[dict]):
|
||||
"""
|
||||
将源数据字典中的数据,按照映射关系字典的方式转换为新的字典对象。
|
||||
|
||||
下面是一个递归映射关系字典的样本::
|
||||
|
||||
dict_key_mapping = {
|
||||
'devUseNo': 'dev_use_no',
|
||||
'mainId': 'id',
|
||||
'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''),
|
||||
'mainCycleCode': 'main_cycle',
|
||||
'fileList': {
|
||||
'__name__': 'main_files',
|
||||
'__mapping__': {
|
||||
'fileName': 'file_name',
|
||||
'filePath': 'file_url',
|
||||
},
|
||||
},
|
||||
'mainDetailList': {
|
||||
'__name__': 'main_items',
|
||||
'__mapping__': {
|
||||
'id': 'id',
|
||||
'itemId': 'item_id',
|
||||
'itemName': 'item_name',
|
||||
'itemRequest': 'item_request',
|
||||
'itemResult': 'item_result',
|
||||
'remarks': 'remarks',
|
||||
'itemFileList': {
|
||||
'__name__': 'main_item_files',
|
||||
'__mapping__': {
|
||||
'fileName': 'file_name',
|
||||
'filePath': 'file_url',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型::
|
||||
|
||||
1、为字符串时,表示从源数据字典中直接读取。
|
||||
2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。
|
||||
3、为字典时,表示有子对象数据,此时需要配置 __name__ 属性和 __mapping__ 属性。
|
||||
4、非以上情况的,直接使用该内容作为目标字典属性的数据。
|
||||
|
||||
:param source: 源数据字典
|
||||
:param mapping: 映射关系字典
|
||||
|
||||
:return: 转换后的字典
|
||||
"""
|
||||
if source is None or mapping is None:
|
||||
return None
|
||||
|
||||
target = {}
|
||||
for _tar_attr, _src_attr in mapping.items():
|
||||
if isinstance(_src_attr, str):
|
||||
#
|
||||
# 直接处理 key 映射关系
|
||||
# 注意,对于需要强制设置为字符串的,不能直接使用字符串,会被误认为是 key 映射关系,应当使用无参数 lambda 表达式。
|
||||
#
|
||||
target[_tar_attr] = source.get(_src_attr, None)
|
||||
elif isinstance(_src_attr, Callable):
|
||||
#
|
||||
# 处理函数或 lambda 表达式
|
||||
#
|
||||
target[_tar_attr] = _src_attr(source)
|
||||
elif isinstance(_src_attr, dict):
|
||||
if '__name__' in _src_attr and '__mapping__' in _src_attr:
|
||||
#
|
||||
# 包含名称映射的,表示新的映射关系,递归处理
|
||||
# 这里仅处理类型为 dict 和 list 的数据
|
||||
#
|
||||
|
||||
# 取出内部源数据字典和映射关系
|
||||
_sd = source.get(_src_attr.get('__name__'), None)
|
||||
_mp = _src_attr.get('__mapping__', None)
|
||||
|
||||
if isinstance(_sd, dict):
|
||||
#
|
||||
# 直接递归映射
|
||||
#
|
||||
target[_tar_attr] = cls.mapping_data_struct(_sd, _mp)
|
||||
elif isinstance(_sd, list):
|
||||
#
|
||||
# 遍历后递归映射
|
||||
#
|
||||
_t_list = []
|
||||
for _sd_item in _sd:
|
||||
_t_list.append(cls.mapping_data_struct(_sd_item, _mp))
|
||||
target[_tar_attr] = _t_list
|
||||
else:
|
||||
#
|
||||
# 非 dict,list 的,直接设置
|
||||
#
|
||||
target[_tar_attr] = _sd
|
||||
else:
|
||||
#
|
||||
# 无映射关系的,直接设置
|
||||
#
|
||||
target[_tar_attr] = _src_attr
|
||||
else:
|
||||
#
|
||||
# 非 str,function,dict 的,直接设置
|
||||
#
|
||||
target[_tar_attr] = _src_attr
|
||||
|
||||
return target
|
||||
|
||||
@classmethod
|
||||
def transform(cls, rows: list[dict], mapping: dict):
|
||||
"""
|
||||
将源数据 rows 字典中的数据,按照映射关系字典 mapping 的方式转换为新的字典对象。
|
||||
|
||||
下面是一个递归映射关系字典的样本::
|
||||
|
||||
dict_key_mapping = {
|
||||
'devUseNo': 'dev_use_no',
|
||||
'mainId': 'id',
|
||||
'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''),
|
||||
'mainCycleCode': 'main_cycle',
|
||||
'fileList': {
|
||||
'__name__': 'main_files',
|
||||
'__mapping__': {
|
||||
'fileName': 'file_name',
|
||||
'filePath': 'file_url',
|
||||
},
|
||||
},
|
||||
'mainDetailList': {
|
||||
'__name__': 'main_items',
|
||||
'__mapping__': {
|
||||
'id': 'id',
|
||||
'itemId': 'item_id',
|
||||
'itemName': 'item_name',
|
||||
'itemRequest': 'item_request',
|
||||
'itemResult': 'item_result',
|
||||
'remarks': 'remarks',
|
||||
'itemFileList': {
|
||||
'__name__': 'main_item_files',
|
||||
'__mapping__': {
|
||||
'fileName': 'file_name',
|
||||
'filePath': 'file_url',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型::
|
||||
|
||||
1、为字符串时,表示从源数据字典中直接读取。
|
||||
2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。
|
||||
3、为字典时,表示有子对象数据,此时需要配置 __name__ 属性和 __mapping__ 属性。
|
||||
4、非以上情况的,直接使用该内容作为目标字典属性的数据。
|
||||
:param rows: 源数据字典列表
|
||||
:param mapping: 映射关系字典
|
||||
:return: 转换结果
|
||||
"""
|
||||
_dict_list: list[dict] = []
|
||||
for _r in rows:
|
||||
_tar_dict = cls.mapping_data_struct(_r, mapping)
|
||||
_dict_list.append(_tar_dict)
|
||||
return _dict_list
|
||||
|
||||
@classmethod
|
||||
def convert(cls, dataframe: pd.DataFrame, mapping: dict):
|
||||
"""
|
||||
将源数据框架 dataframe 中的数据,按照映射关系字典 mapping 的方式转换为新的 dataframe 对象。
|
||||
|
||||
下面是一个递归映关系射字典的样本::
|
||||
|
||||
dict_key_mapping = {
|
||||
'devUseNo': 'dev_use_no',
|
||||
'mainId': 'id',
|
||||
'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''),
|
||||
'mainCycleCode': 'main_cycle',
|
||||
}
|
||||
|
||||
映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型::
|
||||
|
||||
1、为字符串时,表示从源数据字典中直接读取。
|
||||
2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。
|
||||
3、非以上情况的,直接使用该内容作为目标字典属性的数据。
|
||||
|
||||
注意:与字典映射转换不同,:class:`pd.DataFrame` 映射转换不支持多层递归转换。
|
||||
|
||||
:param dataframe: 源数据 dataframe
|
||||
:param mapping: 映射关系字典
|
||||
:return: 转换结果
|
||||
"""
|
||||
_tar_df = pd.DataFrame()
|
||||
for _tar_attr, _src_attr in mapping.items():
|
||||
if isinstance(_src_attr, str):
|
||||
_tar_df[_tar_attr] = dataframe[_src_attr]
|
||||
elif isinstance(_src_attr, Callable):
|
||||
_tar_df[_tar_attr] = dataframe.apply(_src_attr, axis=1)
|
||||
else:
|
||||
_tar_df[_tar_attr] = _tar_attr
|
||||
return _tar_df
|
||||
|
||||
@classmethod
|
||||
def is_equal(cls, data_dict: dict, data_model: 'BaseModel', skip_kes: list[str] = None, decimals: str = '0.00'):
|
||||
"""
|
||||
判断 data_dict 中的值是否都与 equ_model 中的对应值相等。一般而言若相等,则表明无需更新数据模型,否则就需要更新。
|
||||
|
||||
:param data_dict: 数据字典,用于遍历比对的数据,也是用于更新的数据
|
||||
:param data_model: 数据模型
|
||||
:param skip_kes: 允许跳过,不做比较的字段
|
||||
:param decimals: 浮点数保留的小数位,默认 2 位
|
||||
:return: 是否相等,各字段是否相等的对应关系字典
|
||||
"""
|
||||
is_equal = True
|
||||
equal_dict: dict = {}
|
||||
if skip_kes is None:
|
||||
skip_kes = []
|
||||
|
||||
for _key, _new_val in data_dict.items():
|
||||
if _key in skip_kes:
|
||||
continue
|
||||
|
||||
if _new_val is None:
|
||||
# 跳过新值中的 None
|
||||
continue
|
||||
|
||||
if _key not in data_model.__dict__:
|
||||
# 跳过不存在的属性
|
||||
continue
|
||||
|
||||
_old_val = data_model.__dict__.get(_key, None)
|
||||
if isinstance(_old_val, (Decimal, float)):
|
||||
_old_val = Decimal(f"{_old_val}").quantize(Decimal(decimals), rounding=ROUND_HALF_UP)
|
||||
_new_val = Decimal(f"{_new_val}").quantize(Decimal(decimals), rounding=ROUND_HALF_UP)
|
||||
elif isinstance(_old_val, int):
|
||||
_new_val = int(_new_val)
|
||||
elif isinstance(_old_val, datetime.datetime):
|
||||
_old_val = _old_val.strftime(LOCAL_DATETIME_FORMAT)
|
||||
_datetime = ustr.to_datetime(_new_val, [LOCAL_DATETIME_FORMAT, LOCAL_DATE_FORMAT])
|
||||
_new_val = _datetime.strftime(LOCAL_DATETIME_FORMAT) if _datetime is not None else f"{_new_val}"
|
||||
elif isinstance(_old_val, datetime.date):
|
||||
_old_val = _old_val.strftime(LOCAL_DATE_FORMAT)
|
||||
_date = ustr.to_datetime(_new_val, [LOCAL_DATE_FORMAT, LOCAL_DATETIME_FORMAT])
|
||||
_new_val = _date.strftime(LOCAL_DATE_FORMAT) if _date is not None else f"{_new_val}"
|
||||
else:
|
||||
_old_val = f"{_old_val}" if _old_val is not None else ''
|
||||
if isinstance(_new_val, float):
|
||||
_new_val = int(_new_val)
|
||||
_new_val = f"{_new_val}"
|
||||
|
||||
_isFieldEqual = _new_val == _old_val
|
||||
is_equal = is_equal and _isFieldEqual
|
||||
equal_dict[_key] = _isFieldEqual
|
||||
|
||||
return is_equal, equal_dict
|
||||
|
||||
@classmethod
|
||||
async def page_info(cls, *where_clause, page_size: int = 20):
|
||||
"""
|
||||
分页参数。
|
||||
|
||||
:return: 页数, 数据行数
|
||||
"""
|
||||
_row_count = await cls.async_row_count(*where_clause)
|
||||
_pagination = Pagination(row_count=_row_count)
|
||||
return _pagination.pages(page_size=page_size), _row_count
|
||||
|
||||
@classmethod
|
||||
def sort_clauses(cls, sort_d: dict):
|
||||
"""
|
||||
按照参数 sort_d 中的定义,组织排序表达式。参数 sortd_d 应该具有如下结构::
|
||||
|
||||
{
|
||||
'field_name1': 'asc',
|
||||
'field_name2': 'desc',
|
||||
}
|
||||
|
||||
:param sort_d 排序参数
|
||||
"""
|
||||
_sort_clause = []
|
||||
for _fn, _st in sort_d.items():
|
||||
if _st in ('', 'asc', 'ascend'):
|
||||
_sort_clause.append(text(_fn))
|
||||
if _st in ('desc', 'descend'):
|
||||
_sort_clause.append(desc(text(_fn)))
|
||||
return _sort_clause
|
||||
Executable
+337
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
集成了表的基本操作。
|
||||
"""
|
||||
import datetime
|
||||
from typing import Optional, Union, List, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import func, desc, Column, select, util, Result
|
||||
from sqlalchemy.engine import ScalarResult, CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
|
||||
from sqlalchemy.sql import Select
|
||||
from sqlalchemy.sql.elements import BinaryExpression
|
||||
from sqlalchemy.sql.operators import ColumnOperators
|
||||
|
||||
from paste.db import baseadapter, engine
|
||||
|
||||
|
||||
class BaseTable(baseadapter.BaseAdapter):
|
||||
"""
|
||||
表对象(数据模型)基类。集成基础数据操作功能。
|
||||
实现异步执行方法。
|
||||
"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
@classmethod
|
||||
async def raw_execute(cls, query, session: AsyncSession=None,
|
||||
params=None, options=None) -> Optional[CursorResult]:
|
||||
"""
|
||||
异步连接执行原始查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。
|
||||
|
||||
:param query: 查询语句
|
||||
:param session: 数据库会话对象
|
||||
:param params: 查询参数
|
||||
:param options: 查询选项
|
||||
:return: CursorResult 游标
|
||||
"""
|
||||
if options is None:
|
||||
options = util.EMPTY_DICT
|
||||
if session is None:
|
||||
connection: AsyncConnection = await cls.get_aio_session().connection()
|
||||
else:
|
||||
connection = await session.connection()
|
||||
|
||||
try:
|
||||
cursor_result: CursorResult = await connection.execute(query, params, execution_options=options)
|
||||
await connection.commit()
|
||||
return cursor_result
|
||||
finally:
|
||||
await connection.close()
|
||||
|
||||
@classmethod
|
||||
async def orm_execute(cls, query, session: AsyncSession=None, params=None, options=None) -> Optional[Result]:
|
||||
"""
|
||||
异步会话执行查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。
|
||||
|
||||
:param query: 查询请求
|
||||
:param session: 数据库会话对象
|
||||
:param params: 查询参数
|
||||
:param options: 查询选项
|
||||
"""
|
||||
_has_session: bool = True
|
||||
if session is None:
|
||||
_has_session = False
|
||||
session = cls.get_aio_session()
|
||||
|
||||
try:
|
||||
result: Result = await session.execute(query, params, execution_options=options)
|
||||
return result
|
||||
finally:
|
||||
if not _has_session:
|
||||
await session.close()
|
||||
|
||||
@classmethod
|
||||
async def orm_execute_scalars(cls, query, session: AsyncSession=None, **kwargs) -> ScalarResult:
|
||||
"""
|
||||
使用异步执行查询。并执行 scalars() 方法,这会进行 ORM 映射。
|
||||
|
||||
:param query: 查询对象
|
||||
:param session: 数据库会话对象
|
||||
:return: 数据模型对象列表
|
||||
"""
|
||||
result: Result = await cls.orm_execute(query, session, **kwargs)
|
||||
return result.scalars()
|
||||
|
||||
@classmethod
|
||||
async def query_all(cls, query: Select, session: AsyncSession=None) -> List['BaseTable']:
|
||||
"""
|
||||
使用异步执行 ORM 查询。并执行 scalars().all() 方法。
|
||||
|
||||
:param query: 查询对象
|
||||
:param session: 数据库会话对象
|
||||
:return: 数据模型对象列表,注意:当查询部分字段时仅返回查询列表中第一个字段的数据列表
|
||||
"""
|
||||
scalars = await cls.orm_execute_scalars(query, session)
|
||||
return list(scalars.all())
|
||||
|
||||
@classmethod
|
||||
async def query_first(cls, query: Select, session: AsyncSession=None) -> 'BaseTable':
|
||||
"""
|
||||
使用异步执行 ORM 查询。并执行 scalars().first() 方法。
|
||||
|
||||
:param query: 查询对象
|
||||
:param session: 数据库会话对象
|
||||
:return: 数据模型对象列表
|
||||
"""
|
||||
scalars = await cls.orm_execute_scalars(query, session)
|
||||
return scalars.first()
|
||||
|
||||
@classmethod
|
||||
async def query_count(cls, query: Select, is_only_count: bool = False, session: AsyncSession=None) -> int:
|
||||
"""
|
||||
使用异步执行 ORM 查询,查询原查询的数据行数。
|
||||
|
||||
:param query: 查询对象
|
||||
:param is_only_count: 是否使用仅 count 方式查询
|
||||
:param session: 数据库会话对象
|
||||
:return: 数据行数
|
||||
"""
|
||||
if is_only_count:
|
||||
count_query = query.with_only_columns(func.count(1))
|
||||
else:
|
||||
count_query = select(func.count(1)).select_from(query.subquery())
|
||||
# 执行查询
|
||||
row_count: int = (await cls.orm_execute_scalars(count_query, session)).first()
|
||||
return row_count
|
||||
|
||||
@classmethod
|
||||
async def query_as_df(cls, query, session: AsyncSession=None, params=None, options=None):
|
||||
"""
|
||||
执行 RAW 原始查询,并将数据请求转换为 :class:`pd.DataFrame` 返回。
|
||||
注意:为了保持字段数据精度或便于数据处理,可在 SQL 语句中使用 :meth:`func.ifnull` 或 :meth:`func.convert` 等方法直接转换数据类型。
|
||||
|
||||
:param query: 查询语句
|
||||
:param session: 数据库会话对象
|
||||
:param params: 查询参数
|
||||
:param options: 查询选项
|
||||
:return: 数据 DataFrame
|
||||
"""
|
||||
result = await cls.raw_execute(query, session, params, options)
|
||||
return pd.DataFrame(result.all(), columns=pd.Series(result.keys()))
|
||||
|
||||
@classmethod
|
||||
async def async_row_count(cls, *where_expressions: Union[ColumnOperators, BinaryExpression],
|
||||
session: AsyncSession=None) -> int:
|
||||
"""
|
||||
异步按条件取得数据行数量。
|
||||
|
||||
:param where_expressions: 条件表达式列表
|
||||
:param session: 数据库会话对象
|
||||
:return: 行数
|
||||
"""
|
||||
query = select(func.count(1)).select_from(cls).where(*where_expressions)
|
||||
result = await cls.orm_execute_scalars(query, session)
|
||||
return result.first()
|
||||
|
||||
@classmethod
|
||||
async def async_exist_by_id(cls, id_val: Union[str, int]):
|
||||
"""
|
||||
检查是否有存在的主键。
|
||||
|
||||
:param id_val: 主键值
|
||||
"""
|
||||
count: int = await cls.async_row_count(cls.instrument_attr('id') == id_val)
|
||||
return count >= 1
|
||||
|
||||
@classmethod
|
||||
async def async_find_by_id(cls, id_val: Union[str, int]) -> Optional['BaseTable']:
|
||||
"""
|
||||
根据主键ID查找数据,确认有 id 字段后方可使用。
|
||||
|
||||
:param id_val: 主键值
|
||||
:return: 数据模型对象
|
||||
"""
|
||||
query = select(cls).where(cls.instrument_attr('id') == id_val)
|
||||
model = await cls.query_first(query=query)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
async def async_find_by_created(cls, from_t: datetime.datetime,
|
||||
to_t: Optional[datetime.datetime] = None) -> List['BaseTable']:
|
||||
"""
|
||||
按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
|
||||
|
||||
:param from_t: 开始时间
|
||||
:param to_t: 结束时间
|
||||
"""
|
||||
if to_t is None:
|
||||
to_t = from_t
|
||||
|
||||
query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
|
||||
return await cls.query_all(query)
|
||||
|
||||
@classmethod
|
||||
async def async_find_by_updated(cls, from_t: datetime.datetime,
|
||||
to_t: Optional[datetime.datetime] = None) -> List['BaseTable']:
|
||||
"""
|
||||
按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
|
||||
|
||||
:param from_t: 开始时间
|
||||
:param to_t: 结束时间
|
||||
"""
|
||||
if to_t is None:
|
||||
to_t = from_t
|
||||
|
||||
query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
|
||||
return await cls.query_all(query)
|
||||
|
||||
@classmethod
|
||||
async def async_find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]):
|
||||
"""
|
||||
根据数据列表查询已经在数据库中的数据。
|
||||
|
||||
:param row_list: 数据列表
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
model_list: list[cls] = []
|
||||
query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
|
||||
if query is not None:
|
||||
model_list = await cls.query_all(query=query)
|
||||
|
||||
return model_list
|
||||
|
||||
@classmethod
|
||||
async def async_find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
|
||||
"""
|
||||
根据数据框架查询已经在数据库中的数据。
|
||||
|
||||
:param row_df: 数据框架
|
||||
:param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
|
||||
:return: 查询到的数据模型列表
|
||||
"""
|
||||
query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
|
||||
if query is not None:
|
||||
_result = await cls.orm_execute_scalars(query=query)
|
||||
_rows: Sequence[cls] = _result.all()
|
||||
return pd.DataFrame(_rows)
|
||||
|
||||
return None
|
||||
|
||||
async def async_find(self, likes: Optional[dict[str, str]] = None) -> List['BaseTable']:
|
||||
"""
|
||||
根据自身参数,查询数据库。
|
||||
|
||||
:param likes: 模糊条件
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
expressions = self.filter_expressions(likes=likes)
|
||||
query = select(self.__class__).where(*expressions)
|
||||
return await self.query_all(query)
|
||||
|
||||
async def async_find_first(self, likes: Optional[dict[str, str]] = None) -> Optional['BaseTable']:
|
||||
"""
|
||||
根据自身参数,查询数据库,仅查询第一条。
|
||||
|
||||
:param likes: 模糊条件
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
expressions = self.filter_expressions(likes=likes)
|
||||
query = select(self.__class__).where(*expressions)
|
||||
result = await self.query_first(query)
|
||||
return result
|
||||
|
||||
async def async_find_piece(self, *where: Union[ColumnOperators, BinaryExpression],
|
||||
offset=0, limit=500, is_desc=False,
|
||||
likes: Optional[dict[str, str]] = None) -> List['BaseTable']:
|
||||
"""
|
||||
根据自身参数,查询数据库。
|
||||
|
||||
:param where: 查询条件
|
||||
:param offset: 偏移量
|
||||
:param limit: 读取数量
|
||||
:param is_desc: 是否逆序排列
|
||||
:param likes: 模糊条件
|
||||
:return: 查询到的结果对象
|
||||
"""
|
||||
clz = self.__class__
|
||||
expressions = self.filter_expressions(likes=likes)
|
||||
if where is not None:
|
||||
expressions += where
|
||||
|
||||
query = select(clz).where(*expressions)
|
||||
if limit > 0:
|
||||
query = query.limit(limit=limit)
|
||||
if offset >= 0:
|
||||
query = query.offset(offset=offset)
|
||||
if is_desc:
|
||||
if hasattr(clz, 'id'):
|
||||
query = query.order_by(desc(clz.id))
|
||||
|
||||
return await self.query_all(query)
|
||||
|
||||
async def async_save(self, auto_expunge: Optional[bool] = True, session: Optional[AsyncSession] = None):
|
||||
"""
|
||||
保存数据模型对象,首先强制关闭原有会话,获取新的会话,并加入对象。
|
||||
|
||||
:param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
|
||||
:param session: 会话对象,主要用于事务
|
||||
:return: 保存状态
|
||||
"""
|
||||
_has_session: bool = True
|
||||
"""
|
||||
该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
|
||||
"""
|
||||
|
||||
self.before_save()
|
||||
self.close_session()
|
||||
if session is None:
|
||||
_has_session = False
|
||||
session = self.get_aio_session()
|
||||
|
||||
try:
|
||||
session.add(self)
|
||||
if not _has_session:
|
||||
# 使用新会话时,直接提交
|
||||
await session.commit()
|
||||
self._is_new = False
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
else:
|
||||
if auto_expunge and not _has_session:
|
||||
await session.refresh(self)
|
||||
session.expunge(self)
|
||||
return True
|
||||
finally:
|
||||
if not _has_session:
|
||||
# 使用新会话时,主动关闭
|
||||
await session.close()
|
||||
|
||||
|
||||
def create_all_tables():
|
||||
"""
|
||||
创建所有的表格。
|
||||
"""
|
||||
baseadapter.registry.metadata.create_all(engine.connect_engine())
|
||||
Executable
+43
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
从配置文件读取数据引擎连接信息,连接数据库。
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine.mock import MockConnection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
|
||||
from paste.core import config
|
||||
|
||||
ASYNC_CONNECTOR_ENGINE = None
|
||||
GLOBAL_CONNECTOR_ENGINE = None
|
||||
|
||||
|
||||
def connect_engine() -> Union[MockConnection, Engine]:
|
||||
"""
|
||||
全局数据连接引擎。
|
||||
|
||||
:return: 数据连接引擎
|
||||
"""
|
||||
global GLOBAL_CONNECTOR_ENGINE
|
||||
if GLOBAL_CONNECTOR_ENGINE is None:
|
||||
GLOBAL_CONNECTOR_ENGINE = create_engine(
|
||||
config.get_config('db_engine.engine'), **config.get_config('db_engine.engine_option')
|
||||
)
|
||||
return GLOBAL_CONNECTOR_ENGINE
|
||||
|
||||
|
||||
def async_connect_engine() -> AsyncEngine:
|
||||
"""
|
||||
异步数据连接引擎。
|
||||
|
||||
:return: 异步数据连接引擎
|
||||
"""
|
||||
global ASYNC_CONNECTOR_ENGINE
|
||||
if ASYNC_CONNECTOR_ENGINE is None:
|
||||
ASYNC_CONNECTOR_ENGINE = create_async_engine(
|
||||
config.get_config('db_engine.async_engine'), **config.get_config('db_engine.engine_option')
|
||||
)
|
||||
return ASYNC_CONNECTOR_ENGINE
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
生成数据模型代码,注意这里显式排除了在配置文件 RBAC 中配置的数据表。
|
||||
如果不需要排除,可直接使用 sqlacodegen 自带的命令。
|
||||
"""
|
||||
import subprocess
|
||||
from os import path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from paste.core import config
|
||||
from paste.core.logging import echo_log
|
||||
from paste.db.basetable import BaseTable
|
||||
|
||||
exclude_tables = list(config.get_config('rbac.table').values())
|
||||
"""
|
||||
需要排除的数据表,默认排除 RBAC 数据表,这部分表格在 RBAC 模块中已经配置好了。
|
||||
"""
|
||||
|
||||
|
||||
def db_engin():
|
||||
return config.get_config('db_engine.engine')
|
||||
|
||||
|
||||
async def sqlacodegen(is_exclude_rbac_table: bool = True):
|
||||
"""
|
||||
生成代码文件。
|
||||
|
||||
:param is_exclude_rbac_table: 是否排除 RBAC 相关的数据表,默认排除不生成数据模型
|
||||
"""
|
||||
_table_names = await BaseTable.tables_in_db()
|
||||
|
||||
# 剔除 RBAC 数据表
|
||||
if _table_names and is_exclude_rbac_table:
|
||||
# 转换为 DataFrame
|
||||
_tables_df = pd.DataFrame(_table_names)
|
||||
# 剔除不要包含的表
|
||||
_name_df: pd.DataFrame = _tables_df.loc[~_tables_df.iloc[:, 0].isin(exclude_tables)]
|
||||
# 转换剩余数据为'表名'字符串列表
|
||||
_table_names = _name_df[0].tolist()
|
||||
|
||||
if len(_table_names) == 0:
|
||||
return
|
||||
|
||||
echo_log(f"将为以下表生成数据模型:{_table_names}")
|
||||
_tables = f"--tables={','.join(_table_names)}"
|
||||
# 默认创建在当前目录的 models 目录中
|
||||
_outfile = f"--outfile={path.join(path.curdir, 'models', 'db_models.py')}"
|
||||
|
||||
_engin = db_engin()
|
||||
subprocess.call(['sqlacodegen', _engin, _tables, _outfile])
|
||||
echo_log(f"生成完成.")
|
||||
@@ -0,0 +1,995 @@
|
||||
"""
|
||||
封装了 Python 对 Redis 的基本操作。
|
||||
同时处理了 Java 在操作 Redis 后留下的字节码问题。
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import pathlib
|
||||
import random
|
||||
import types
|
||||
from logging import ERROR, WARNING
|
||||
from typing import Optional, Callable, Awaitable, Union, Tuple, Dict
|
||||
|
||||
import javaobj
|
||||
import redis
|
||||
from redis.asyncio import ConnectionPool, StrictRedis
|
||||
from redis.client import Pipeline
|
||||
|
||||
from paste.core import aio_pool, config, logging
|
||||
from paste.util.snow_id import IdWorker
|
||||
|
||||
|
||||
class LuaScriptManager:
|
||||
"""
|
||||
Lua 脚本管理器。
|
||||
负责加载、缓存和执行 Lua 脚本。
|
||||
"""
|
||||
|
||||
# 默认 Lua 脚本内容(作为内置默认值,无需外部文件)
|
||||
DEFAULT_SCRIPTS = {
|
||||
"stock_decr": """
|
||||
-- 扣减库存(原子操作)
|
||||
-- KEYS[1]: 库存 key
|
||||
-- ARGV[1]: 扣减数量
|
||||
-- 返回值: 1=成功, 0=库存不足, -1=key不存在
|
||||
|
||||
local key = KEYS[1]
|
||||
local quantity = tonumber(ARGV[1])
|
||||
|
||||
local current = redis.call('GET', key)
|
||||
if not current then
|
||||
return -1
|
||||
end
|
||||
|
||||
current = tonumber(current)
|
||||
if current >= quantity then
|
||||
redis.call('DECRBY', key, quantity)
|
||||
return 1
|
||||
else
|
||||
return 0
|
||||
end
|
||||
""",
|
||||
|
||||
"stock_incr": """
|
||||
-- 增加库存(原子操作)
|
||||
-- KEYS[1]: 库存 key
|
||||
-- ARGV[1]: 增加数量
|
||||
-- 返回值: 当前库存
|
||||
|
||||
local key = KEYS[1]
|
||||
local quantity = tonumber(ARGV[1])
|
||||
|
||||
redis.call('INCRBY', key, quantity)
|
||||
return redis.call('GET', key)
|
||||
""",
|
||||
|
||||
"stock_peek": """
|
||||
-- 查看库存(原子操作)
|
||||
-- KEYS[1]: 库存 key
|
||||
-- 返回值: 当前库存
|
||||
|
||||
local key = KEYS[1]
|
||||
local current = redis.call('GET', key)
|
||||
|
||||
if not current then
|
||||
return -1
|
||||
end
|
||||
return tonumber(current)
|
||||
""",
|
||||
}
|
||||
|
||||
_scripts: Dict[str, Tuple[str, str]] = {} # name -> (sha, script_content)
|
||||
_script_dir: Optional[pathlib.Path] = None
|
||||
_use_external_files: bool = False # 是否使用外部文件
|
||||
|
||||
@classmethod
|
||||
def set_script_dir(cls, script_dir: str, use_external: bool = True):
|
||||
"""
|
||||
设置 Lua 脚本目录
|
||||
|
||||
:param script_dir: 脚本目录路径
|
||||
:param use_external: 是否使用外部文件(False 则使用内置默认脚本)
|
||||
"""
|
||||
cls._script_dir = pathlib.Path(script_dir) if script_dir else None
|
||||
cls._use_external_files = use_external
|
||||
|
||||
@classmethod
|
||||
async def load_script(cls, redis_client: StrictRedis, script_name: str) -> str:
|
||||
"""
|
||||
加载并注册 Lua 脚本
|
||||
优先使用外部文件,不存在则使用内置默认脚本
|
||||
|
||||
:param redis_client: Redis 客户端
|
||||
:param script_name: 脚本名称(如 stock_decr)
|
||||
:return: 脚本 SHA
|
||||
"""
|
||||
script_content = None
|
||||
|
||||
# 尝试从外部文件加载
|
||||
if cls._use_external_files and cls._script_dir:
|
||||
script_path = cls._script_dir / f"{script_name}.lua"
|
||||
if script_path.exists():
|
||||
with open(script_path, 'r', encoding='utf-8') as f:
|
||||
script_content = f.read()
|
||||
logging.echo_log(f"Lua 脚本从外部文件加载: {script_path}")
|
||||
|
||||
# 使用内置默认脚本
|
||||
if script_content is None:
|
||||
if script_name not in cls.DEFAULT_SCRIPTS:
|
||||
raise ValueError(f"脚本不存在: {script_name},且无内置默认值")
|
||||
script_content = cls.DEFAULT_SCRIPTS[script_name]
|
||||
logging.echo_log(f"Lua 脚本使用内置默认值: {script_name}")
|
||||
|
||||
# 计算 SHA
|
||||
sha = hashlib.sha1(script_content.encode()).hexdigest()
|
||||
|
||||
# 缓存脚本
|
||||
cls._scripts[script_name] = (sha, script_content)
|
||||
|
||||
# 预加载到 Redis
|
||||
try:
|
||||
await redis_client.script_load(script_content)
|
||||
except Exception:
|
||||
pass # 预加载失败不影响后续使用
|
||||
|
||||
return sha
|
||||
|
||||
@classmethod
|
||||
async def load_default_scripts(cls, redis_client: StrictRedis):
|
||||
"""
|
||||
加载所有默认脚本
|
||||
"""
|
||||
for script_name in cls.DEFAULT_SCRIPTS.keys():
|
||||
await cls.load_script(redis_client, script_name)
|
||||
logging.echo_log(f"已加载 {len(cls.DEFAULT_SCRIPTS)} 个默认 Lua 脚本")
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, redis_client: StrictRedis, script_name: str,
|
||||
keys: list, args: list) -> any:
|
||||
"""
|
||||
执行 Lua 脚本
|
||||
优先使用 evalsha(性能更好),失败则降级到 eval
|
||||
|
||||
:param redis_client: Redis 客户端
|
||||
:param script_name: 脚本名称
|
||||
:param keys: KEYS 参数列表
|
||||
:param args: ARGV 参数列表
|
||||
:return: 脚本执行结果
|
||||
"""
|
||||
if script_name not in cls._scripts:
|
||||
# 脚本未加载,尝试加载
|
||||
await cls.load_script(redis_client, script_name)
|
||||
|
||||
sha, script_content = cls._scripts[script_name]
|
||||
|
||||
try:
|
||||
return await redis_client.evalsha(sha, len(keys), *keys, *args)
|
||||
except redis.ResponseError as e:
|
||||
if "NOSCRIPT" in str(e):
|
||||
# 重新加载并重试
|
||||
await redis_client.script_load(script_content)
|
||||
return await redis_client.evalsha(sha, len(keys), *keys, *args)
|
||||
else:
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def reload_script(cls, redis_client: StrictRedis, script_name: str) -> str:
|
||||
"""重新加载指定的 Lua 脚本"""
|
||||
if script_name in cls._scripts:
|
||||
del cls._scripts[script_name]
|
||||
return await cls.load_script(redis_client, script_name)
|
||||
|
||||
|
||||
class Redis:
|
||||
"""
|
||||
Redis 基础操作。
|
||||
"""
|
||||
|
||||
connect_pool: Optional[ConnectionPool] = None
|
||||
|
||||
prefix = b'\xac\xed\x00\x05'
|
||||
utf_flag = b'\x74'
|
||||
|
||||
lua_scripts = LuaScriptManager
|
||||
"""Lua 脚本管理器。"""
|
||||
|
||||
@classmethod
|
||||
def is_java_serialized(cls, bs: Union[bytes, str]):
|
||||
"""
|
||||
判断是否为 Java 序列化后的数据。
|
||||
|
||||
:param bs: 字节流
|
||||
"""
|
||||
if not isinstance(bs, bytes):
|
||||
return False
|
||||
return bs[:4] == cls.prefix
|
||||
|
||||
@classmethod
|
||||
async def get_pool(cls) -> ConnectionPool:
|
||||
"""
|
||||
取得 Redis 连接池。
|
||||
|
||||
:return: 连接池对象
|
||||
"""
|
||||
if cls.connect_pool is None:
|
||||
_conn_params = config.get_config("redis.connection")
|
||||
cls.connect_pool = ConnectionPool.from_url(**_conn_params)
|
||||
return cls.connect_pool
|
||||
|
||||
@classmethod
|
||||
async def close_pool(cls):
|
||||
if cls.connect_pool is not None:
|
||||
await cls.connect_pool.disconnect()
|
||||
cls.connect_pool = None
|
||||
|
||||
@classmethod
|
||||
async def get_redis(cls) -> StrictRedis:
|
||||
"""
|
||||
取得数据库对象。
|
||||
|
||||
:return: 数据库对象
|
||||
"""
|
||||
_pool = await cls.get_pool()
|
||||
return StrictRedis(
|
||||
connection_pool=_pool,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
socket_keepalive=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def ping(cls):
|
||||
"""
|
||||
测试连接。
|
||||
|
||||
:return: 测试结果
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.ping()
|
||||
|
||||
@classmethod
|
||||
async def get_pipe(cls, transaction: bool = True, shard_hint=None) -> Pipeline:
|
||||
"""
|
||||
取得管道对象。
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return _redis.pipeline(transaction=transaction, shard_hint=shard_hint)
|
||||
|
||||
# ========== Lua 脚本初始化 ==========
|
||||
|
||||
@classmethod
|
||||
async def init_lua_scripts(cls, script_dir: str = None, use_external: bool = False):
|
||||
"""
|
||||
初始化 Lua 脚本
|
||||
建议在应用启动时调用一次
|
||||
|
||||
:param script_dir: 外部脚本目录(可选)
|
||||
:param use_external: 是否使用外部文件,默认 False 使用内置脚本
|
||||
"""
|
||||
if script_dir:
|
||||
cls.lua_scripts.set_script_dir(script_dir, use_external)
|
||||
|
||||
async with await cls.get_redis() as _redis:
|
||||
await cls.lua_scripts.load_default_scripts(_redis)
|
||||
|
||||
# ========== 库存核心方法(原子操作) ==========
|
||||
|
||||
@classmethod
|
||||
async def stock_decr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, str]:
|
||||
"""
|
||||
扣减库存(原子操作)
|
||||
使用 Lua 脚本保证原子性,防止超卖
|
||||
|
||||
:param stock_key: 库存 Key(支持分片,如 stock:iPhone15:shard:0)
|
||||
:param quantity: 扣减数量
|
||||
:return: (是否成功, 消息)
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
try:
|
||||
result = await cls.lua_scripts.execute(
|
||||
_redis,
|
||||
"stock_decr",
|
||||
keys=[stock_key],
|
||||
args=[quantity]
|
||||
)
|
||||
|
||||
if result == 1:
|
||||
return True, "扣减成功"
|
||||
elif result == 0:
|
||||
return False, "库存不足"
|
||||
else:
|
||||
return False, "商品不存在"
|
||||
except Exception as e:
|
||||
logging.echo_log(f"扣减库存异常: {e}", level=ERROR, is_log_exc=True)
|
||||
return False, f"系统异常: {e}"
|
||||
|
||||
@classmethod
|
||||
async def stock_incr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, int]:
|
||||
"""
|
||||
增加库存(原子操作)
|
||||
用于退货入库、补货等场景
|
||||
|
||||
:param stock_key: 库存 Key
|
||||
:param quantity: 增加数量
|
||||
:return: (是否成功, 当前库存)
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
try:
|
||||
result = await cls.lua_scripts.execute(
|
||||
_redis,
|
||||
"stock_incr",
|
||||
keys=[stock_key],
|
||||
args=[quantity]
|
||||
)
|
||||
return True, int(result)
|
||||
except Exception as e:
|
||||
logging.echo_log(f"增加库存异常: {e}", level=ERROR, is_log_exc=True)
|
||||
return False, 0
|
||||
|
||||
@classmethod
|
||||
async def stock_peek(cls, stock_key: str) -> int:
|
||||
"""
|
||||
查看剩余库存(原子操作)
|
||||
|
||||
:param stock_key: 库存 Key
|
||||
:return: 剩余库存
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
try:
|
||||
result = await cls.lua_scripts.execute(
|
||||
_redis,
|
||||
"stock_peek",
|
||||
keys=[stock_key],
|
||||
args=[]
|
||||
)
|
||||
return int(result) if result >= 0 else 0
|
||||
except Exception as e:
|
||||
logging.echo_log(f"查询库存异常: {e}", level=ERROR, is_log_exc=True)
|
||||
return 0
|
||||
|
||||
# ========== 库存分片辅助方法 ==========
|
||||
|
||||
@classmethod
|
||||
def get_shard_key(cls, sku_id: str, shard_id: int) -> str:
|
||||
"""
|
||||
获取分片 Key。
|
||||
推荐格式:{业务域}:{实体}:{唯一标识}:{分片/维度}:{扩展}
|
||||
|
||||
:param sku_id: 商品ID
|
||||
:param shard_id: 分片ID
|
||||
:return: 分片 Key
|
||||
"""
|
||||
return f"stock:{sku_id}:shard:{shard_id}"
|
||||
|
||||
@classmethod
|
||||
def get_user_shard(cls, sku_id: str, user_id: str, shard_count: int = 10) -> str:
|
||||
"""
|
||||
根据用户ID获取分片 Key
|
||||
|
||||
:param sku_id: 商品ID
|
||||
:param user_id: 用户ID
|
||||
:param shard_count: 分片总数
|
||||
:return: 分片 Key
|
||||
"""
|
||||
shard = hash(user_id) % shard_count
|
||||
return cls.get_shard_key(sku_id, shard)
|
||||
|
||||
@classmethod
|
||||
async def init_sharded_stock(cls, sku_id: str, total_stock: int, shard_count: int = 10):
|
||||
"""
|
||||
初始化分片库存
|
||||
|
||||
:param sku_id: 商品ID
|
||||
:param total_stock: 总库存
|
||||
:param shard_count: 分片数量
|
||||
"""
|
||||
base = total_stock // shard_count
|
||||
remainder = total_stock % shard_count
|
||||
|
||||
async with await cls.get_redis() as _redis:
|
||||
for i in range(shard_count):
|
||||
shard_key = cls.get_shard_key(sku_id, i)
|
||||
stock = base + (1 if i < remainder else 0)
|
||||
await _redis.set(shard_key, stock)
|
||||
logging.echo_log(f"初始化分片 {i}: {shard_key} = {stock}")
|
||||
|
||||
# ========== 基础 KV 操作 ==========
|
||||
|
||||
@classmethod
|
||||
async def keys(cls):
|
||||
"""
|
||||
取得所有的 Key。
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
_keys = await _redis.keys()
|
||||
return _keys
|
||||
|
||||
@classmethod
|
||||
async def show_keys(cls):
|
||||
"""
|
||||
控制台显示所有的 Keys。
|
||||
"""
|
||||
_keys = await cls.keys()
|
||||
for _key in _keys:
|
||||
if isinstance(_key, bytes):
|
||||
if cls.is_java_serialized(_key):
|
||||
print(_key[7:].decode('utf-8'), '=>', _key)
|
||||
else:
|
||||
print(_key.decode('utf-8'), '=>', _key)
|
||||
else:
|
||||
print(_key)
|
||||
|
||||
@classmethod
|
||||
async def get(cls, key: Union[bytes, str]):
|
||||
"""
|
||||
多种方式读取 Redis 中的数据。
|
||||
|
||||
:param key: Redis Key 名称
|
||||
:return: 数据内容
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
_result = await _redis.get(key)
|
||||
|
||||
if _result is None and not cls.is_java_serialized(key):
|
||||
if isinstance(key, str):
|
||||
key_bytes = key.encode('utf-8')
|
||||
else:
|
||||
key_bytes = key
|
||||
_key = cls.prefix + cls.utf_flag + len(key_bytes).to_bytes(2, 'big') + key_bytes
|
||||
_result = await _redis.get(_key)
|
||||
|
||||
if _result is None:
|
||||
return _result
|
||||
|
||||
if isinstance(_result, bytes) and cls.is_java_serialized(_result):
|
||||
return javaobj.loads(_result)
|
||||
else:
|
||||
return _result
|
||||
|
||||
@classmethod
|
||||
async def set(cls, key: str, value: any, ex: int = None):
|
||||
"""
|
||||
设置键值对
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.set(key, value, ex=ex)
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, key: str):
|
||||
"""
|
||||
删除键
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.delete(key)
|
||||
|
||||
@classmethod
|
||||
async def exists(cls, key: str) -> bool:
|
||||
"""
|
||||
检查键是否存在
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.exists(key) > 0
|
||||
|
||||
@classmethod
|
||||
async def expire(cls, key: str, seconds: int):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.expire(key, seconds)
|
||||
|
||||
@classmethod
|
||||
async def incr(cls, key: str) -> int:
|
||||
"""
|
||||
原子递增
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.incr(key)
|
||||
|
||||
# ========== 回调处理 ==========
|
||||
|
||||
@classmethod
|
||||
def get_func_name(cls, func):
|
||||
"""
|
||||
得到方法名称。
|
||||
|
||||
:param func: 方法对象
|
||||
:return: 方法名称
|
||||
"""
|
||||
if isinstance(func, types.FunctionType):
|
||||
return func.__name__
|
||||
elif isinstance(func, types.MethodType):
|
||||
return func.__func__.__name__
|
||||
elif isinstance(func, (classmethod, staticmethod)):
|
||||
return func.__func__.__name__
|
||||
elif hasattr(func, '__call__'):
|
||||
return func.__class__.__name__
|
||||
else:
|
||||
return str(func)
|
||||
|
||||
@classmethod
|
||||
async def callback(cls, func: Callable, message_key: str, is_delete=False):
|
||||
"""
|
||||
根据消息 KEY 读取数据,并执行回调函数,如果回调函数正确执行,则根据参数 is_delete 判断删除消息。
|
||||
|
||||
:param func: 回调函数
|
||||
:param message_key: 消息 KEY
|
||||
:param is_delete: 是否删除处理过的消息
|
||||
"""
|
||||
result = None
|
||||
async with await cls.get_redis() as _redis:
|
||||
try:
|
||||
message_data = await _redis.hgetall(message_key)
|
||||
if not message_data:
|
||||
logging.echo_log(f"警告: 空消息数据 {message_key}.", level=WARNING)
|
||||
return result
|
||||
|
||||
if func:
|
||||
# 处理回调
|
||||
result = func(message_data)
|
||||
# 处理协程
|
||||
if isinstance(result, Awaitable):
|
||||
result = await result
|
||||
|
||||
if is_delete:
|
||||
# 回调正确执行,且设置为删除删除的,才会删除消息
|
||||
await _redis.delete(message_key)
|
||||
logging.echo_log(f"消息已删除: {message_key};数据为:{message_data}.")
|
||||
except redis.RedisError as e:
|
||||
logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True)
|
||||
except Exception as e:
|
||||
logging.echo_log(
|
||||
f"执行回调异常:{e};方法:{cls.get_func_name(func)};消息: {message_key}.",
|
||||
level=ERROR, is_log_exc=True
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class PubSubActor(Redis):
|
||||
"""
|
||||
发布订阅执行器。用于发布消息和订阅消息。
|
||||
订阅采用阻塞式读取,可以在读取到数据后,执行回调方法,并根据参数确定是否删除历史消息。
|
||||
"""
|
||||
|
||||
def __init__(self, hash_name: str):
|
||||
self.hash_name = f"{hash_name}_HASH_NAME"
|
||||
self.channel = f"{hash_name}_CHANNEL"
|
||||
|
||||
self.running = False
|
||||
"""
|
||||
优雅退出控制标志
|
||||
"""
|
||||
|
||||
self.stopping = False
|
||||
"""
|
||||
控制整个 run_forever 循环退出
|
||||
"""
|
||||
|
||||
async def publish(self, data: dict) -> str:
|
||||
"""
|
||||
数据写入 Redis 并发布消息。
|
||||
|
||||
:param data: 写入 Redis 的数据
|
||||
:return: 消息ID
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
# 生成雪花 ID 作为 Hash Key
|
||||
_random_num = random.randint(1000, 9999)
|
||||
_id = IdWorker.get_id_worker(3, 3, _random_num).get_id()
|
||||
|
||||
# 写入Redis hash
|
||||
await _redis.hset(f"{self.hash_name}:{_id}", mapping=data)
|
||||
# 发布新消息通知
|
||||
await _redis.publish(self.channel, _id)
|
||||
return _id
|
||||
|
||||
async def subscribe(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
监听消息。
|
||||
|
||||
:param func: 监听回调程序
|
||||
:param is_delete: 回调执行完毕后,是否删除消息
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
_pubsub = _redis.pubsub()
|
||||
await _pubsub.subscribe(self.channel)
|
||||
|
||||
try:
|
||||
self.running = True
|
||||
|
||||
# 使用 while 循环,而不是直接 async for,以便加入超时控制
|
||||
while not self.stopping and self.running:
|
||||
try:
|
||||
# 每次循环都重新获取迭代器
|
||||
listen_iter = _pubsub.listen()
|
||||
message = await asyncio.wait_for(listen_iter.__anext__(), timeout=60.0)
|
||||
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
|
||||
message_id = message["data"]
|
||||
message_key = f"{self.hash_name}:{message_id}"
|
||||
|
||||
try:
|
||||
# 隔离处理回调异常
|
||||
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
|
||||
await aio_pool.run_background_task(self.callback(func, message_key, is_delete), 10)
|
||||
# await self.callback(func, message_key, is_delete=is_delete)
|
||||
except Exception:
|
||||
# 继续处理下条消息
|
||||
continue
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时,是心跳成功的标志
|
||||
logging.echo_log("心跳:连接正常,继续监听...")
|
||||
continue
|
||||
except redis.exceptions.ConnectionError as e:
|
||||
# 连接错误,触发重连
|
||||
logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True)
|
||||
raise e
|
||||
except StopAsyncIteration:
|
||||
# pubsub 正常关闭
|
||||
logging.echo_log("PubSub 迭代器已停止.")
|
||||
break
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.echo_log("收到退出信号,停止监听...")
|
||||
self.running = False
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True)
|
||||
raise e
|
||||
finally:
|
||||
self.running = False
|
||||
try:
|
||||
await _pubsub.unsubscribe(self.channel)
|
||||
await _pubsub.close()
|
||||
except Exception as close_err:
|
||||
logging.echo_log(f"资源关闭异常: {close_err}.")
|
||||
finally:
|
||||
logging.echo_log("监听已完全停止.")
|
||||
|
||||
async def run_forever(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
持久运行的监听器,包含自动重连逻辑和优雅退出。
|
||||
"""
|
||||
while not self.stopping:
|
||||
try:
|
||||
logging.echo_log("启动新的监听会话...")
|
||||
await self.subscribe(func, is_delete)
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.echo_log("收到退出信号,停止监听...")
|
||||
self.stopping = True
|
||||
break
|
||||
except Exception as e:
|
||||
logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True)
|
||||
|
||||
if self.stopping:
|
||||
logging.echo_log("总开关已打开,停止重连.")
|
||||
break
|
||||
|
||||
logging.echo_log("等待重新连接...")
|
||||
try:
|
||||
# 关键改动:直接使用 sleep,它本身就是可中断的
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
logging.echo_log("等待期间被取消,准备退出.")
|
||||
break
|
||||
|
||||
logging.echo_log("监听服务已完全停止.")
|
||||
|
||||
async def history(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
处理历史数据。
|
||||
|
||||
:param func: 监听回调程序
|
||||
:param is_delete: 回调执行完毕后,是否删除消息
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
_keys = await _redis.keys()
|
||||
for _k in _keys:
|
||||
try:
|
||||
# 隔离处理回调异常
|
||||
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
|
||||
await aio_pool.run_background_task(self.callback(func, _k, is_delete), 10)
|
||||
# await self.callback(func, _k, is_delete=is_delete)
|
||||
except Exception:
|
||||
# 继续处理下条消息
|
||||
continue
|
||||
|
||||
def subscribe_stop(self):
|
||||
self.running = False
|
||||
self.stopping = True
|
||||
|
||||
|
||||
class StreamActor(Redis):
|
||||
"""
|
||||
流执行器。使用 Redis Streams 实现发布消息和消费消息。
|
||||
消费采用消费者组模式,支持消息确认和可靠传递。
|
||||
方法结构与 PubSubActor 保持一致,便于无缝替换。
|
||||
|
||||
此版本集成了启动时的僵尸任务自动恢复功能。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def actor_config(cls, config_path: str):
|
||||
"""
|
||||
根据路径,取得配置信息。
|
||||
|
||||
:param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔
|
||||
:return:
|
||||
"""
|
||||
_stream_name = config_path.split(".")[-1].upper()
|
||||
_stream_config = config.get_config(config_path)
|
||||
_group_name = _stream_config.get('group', f"{_stream_name}_GROUP")
|
||||
_consumer_name = _stream_config.get('consumer', f"{_stream_name}_CONSUMER")
|
||||
_snow_id = IdWorker.get_id_worker().get_id()
|
||||
_consumer_name = f"{_consumer_name}_{_snow_id}"
|
||||
return _stream_name, _group_name, _consumer_name
|
||||
|
||||
@classmethod
|
||||
def new_actor(cls, config_path: str):
|
||||
"""
|
||||
根据配置文件中的配置小节创建流执行器。
|
||||
|
||||
:param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔
|
||||
:return: 执行器对象
|
||||
"""
|
||||
_stream_name, _group_name, _consumer_name = cls.actor_config(config_path)
|
||||
return cls(_stream_name, _group_name, _consumer_name)
|
||||
|
||||
def __init__(self, stream_name: str, group_name: str, consumer_name: str):
|
||||
"""
|
||||
初始化流执行器。
|
||||
|
||||
:param stream_name: Redis Stream 的名称
|
||||
:param group_name: 消费者组的名称
|
||||
:param consumer_name: 当前消费者的名称
|
||||
"""
|
||||
self.stream_name = stream_name
|
||||
self.group_name = group_name
|
||||
self.consumer_name = consumer_name
|
||||
|
||||
self.running = False
|
||||
"""
|
||||
优雅退出控制标志
|
||||
"""
|
||||
|
||||
self.stopping = False
|
||||
"""
|
||||
控制整个 run_forever 循环退出
|
||||
"""
|
||||
|
||||
async def _ensure_group_exists(self):
|
||||
"""确保消费者组已存在,如果不存在则创建。"""
|
||||
try:
|
||||
_redis = await self.get_redis()
|
||||
await _redis.xgroup_create(
|
||||
name=self.stream_name,
|
||||
groupname=self.group_name,
|
||||
id='0', # 从头开始消费
|
||||
mkstream=True # Stream 不存在时自动创建
|
||||
)
|
||||
logging.echo_log(f"消费者组 '{self.group_name}' 已创建.")
|
||||
except redis.exceptions.ResponseError as e:
|
||||
if "Consumer Group name already exists" in str(e):
|
||||
logging.echo_log(f"消费者组 '{self.group_name}' 已存在.")
|
||||
else:
|
||||
raise
|
||||
|
||||
async def publish(self, data: dict) -> str:
|
||||
"""
|
||||
将数据作为消息写入 Redis Stream。
|
||||
|
||||
:param data: 写入 Stream 的数据字典
|
||||
:return: 消息ID
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
# 添加时会自动生成唯一的消息ID
|
||||
message_id = await _redis.xadd(name=self.stream_name, fields=data)
|
||||
logging.echo_log(f"消息已发布至 Stream '{self.stream_name}',ID: {message_id};数据为:{data}.")
|
||||
return message_id
|
||||
|
||||
async def reclaim_stale_tasks(self, func: Callable, is_delete: bool, stale_threshold_ms: int = 5 * 60 * 1000):
|
||||
"""
|
||||
检查并尝试重新处理僵尸任务。
|
||||
|
||||
Args:
|
||||
func (Callable): 用于处理任务的业务回调函数。
|
||||
is_delete (bool): 处理成功后是否确认消息。
|
||||
stale_threshold_ms (int): 判定为僵尸任务的空闲时间阈值(毫秒)。
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
# 1. 发现僵尸任务
|
||||
try:
|
||||
stale_tasks = await _redis.xpending_range(
|
||||
name=self.stream_name,
|
||||
groupname=self.group_name,
|
||||
min='-',
|
||||
max='+',
|
||||
count=10, # 每次最多处理10个僵尸任务,避免启动时阻塞太久
|
||||
idle=stale_threshold_ms
|
||||
)
|
||||
except Exception as e:
|
||||
logging.echo_log(f"检查僵尸任务时出错: {e}", level=ERROR, is_log_exc=True)
|
||||
return
|
||||
|
||||
if not stale_tasks:
|
||||
logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.")
|
||||
return
|
||||
|
||||
if not stale_tasks or not isinstance(stale_tasks, list):
|
||||
logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.")
|
||||
return
|
||||
message_ids = [task['message_id'] for task in stale_tasks]
|
||||
logging.echo_log(f"发现 {len(message_ids)} 个僵尸任务,尝试认领并重新处理...")
|
||||
|
||||
# 2. 认领任务
|
||||
try:
|
||||
reclaimed_messages = await _redis.xclaim(
|
||||
name=self.stream_name,
|
||||
groupname=self.group_name,
|
||||
consumername=self.consumer_name, # 认领给自己
|
||||
min_idle_time=stale_threshold_ms,
|
||||
message_ids=message_ids,
|
||||
justid=False # 我们需要消息内容来处理
|
||||
)
|
||||
except Exception as e:
|
||||
logging.echo_log(f"认领僵尸任务时出错: {e}", level=ERROR, is_log_exc=True)
|
||||
return
|
||||
|
||||
if not reclaimed_messages:
|
||||
logging.echo_log("未能成功认领任何僵尸任务.")
|
||||
return
|
||||
|
||||
logging.echo_log(f"成功认领 {len(reclaimed_messages)} 个僵尸任务,开始处理.")
|
||||
|
||||
# 3. 处理被认领的任务
|
||||
for message_id, message_data in reclaimed_messages:
|
||||
# 使用我们已有的 _callback_wrapper 来处理,保证逻辑一致
|
||||
await self._callback_wrapper(
|
||||
func=func,
|
||||
message_id=message_id,
|
||||
message_data=message_data,
|
||||
is_delete=is_delete
|
||||
)
|
||||
|
||||
async def history(self, func: Callable, is_delete: bool):
|
||||
"""
|
||||
启动时的恢复程序。
|
||||
检查并处理长时间未完成的僵尸任务,确保系统健壮性。
|
||||
"""
|
||||
logging.echo_log("执行启动恢复程序,检查僵尸任务...")
|
||||
# 将 func 和 is_delete 传递下去
|
||||
await self.reclaim_stale_tasks(func=func, is_delete=is_delete, stale_threshold_ms=5 * 60 * 1000)
|
||||
logging.echo_log("启动恢复程序执行完毕.")
|
||||
|
||||
async def subscribe(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
从消费者组中监听并处理新消息。
|
||||
启动时会先执行恢复程序,处理僵尸任务。
|
||||
"""
|
||||
await self._ensure_group_exists()
|
||||
|
||||
# === 核心改动:启动时先执行恢复,并传入回调参数 ===
|
||||
await self.history(func=func, is_delete=is_delete)
|
||||
|
||||
async with await self.get_redis() as _redis:
|
||||
try:
|
||||
self.running = True
|
||||
logging.echo_log("僵尸任务恢复完成,开始监听新消息...")
|
||||
|
||||
while not self.stopping and self.running:
|
||||
try:
|
||||
# 阻塞读取新消息
|
||||
streams = await _redis.xreadgroup(
|
||||
groupname=self.group_name,
|
||||
consumername=self.consumer_name,
|
||||
streams={self.stream_name: '>'}, # '>' 表示只读取新消息
|
||||
count=1,
|
||||
block=5000 # 5秒超时,类似 PubSub 的心跳
|
||||
)
|
||||
|
||||
if not streams:
|
||||
# 超时,是心跳成功的标志
|
||||
logging.echo_log("心跳:连接正常,继续监听...")
|
||||
continue
|
||||
|
||||
# 解析消息
|
||||
stream, messages = streams[0]
|
||||
message_id, message_data = messages[0]
|
||||
logging.echo_log(f"收到新消息: ID={message_id}, 数据={message_data}")
|
||||
|
||||
try:
|
||||
# 隔离处理回调异常
|
||||
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
|
||||
await aio_pool.run_background_task(
|
||||
self._callback_wrapper(func, message_id, message_data, is_delete), 10
|
||||
)
|
||||
except Exception:
|
||||
# 回调处理失败,消息未被确认,将留在队列中稍后重试
|
||||
continue
|
||||
|
||||
except redis.exceptions.ConnectionError as e:
|
||||
# 连接错误,触发重连
|
||||
logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True)
|
||||
raise e
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.echo_log("收到退出信号,停止监听...")
|
||||
self.running = False
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True)
|
||||
raise e
|
||||
finally:
|
||||
self.running = False
|
||||
logging.echo_log("Stream 监听已完全停止.")
|
||||
|
||||
async def _callback_wrapper(self, func: Callable, message_id: str, message_data: dict, is_delete: bool):
|
||||
"""
|
||||
一个包装器,用于将 Stream 的消息处理逻辑适配到基类的 callback 方法签名上。
|
||||
这样做可以复用基类 callback 中的异常处理逻辑。
|
||||
"""
|
||||
# 如果没有提供回调函数,则无法处理,直接返回,避免丢失任务
|
||||
if not func:
|
||||
logging.echo_log(f"警告: 收到消息 {message_id} 但未提供业务回调函数,消息将被忽略.", level=WARNING)
|
||||
return
|
||||
|
||||
# 不能直接调用基类的 callback,因为它会尝试删除
|
||||
# 在这里复制它的异常处理逻辑,但使用 Stream 的操作
|
||||
result = None
|
||||
async with await self.get_redis() as _redis:
|
||||
try:
|
||||
if func:
|
||||
# 处理回调
|
||||
result = func(message_data)
|
||||
# 处理协程
|
||||
if isinstance(result, Awaitable):
|
||||
result = await result
|
||||
|
||||
if is_delete:
|
||||
# 先从 PENDING 列表中移除
|
||||
await _redis.xack(self.stream_name, self.group_name, message_id)
|
||||
# 再从 Stream 中逻辑删除
|
||||
await _redis.xdel(self.stream_name, message_id)
|
||||
logging.echo_log(f"消息已确认 (ACK): {message_id};数据为:{message_data}.")
|
||||
except redis.RedisError as e:
|
||||
logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True)
|
||||
except Exception as e:
|
||||
logging.echo_log(
|
||||
f"执行回调异常:{e};方法:{self.get_func_name(func)};消息: {message_id}.",
|
||||
level=ERROR, is_log_exc=True
|
||||
)
|
||||
return result
|
||||
|
||||
async def run_forever(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
持久运行的监听器,包含自动重连逻辑和优雅退出。
|
||||
"""
|
||||
while not self.stopping:
|
||||
try:
|
||||
logging.echo_log("启动新的监听会话...")
|
||||
await self.subscribe(func, is_delete)
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.echo_log("收到退出信号,停止监听...")
|
||||
self.stopping = True
|
||||
break
|
||||
except Exception as e:
|
||||
logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True)
|
||||
|
||||
if self.stopping:
|
||||
logging.echo_log("总开关已打开,停止重连.")
|
||||
break
|
||||
|
||||
logging.echo_log("等待重新连接...")
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
logging.echo_log("等待期间被取消,准备退出.")
|
||||
break
|
||||
|
||||
logging.echo_log("监听服务已完全停止.")
|
||||
|
||||
def subscribe_stop(self):
|
||||
self.running = False
|
||||
self.stopping = True
|
||||
Reference in New Issue
Block a user