from sqlalchemy import select, delete from paste.rbac.rbac_item_child import RbacItemChild from paste.rbac.rbac_models import RbacItemModel class RbacItem(RbacItemModel): """ 授权项。 分为角色和权限两类,分别由对应子类实现。 """ TYPE_ROLE = 1 """ 角色类型。 """ TYPE_PERMISSION = 2 """ 权限类型。 """ @classmethod async def all_parents(cls, item_names: set[str]): """ 按层次,递归查询授权项列表中所有授权项的父授权项名称。 :param item_names: 授权项名称列表 :return: 所有各层级的父授权项 """ _query = select(RbacItemChild.parent).where(RbacItemChild.child.in_(item_names)) _rows = await cls.query_all(_query) _item_names: set[str] = set(_rows) if _item_names: _parents = await cls.all_parents(_item_names) _item_names = _item_names.union(_parents) return _item_names @classmethod async def all_children(cls, item_names: set[str]): """ 按层次,递归查询授权项列表中所有授权项的子授权项名称。 :param item_names: 授权项名称列表 :return: 所有各层级的子授权项 """ _query = select(RbacItemChild.child).where(RbacItemChild.parent.in_(item_names)) _rows = await cls.query_all(_query) _item_names: set[str] = set(_rows) if _item_names: _children = await cls.all_children(_item_names) _item_names = _item_names.union(_children) return _item_names @classmethod async def find_by_name(cls, name: str): """ 根据授权项名称查找授权项。 :param name: 授权项名称 :return: 授权项 """ _query = select(cls).where(cls.name == name) _model: cls = await cls.query_first(_query) return _model async def add_children(self, item_names: set[str]): """ 增加子授权项,自动剔除已包含的角色或授权项。角色和权限都能增加子授权项。 :param item_names: 待分配的子授权项名称列表 """ # 首先根据授权项名称列表查出所有的授权项,剔除错误的名称 _query = select(RbacItem.name).where(RbacItem.name.in_(item_names)) _rows = await self.query_all(_query) item_names: set[str] = set(_rows) # 取得所有祖先,确保所有子授权项,没有出现在祖先中,防止循环授权 _all_parents = await self.all_parents(item_names=item_names) # 查库,过滤已经存在的子授权项 _query = select(RbacItemChild.child).where( RbacItemChild.parent == self.name, RbacItemChild.child.in_(item_names) ) _rows = await self.query_all(_query) _exist_children: set[str] = set(_rows) # 创建模型列表,剔除已包含的角色或授权项,剔除出现在祖先中的授权项,以及剔除自身 _new_children = [ RbacItemChild(**{RbacItemChild.parent.key: self.name, RbacItemChild.child.key: _name}) for _name in item_names if _name not in _exist_children and _name not in _all_parents and _name != self.name ] # 保存数据 _session = self.get_aio_session() try: _session.add_all(_new_children) await _session.commit() except Exception as e: await _session.rollback() raise e finally: await _session.close() async def get_children(self): """ 通过中间关系查询所有子授权项。注意:查询的是直接子授权项,不包含继承获得的授权项。 :return: 子权限项列表 """ _query = select(RbacItem).join( RbacItemChild, RbacItemChild.child == RbacItem.name ).where( RbacItemChild.parent == self.name, ) _item_model_list: list[RbacItem] = await self.query_all(_query) return _item_model_list async def remove_children(self, item_names: set[str]): """ 删除子授权项。 :param item_names: 子授权项名称集合 :return: 成功删除的项数 """ _delete = delete(RbacItemChild).where(RbacItemChild.parent == self.name, RbacItemChild.child.in_(item_names)) _result = await self.raw_execute(_delete) return _result.rowcount