""" 数据访问适配器。主要封装了数据访问的基础方法。 """ import datetime import random import uuid from typing import Optional, Any, Union, List, Sequence import pandas as pd from sqlalchemy import Column, String, DateTime, Date, inspect, Table, Integer, select, Numeric, ForeignKey, func, \ text, desc, Result, Engine, tuple_ from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker from sqlalchemy.orm import sessionmaker, Session, InstanceState, InstrumentedAttribute, registry from sqlalchemy.sql import Select from sqlalchemy.sql.base import ColumnCollection from sqlalchemy.sql.compiler import StrSQLCompiler from sqlalchemy.sql.elements import BinaryExpression from paste.db import engine from paste.util import udict LOCAL_DATE_FORMAT = '%Y-%m-%d' LOCAL_TIME_FORMAT = '%H:%M:%S' LOCAL_DATETIME_FORMAT = f'{LOCAL_DATE_FORMAT} {LOCAL_TIME_FORMAT}' def guid() -> str: """ 生成 GUID。 :return: GUID """ return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(uuid.uuid1()) + str(random.random()))).replace('-', '') def get_session() -> Session: """ 获取线程安全的连接会话对象。 :return: 连接会话对象 """ _engine: Engine = engine.connect_engine() session_maker = sessionmaker(bind=_engine) return session_maker() def get_aio_session() -> AsyncSession: """ 获取异步连接会话对象。 :return: 连接会话对象 """ _engine: AsyncEngine = engine.async_connect_engine() session_maker = async_sessionmaker(bind=_engine, class_=AsyncSession) return session_maker() Base = registry().generate_base() """ 取得 SQLAlchemy ORM 的声明式基类。 """ class BaseAdapter(Base): """ 适配器类,集成数据库引擎、连接、会话、元数据操作、部分 DML 数据处理等功能。 主要实现了同步执行方法。 """ __abstract__ = True def __init__(self, **kwargs): """ 构造模型对象。 :param args: :param kwargs: """ self._is_new = True super().__init__(**kwargs) @classmethod def get_session(cls): """ 线程安全的连接会话。 :return: 同步会话对象 """ return get_session() @classmethod def get_aio_session(cls): """ 异步连接会话。 :return: 异步会话对象 """ return get_aio_session() @classmethod def ping(cls): """ 尝试连接数据库服务器。 :return: 连接结果 """ engine.connect_engine().connect() @classmethod async def tables_in_db(cls): """ 取得数据库中所有表名称。 :return: 表名称列表 """ _query = select(text('table_name')).select_from( text('information_schema.tables') ).where( text('table_schema = database()') ) _names = list() _session = cls.get_aio_session() try: _result: Result = await _session.execute(_query) _names = list(_result.scalars().all()) finally: await _session.close() return _names @classmethod async def is_table_exist(cls, table_name: str): """ 判断表是否存在。 :param table_name: 要判断的表名称 """ _query = select(func.count()).select_from( text('information_schema.tables') ).where( text('table_schema = database()'), text('table_name = :table_name') ).params( table_name=table_name ) _count = 0 _session = cls.get_aio_session() try: _result: Result = await _session.execute(_query) _count = int(_result.scalar()) finally: await _session.close() return _count > 0 @classmethod def table(cls) -> Table: """ 取得当前模型对象所对应的表对象。 :return: 表对象 """ tables = cls.metadata.tables assert cls.__tablename__ is not None, '属性:cls.__tablename__ 未找到.' assert cls.__tablename__ in tables, ('数据库中未找到表:%s.' % cls.__tablename__) return tables[cls.__tablename__] @classmethod def columns(cls) -> List[Column]: """ 当前对象配置的所有列。 :return: 所有列 """ cols = cls.table().columns assert isinstance(cols, ColumnCollection), '%s 列类型错误' % cls.__name__ return list(cols) @classmethod def instrument_attr(cls, column_name: str) -> Optional[InstrumentedAttribute]: """ 取得查询条件绑定属性,用于配置查询条件表达式。 :param column_name: 列名称 :return: 类配置的列信息,即查询条件绑定属性 """ attr = getattr(cls, column_name) if not isinstance(attr, InstrumentedAttribute): return None return attr @classmethod def label(cls, column: Union[Column, Any]): """ 以字段的 comment 作为标签返回。 :return: 指定列的备注 """ return column.comment @classmethod def labels(cls) -> dict: """ 取得所有列的名称描述。 :return: 所有列的描述字典,结构为:{`field`: `comment`} """ _label_dict = {} for _col in cls.columns(): _label_dict[_col.key] = _col.comment return _label_dict @classmethod def field(cls, column: Union[Column, Any]) -> str: """ 取得字段名称。 :param column: 列对象 :return: 字段名称 """ return column.key @classmethod def fields(cls) -> List[str]: """ 取得所有字段名称列表。 :return: 所有字段名称列表 """ return [col.key for col in cls.columns()] @classmethod def new_id(cls, **kwargs) -> str: """ 用GUID作为新的ID。 :return: ID """ return guid() @property def is_new(self): """ 通过检查是否已经存在ID判断是否为新建对象。 :return: 是否为新建对象 """ return hasattr(self, '_is_new') and self._is_new @classmethod def raw_sql(cls, query) -> StrSQLCompiler: """ 显示编译后的原始 SQL 命令。 :return: 编译后的 SQL 命令文本 """ raw_sql = query.compile(compile_kwargs={'literal_binds': True}) return raw_sql @classmethod def row_count(cls, *where_expressions: BinaryExpression) -> int: """ 按条件取得记录行数。 :param where_expressions: 查询条件 :return: 返回记录行数 """ query = select(func.count(1)).select_from(cls).where(*where_expressions) session = cls.get_session() count = session.execute(query).scalars().first() session.close() return count @classmethod def exist_by_id(cls, id_val: Union[str, int]): """ 检查是否有存在的主键。 :param id_val: 主键值 """ _wheres = [cls.instrument_attr('id') == id_val] count: int = cls.row_count(*_wheres) return count >= 1 @classmethod def search_wheres(cls, likes: Optional[dict[str, str]] = None, **kwargs): """ 按参数组织查询条件。 :param likes: 需要执行模糊查询的列 :param kwargs: 属性填充参数 :return: 查询条件列表 """ if likes is None: likes = [] _query_model = cls().copy_from_dict(kwargs) return _query_model.filter_expressions(likes=likes) @classmethod def datalist_query(cls, row_list: List[dict], condition_cols: List[Column]) -> Optional[Select]: """ 根据数据列表、查询条件列生成查询。仅支持单表查询。 :param row_list: 数据列表 :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 :return: 查询到的数据模型列表 """ if not row_list or not condition_cols: return None _condition_fields = [_col.key for _col in condition_cols] # 验证字段存在 for field in _condition_fields: if not hasattr(cls, field): raise AttributeError(f"Field '{field}' not found in {cls.__name__}") # 构建值列表的元组形式 _values_tuples = [] for _row in row_list: _val_list = [] for _f in _condition_fields: _val = udict.get_with_default(_row, _f, '') # 单独处理日期时间格式,但是此处固定格式 if isinstance(_val, datetime.datetime): _val = _val.strftime(LOCAL_DATETIME_FORMAT) if isinstance(_val, datetime.date): _val = _val.strftime(LOCAL_DATE_FORMAT) _val_list.append(_val) # 过滤掉全为空的元组 if any(v != '' for v in _val_list): _values_tuples.append(tuple(_val_list)) if not _values_tuples: return None # 去重,提高性能 _values_tuples = list(set(_values_tuples)) # 使用参数化查询 if len(condition_cols) == 1: # 单字段查询 field = getattr(cls, _condition_fields[0]) _single_values = [t[0] for t in _values_tuples] query = select(cls).where(field.in_(_single_values)) else: # 多字段组合查询 conditions = [getattr(cls, field) for field in _condition_fields] query = select(cls).where(tuple_(*conditions).in_(_values_tuples)) return query @classmethod def dataframe_query(cls, row_df: pd.DataFrame, condition_cols: List[Column]): """ 根据数据列表、查询条件列生成查询。仅支持单表查询。 :param row_df: 数据框架 :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 :return: 查询到的数据模型列表 """ if row_df.empty or not condition_cols: return None _condition_fields = [_col.key for _col in condition_cols] # 验证字段存在 for field in _condition_fields: if not hasattr(cls, field): raise AttributeError(f"Field '{field}' not found in {cls.__name__}") # 去除重复行并删除包含 NaN 的行 row_df = row_df[_condition_fields].drop_duplicates().dropna() if row_df.empty: return None if len(condition_cols) == 1: # 单字段查询 _field_name = _condition_fields[0] field = getattr(cls, _field_name) # 直接获取 Series 的值 _values_list = row_df[_field_name].tolist() if not _values_list: return None query = select(cls).where(field.in_(_values_list)) else: # 多字段组合查询 - 使用向量化操作避免 iterrows() _values_tuples = [tuple(row) for row in row_df.values] # 再次去重确保 _values_tuples = list(set(_values_tuples)) if not _values_tuples: return None conditions = [getattr(cls, field) for field in _condition_fields] query = select(cls).where(tuple_(*conditions).in_(_values_tuples)) return query @classmethod def find_by_id(cls, id_val: Union[str, int], reset_session: Optional[bool] = True) -> Optional['BaseAdapter']: """ 根据主键ID查找数据,确认有 id 字段后方可使用。 :param id_val: 主键值 :param reset_session: 重置会话连接 :return: 数据模型对象 """ query = select(cls).where(cls.instrument_attr('id') == id_val) session = cls.get_session() model = session.execute(query).scalars().first() session.close() if reset_session and isinstance(model, BaseAdapter): model.reset_session() return model @classmethod def find_by_created(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']: """ 按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。 :param from_t: 开始时间 :param to_t: 结束时间 """ if to_t is None: to_t = from_t query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t)) session = cls.get_session() rows = session.execute(query).scalars().all() session.close() return rows @classmethod def find_by_updated(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']: """ 按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。 :param from_t: 开始时间 :param to_t: 结束时间 """ if to_t is None: to_t = from_t query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t)) session = cls.get_session() rows = session.execute(query).scalars().all() session.close() return rows @classmethod def find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]) -> Sequence['BaseAdapter']: """ 根据数据列表查询已经在数据库中的数据。 :param row_list: 数据列表 :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 :return: 查询到的数据模型列表 """ rows: Sequence[cls] = [] _query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols) if _query is not None: session = cls.get_session() rows = session.execute(_query).scalars().all() session.close() return rows @classmethod def find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]): """ 根据数据框架查询已经在数据库中的数据。 :param row_df: 数据框架 :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 :return: 查询到的数据模型列表 """ _query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols) if _query is not None: session = cls.get_session() _rows = session.connection().execute(_query).all() session.close() return pd.DataFrame(_rows) return None def load(self, model_data: dict): """ 从字典对象载入数据。 该方法仅从字典中读取与对象属性对应的数据,忽略其他数据。 注意:与 copyXXX 方法不同,该方法跳过 model_data 中的 None 值,保留原始值不变 :param model_data: 包含数据的字典对象 """ for attr in self.__dir__(): if model_data.get(attr, None) is not None: self.__setattr__(attr, model_data.get(attr)) return self def copy_from(self, source: 'BaseAdapter', mapping_fields: Optional[List[str]] = None): """ 从源数据模型复制数据,仅复制相同字段的数据。 :param source: 源数据对象 :param mapping_fields: 映射字段列表,若为空,则从源数据模型中获取字段列表 """ # 源对象字段列表 if mapping_fields is None: mapping_fields = [column.key for column in source.columns()] for column in self.columns(): # 若源对象中不包含同名字段,则跳过 if column.key not in mapping_fields: continue # 取出对象属性值 attr_val = source.ins_value(column.key) setattr(self, column.key, attr_val) return self def copy_from_dict(self, source: dict, mapping_keys: Union[list, tuple, set] = None, skip_none: bool = False): """ 从源字典对象复制数据,仅复制相同键值的数据。 :param source: 源数据对象 :param mapping_keys: 映射键列表,若为空,则从源字典对象中获取键列表 :param skip_none: 是否跳过 None 值,即若源数据对象中属性值为 None 时,跳过 :return 返回自身 """ # 源对象字段列表 if mapping_keys is None: mapping_keys = source.keys() for column in self.columns(): if column.key not in mapping_keys: # 若源对象中不包含同名字段,则跳过 continue # 取出对象属性值 attr_val = source.get(column.key, None) if skip_none and attr_val is None: # 若设置了跳过 None 且此时属性值为空时,跳过 continue setattr(self, column.key, attr_val) return self def inspect(self) -> InstanceState: """ 取得对象的监视对象。 :return: 监视对象 """ return inspect(self) def session(self) -> Session: """ 取得当前对象的连接会话对象。 :return: 连接会话对象 """ return self.inspect().session def has_session(self) -> bool: """ 检测对象是否已有连接会话对象。 :return: 有则返回 True 否则返回 False """ return self.session() is not None def add_session(self): """ 将对象加入到连接会话对象。 """ if not self.has_session(): self.get_session().add(self) return self.session() def close_session(self): """ 关闭会话。 """ if self.has_session(): self.session().close() def reset_session(self): """ 重置对象的连接会话。 """ self.close_session() self.add_session() def ins_value(self, column_name: str) -> Any: """ 取得查询条件绑定的值,该值与对象属性一致。与:: self.column_name 或 self['column_name'] :param column_name: 列名 :return: 对象属性值 """ # 取出对象属性值 attr_val = getattr(self, column_name) if attr_val is None: return None return attr_val def filter_expressions(self, likes: Optional[dict[str, str]] = None) -> List[BinaryExpression]: """ 自动生成查询条件表达式。 参数 likes 为需要执行模糊查询的字段字典:: { field_name1: '%{}%', field_name2: '{}%', field_name3: '%{}', ... } :param likes: 需要执行模糊查询的字段字典 :return: 包含查询条件表达式的列表 """ expressions = list() for column in self.columns(): # 取出对象属性值 attr_val = self.ins_value(column.key) # 若属性值为空,不增加条件,跳出 if attr_val in (None, ''): continue # 取出类属性 attr = self.instrument_attr(column.key) # 若类属性类型不正确,不增加条件,跳出 if attr is None: continue # 取出列的数据类型 column_type = column.type if isinstance(column_type, (String, Integer, Numeric, ForeignKey)): # 针对属性数据类型,进行不同处理 if isinstance(attr_val, (list, tuple, set)): if not attr_val: # 跳过0长数组 continue # 数组使用 in expressions.append(attr.in_(attr_val)) elif isinstance(attr_val, (int, float)): # 数值类型 expressions.append(attr == attr_val) else: # 字符类型,使用 like if likes and attr.key in likes: _f_str = likes[attr.key] expressions.append(attr.like(_f_str.format(attr_val))) else: expressions.append(attr == attr_val) elif isinstance(column_type, (DateTime, Date)): if isinstance(attr_val, (list, tuple, set)): if len(attr_val) == 2: # 如果长度为 2 则使用 between and 方式 expressions.append(attr.between(attr_val[0], attr_val[1])) else: # 数组使用 in expressions.append(attr.in_(attr_val)) else: expressions.append(attr.between(attr_val, attr_val)) else: # 其他类型,用等号 expressions.append(attr == attr_val) return expressions def gen_query(self, likes: Optional[dict[str, str]] = None): """ 根据自身参数生成查询对象。 :param likes: 需要执行模糊查询的字段字典 :return: 查询对象 """ cls = self.__class__ expressions = self.filter_expressions(likes) _query: Select = select(cls).where(*expressions) return _query def find(self) -> Sequence['BaseAdapter']: """ 根据自身参数,查询数据库。 :return: 查询到的结果对象 """ session = self.get_session() _model_list = session.execute(self.gen_query()).scalars().all() session.close() return _model_list def find_first(self) -> 'BaseAdapter': """ 根据自身参数,查询数据库,仅查询第一条。 :return: 查询到的结果对象 """ session = self.get_session() _model = session.execute(self.gen_query()).scalars().first() session.close() return _model def find_piece(self, *where: BinaryExpression, offset=0, limit=500, is_desc=False, likes: Optional[dict[str, str]] = None): """ 根据自身参数,查询数据库。 :param where: 查询条件 :param offset: 偏移量 :param limit: 读取数量 :param is_desc: 是否逆序排列 :param likes: 模糊条件 :return: 查询到的结果对象 """ clz = self.__class__ expressions = self.filter_expressions(likes=likes) if where is not None: expressions += where query = select(clz).where(*expressions) if limit > 0: query = query.limit(limit=limit) if offset >= 0: query = query.offset(offset=offset) if is_desc: if hasattr(clz, 'id'): query = query.order_by(desc(clz.id)) session = self.get_session() _model_list = session.execute(query).scalars().all() session.close() return _model_list def before_save(self): """ 保存前的动作,一般应当在子类中覆盖该方法,增加在保存前应当执行的动作。 """ if self.is_new: if hasattr(self, 'id'): setattr(self, 'id', self.new_id()) if hasattr(self, 'created_at'): setattr(self, 'created_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT)) if hasattr(self, 'updated_at'): setattr(self, 'updated_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT)) return self def save(self, auto_expunge: Optional[bool] = True, session: Optional[Session] = None): """ 保存数据模型对象,若该对象尚未有连接会话,则自动加入连接会话。 :param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效 :param session: 会话对象,主要用于事务 :return: 保存状态 """ _has_session: bool = True """ 该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。 """ self.before_save() if session is None: _has_session = False session = self.add_session() try: session.add(self) if not _has_session: # 使用新会话时,直接提交 session.commit() self._is_new = False except Exception as e: session.rollback() raise e else: if auto_expunge and not _has_session: session.refresh(self) session.expunge(self) return True def to_dict(self) -> dict: """ 数据模型转字典。递归处理内部类型对象。 :return: 转换后的字典数据 """ # 模型数据字典 m_dict = {} # 遍历处理内部转换 for _key, _val in dict(self.__dict__).items(): if f'{_key}'.startswith('_'): # 跳过所有私有属性 continue if isinstance(_val, BaseAdapter): # 内部数据对象数据,直接转换 m_dict[_key] = _val.to_dict() elif isinstance(_val, list): # 遍历转换数据对象列表 _tmp_list = [] for _i, _v in enumerate(_val): if isinstance(_v, BaseAdapter): _tmp_list.append(_v.to_dict()) else: _tmp_list.append(_v) m_dict[_key] = _tmp_list elif isinstance(_val, dict): # 遍历转换数据对象字典 _tmp_dict = {} for _ik, _iv in _val.items(): if isinstance(_iv, BaseAdapter): _tmp_dict[_ik] = _iv.to_dict() else: _tmp_dict[_ik] = _iv m_dict[_key] = _tmp_dict else: # 其他属性直接赋值 m_dict[_key] = _val return m_dict