首次提交
This commit is contained in:
Executable
+236
@@ -0,0 +1,236 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
from types import ModuleType
|
||||
from typing import Optional, Any
|
||||
|
||||
import tornado
|
||||
from tornado.routing import URLSpec
|
||||
from tornado.web import OutputTransform, _RuleList
|
||||
|
||||
|
||||
class Application(tornado.web.Application):
|
||||
"""
|
||||
从 Tornado 派生的应用程序类。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def modules_iterator(cls, package: [str, ModuleType]):
|
||||
"""
|
||||
从 package 包装载所有模块。这里返回模块迭代器。
|
||||
若为字符串,则直接从目录中载入模块;若为模块,则根据模块的参数装载。
|
||||
|
||||
:param package: 包,允许为路径或包(模块对象)
|
||||
:return: 模块迭代器
|
||||
"""
|
||||
if isinstance(package, str):
|
||||
package = importlib.import_module(package)
|
||||
|
||||
# 模块迭代器,能够遍历出所有子包中的子模块
|
||||
return pkgutil.walk_packages(package.__path__, f"{package.__name__}.")
|
||||
|
||||
@classmethod
|
||||
def fetch_handlers(cls, module: ModuleType):
|
||||
"""
|
||||
查找模块中所有的请求处理类,即类 RequestHandler 的子类。
|
||||
|
||||
:param module: 模块
|
||||
:return: [(路由模式, 请求处理类)]
|
||||
"""
|
||||
# 判断是否是有效的 RequestHandler 类,且是 RequestHandler 的子类
|
||||
def is_handler(handler_cls):
|
||||
return isinstance(handler_cls, type) and issubclass(handler_cls, tornado.web.RequestHandler)
|
||||
|
||||
# 判断是否拥有 route_pattern 模式属性,且该属性值为字符串类型
|
||||
def has_pattern(handler_cls):
|
||||
return hasattr(handler_cls, 'route_pattern') and isinstance(getattr(handler_cls, 'route_pattern'), str)
|
||||
|
||||
handlers: list[tuple[str, ModuleType]] = []
|
||||
# 迭代模块成员
|
||||
for _n in dir(module):
|
||||
_cls = getattr(module, _n)
|
||||
is_hdl = is_handler(_cls)
|
||||
has_pat = has_pattern(_cls)
|
||||
if is_hdl and has_pat:
|
||||
_route = _cls.route_pattern
|
||||
handlers.append((_route, _cls))
|
||||
|
||||
return handlers
|
||||
|
||||
@classmethod
|
||||
def load_ui_modules(cls, ui_modules_config):
|
||||
"""
|
||||
将JSON配置中的模块字符串转换为实际的类。
|
||||
"""
|
||||
loaded_modules = {}
|
||||
for name, path in ui_modules_config.items():
|
||||
try:
|
||||
module_path, class_name = path.rsplit('.', 1)
|
||||
module = importlib.import_module(module_path)
|
||||
loaded_modules[name] = getattr(module, class_name)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise RuntimeError(f"Failed to load UIModule {name} from {path}: {str(e)}")
|
||||
return loaded_modules
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: Optional[_RuleList] = None,
|
||||
handlers_pkg: [str, ModuleType] = None,
|
||||
uri_prefix: str = "",
|
||||
**settings: Any
|
||||
) -> None:
|
||||
"""
|
||||
重写应用程序构造函数,增加自动装载功能。
|
||||
|
||||
:param handlers: 请求处理路由配置列表
|
||||
:param handlers_pkg: 执行自动装载的请求处理类所在包
|
||||
:param uri_prefix: URI 前缀
|
||||
:param settings: 其他配置
|
||||
"""
|
||||
|
||||
self.routes: list[(URLSpec, _RuleList)] = []
|
||||
"""
|
||||
请求处理路由列表。
|
||||
"""
|
||||
|
||||
if uri_prefix:
|
||||
uri_prefix = uri_prefix if uri_prefix.startswith('/') else f"/{uri_prefix}"
|
||||
|
||||
self.uri_prefix = uri_prefix
|
||||
"""
|
||||
统一资源标识符前缀。仅支持动态加载的请求处理类。
|
||||
"""
|
||||
|
||||
# 合并构造参数中的请求处理路由
|
||||
if handlers:
|
||||
self.routes.extend(handlers)
|
||||
|
||||
# 动态加载请求处理类,并执行合并
|
||||
if handlers_pkg:
|
||||
self.routes.extend(self.load_handlers(handlers_pkg=handlers_pkg))
|
||||
|
||||
self.before_create()
|
||||
|
||||
super().__init__(handlers=self.routes, **settings)
|
||||
|
||||
def before_create(self):
|
||||
"""
|
||||
在创建应用之前执行。
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_handlers(self, handlers_pkg: [str, ModuleType] = None):
|
||||
"""
|
||||
从 handlers_pkg 指定的包装载所有模块,分析出所有请求处理类和路由路径,并返回。
|
||||
|
||||
:param handlers_pkg: 模块根目录,允许为路径或包(模块对象)
|
||||
:return: 动态装载的所有路由配置
|
||||
"""
|
||||
_routes = []
|
||||
|
||||
if handlers_pkg is None:
|
||||
return _routes
|
||||
|
||||
# 迭代器装载所有子包中的子模块
|
||||
modules_itr = self.modules_iterator(package=handlers_pkg)
|
||||
for _file_finder, _module_name, _is_package in modules_itr:
|
||||
if _is_package:
|
||||
continue
|
||||
|
||||
_module = importlib.import_module(_module_name)
|
||||
_handlers = self.fetch_handlers(module=_module)
|
||||
for _hdl in _handlers:
|
||||
_pattern, _cls = str(_hdl[0]), _hdl[1]
|
||||
_pattern = _pattern if _pattern.startswith('/') else f"/{_pattern}"
|
||||
_url_spec = tornado.web.url(
|
||||
pattern=f"{self.uri_prefix}{_pattern}", handler=_cls, name=_cls.__name__
|
||||
)
|
||||
_routes.append(_url_spec)
|
||||
|
||||
return _routes
|
||||
|
||||
|
||||
class ApplicationSwagger(Application):
|
||||
"""
|
||||
从框架 Application 派生,增加对 Swagger 的支持。
|
||||
"""
|
||||
|
||||
swagger_schema = ""
|
||||
"""
|
||||
在 Swagger 注入时,保存 json schema。
|
||||
"""
|
||||
|
||||
swagger_home_template = ""
|
||||
"""
|
||||
在 Swagger 注入时,保存 Ui 页面内容。
|
||||
"""
|
||||
|
||||
swagger_url = "/docs"
|
||||
"""
|
||||
Swagger URL。
|
||||
"""
|
||||
|
||||
swagger_api_base_url = "/"
|
||||
"""
|
||||
Swagger API base URL。
|
||||
"""
|
||||
|
||||
swagger_title = ""
|
||||
"""
|
||||
Swagger 页面标题。
|
||||
"""
|
||||
|
||||
swagger_description = ""
|
||||
"""
|
||||
Swagger 页面描述。
|
||||
"""
|
||||
|
||||
swagger_api_version = ""
|
||||
"""
|
||||
Swagger 页面版本。
|
||||
"""
|
||||
|
||||
swagger_contact = ""
|
||||
"""
|
||||
Swagger 页面联系方式。
|
||||
"""
|
||||
|
||||
swagger_schemes = ["http", "https"]
|
||||
"""
|
||||
Swagger 协议方案。
|
||||
"""
|
||||
|
||||
def __init__(self, **settings: Any) -> None:
|
||||
self.swagger_schema = settings.get('swagger_schema', self.swagger_schema)
|
||||
self.swagger_home_template = settings.get('swagger_home_template', self.swagger_home_template)
|
||||
|
||||
self.swagger_url = settings.get('swagger_url', self.swagger_url)
|
||||
self.swagger_url = self.swagger_url if self.swagger_url.startswith('/') else f"/{self.swagger_url}"
|
||||
|
||||
self.swagger_api_base_url = settings.get('swagger_api_base_url', self.swagger_api_base_url)
|
||||
self.swagger_api_base_url = self.swagger_api_base_url if self.swagger_api_base_url.startswith('/') \
|
||||
else f"/{self.swagger_api_base_url}"
|
||||
|
||||
self.swagger_title = settings.get('swagger_title', self.swagger_title)
|
||||
self.swagger_description = settings.get('swagger_description', self.swagger_description)
|
||||
self.swagger_api_version = settings.get('swagger_api_version', self.swagger_api_version)
|
||||
self.swagger_contact = settings.get('swagger_contact', self.swagger_contact)
|
||||
self.swagger_schemes = settings.get('swagger_schemes', self.swagger_schemes)
|
||||
|
||||
super().__init__(**settings)
|
||||
|
||||
def before_create(self):
|
||||
_swagger_url = f"{self.uri_prefix}{self.swagger_url}"
|
||||
_swagger_api_base_url = f"{self.uri_prefix}{self.swagger_api_base_url}"
|
||||
|
||||
from paste.web.swagger import setup_swagger
|
||||
setup_swagger(
|
||||
app=self,
|
||||
routes=self.routes,
|
||||
swagger_url=_swagger_url,
|
||||
api_base_url=_swagger_api_base_url,
|
||||
title=self.swagger_title,
|
||||
description=self.swagger_description,
|
||||
api_version=self.swagger_api_version,
|
||||
contact=self.swagger_contact,
|
||||
schemes=self.swagger_schemes,
|
||||
)
|
||||
@@ -0,0 +1,205 @@
|
||||
import functools
|
||||
import logging
|
||||
from typing import Awaitable
|
||||
|
||||
from jwt import ExpiredSignatureError, InvalidSignatureError, InvalidTokenError
|
||||
|
||||
from paste.security import token
|
||||
from paste.web.handler import RequestHandler
|
||||
|
||||
|
||||
def route(route_pattern: str):
|
||||
"""
|
||||
路由装饰器。为类增加 route_pattern 属性,并赋值。
|
||||
|
||||
:param route_pattern: URL 路径模式
|
||||
"""
|
||||
|
||||
def wrapper(cls: type[RequestHandler]):
|
||||
cls.route_pattern = route_pattern
|
||||
return cls
|
||||
|
||||
# 标记已经被 route 装饰
|
||||
setattr(wrapper, '__route__', True)
|
||||
return wrapper
|
||||
|
||||
|
||||
def auth_token(func):
|
||||
"""
|
||||
令牌验证装饰器,用于 :class:`tornado.web.RequestHandler` 子类中的 get()/post() 等需要执行权限验证的方法,以便在正
|
||||
式执行方法前利用客户端提交的令牌进行鉴权。
|
||||
|
||||
当执行该装饰器解码令牌后,将更新 RequestHandler 对象的 token_payload 属性数据。其次,若能通过令牌中配载的 user_id 取
|
||||
得用户数据,则还将设置 current_user 属性。
|
||||
|
||||
该装饰器仅用于校验令牌的有效性,并取得用户信息,不负责校验用户的具体权限,若要验证权限需使用 @auth_permission 装饰器。
|
||||
|
||||
使用方式如下::
|
||||
|
||||
@auth_token
|
||||
async def post(self):
|
||||
pass
|
||||
|
||||
要求在请求的 Headers 中必须包含 Access-Token,且内容由 encode_token 方法签发。
|
||||
|
||||
:param func: 被装饰的函数对象,不需要手动传入该参数
|
||||
:return: 装饰后的函数对象
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
req_handler: RequestHandler = args[0]
|
||||
try:
|
||||
# 请求头
|
||||
req_headers = dict(req_handler.request.headers)
|
||||
|
||||
# 取出 Token
|
||||
access_token = req_headers.get('Access-Token', None)
|
||||
if access_token in (None, b'', ''):
|
||||
raise InvalidTokenError(f'请求地址:{req_handler.request.uri}')
|
||||
|
||||
# 如果采用 OAuth2 规范,这里应当调用远程 API 执行解码,解码后返回用户信息
|
||||
|
||||
# 用解码后的 Token 字典更新 Handler 中的的 token_dict
|
||||
token_payload = token.decode_token(access_token)
|
||||
req_handler.token_payload.update(token_payload)
|
||||
|
||||
# 根据 Token 读取用户对象,并设置到请求处理对象(控制器)
|
||||
_user_id = req_handler.token_param('user_id')
|
||||
if _user_id and req_handler.user_class:
|
||||
_user = await req_handler.user_class.async_find_by_id(_user_id)
|
||||
if _user is None:
|
||||
raise InvalidTokenError()
|
||||
req_handler.current_user = _user
|
||||
await req_handler.after_auth_token(token_payload)
|
||||
|
||||
# 兼容同步或异步方法
|
||||
_result = func(*args, **kwargs)
|
||||
if isinstance(_result, Awaitable):
|
||||
_result = await _result
|
||||
return _result
|
||||
except ExpiredSignatureError as e:
|
||||
e.args = ('令牌已过期,请求被拒绝.',)
|
||||
req_handler.response_error(e, status_code=403, api_status_code=403)
|
||||
req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True)
|
||||
return None
|
||||
except InvalidSignatureError as e:
|
||||
e.args = ('令牌签名错误,请求被拒绝.',)
|
||||
req_handler.response_error(e, status_code=403, api_status_code=403)
|
||||
req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True)
|
||||
return None
|
||||
except InvalidTokenError as e:
|
||||
e.args = ('令牌错误,请求被拒绝.',)
|
||||
req_handler.response_error(e, status_code=401, api_status_code=401)
|
||||
req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
req_handler.response_error(e, status_code=501, api_status_code=501)
|
||||
req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True)
|
||||
return None
|
||||
|
||||
# 标记已经被 auth_token 装饰
|
||||
setattr(wrapper, '__auth_token__', True)
|
||||
return wrapper
|
||||
|
||||
|
||||
def auth_permission(func):
|
||||
"""
|
||||
权限检查装饰器。若不启用 RBAC 则不应用该装饰器。
|
||||
用于检查用户是否有执行某个操作的具体权限。该装饰器须跟随在 @auth_token 装饰器的后面使用。
|
||||
|
||||
:param func: 被装饰的函数对象,不需要手动传入该参数
|
||||
:return: 装饰后的函数对象
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
#
|
||||
# 为了能在不使用 RBAC 的系统中正常运行,这里的引用必须放在函数中
|
||||
# 否则初始化过程中 RBAC 数据模型会尝试读取数据表配置,发生错误
|
||||
#
|
||||
from paste.rbac.rbac_user import RbacUser, Supervisors
|
||||
|
||||
req_handler: RequestHandler = args[0]
|
||||
try:
|
||||
# 验证当前用户状态
|
||||
_user: RbacUser = req_handler.current_user
|
||||
assert _user is not None, f"无效令牌或未登录,无权执行:{req_handler.route_pattern} 操作."
|
||||
|
||||
# 类型检测
|
||||
_right_type = isinstance(_user, RbacUser)
|
||||
assert _right_type, f"当前用户类型错误,必须为 RbacUser 的子类."
|
||||
|
||||
if _user.username not in Supervisors:
|
||||
# 验证用户权限状态
|
||||
_has_permission = await _user.has_permission(req_handler.route_pattern)
|
||||
assert _has_permission, f"当前用户 {_user.username} 无权执行:{req_handler.route_pattern} 操作."
|
||||
|
||||
# 兼容同步或异步方法
|
||||
_result = func(*args, **kwargs)
|
||||
if isinstance(_result, Awaitable):
|
||||
_result = await _result
|
||||
return _result
|
||||
except AssertionError as e:
|
||||
req_handler.response_error(e, status_code=401, api_status_code=401)
|
||||
req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
req_handler.response_error(e, status_code=501, api_status_code=501)
|
||||
req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True)
|
||||
return None
|
||||
|
||||
# 标记已经被 auth_permission 装饰
|
||||
setattr(wrapper, '__auth_permission__', True)
|
||||
return wrapper
|
||||
|
||||
|
||||
def auth_rule(func):
|
||||
"""
|
||||
规则检查装饰器。若不启用规则验证,则不应用该装饰器。
|
||||
用于对用户按规则验证。该装饰器须跟随在 @auth_token 装饰器的后面使用。
|
||||
|
||||
:param func: 被装饰的函数对象,不需要手动传入该参数
|
||||
:return: 装饰后的函数对象
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
#
|
||||
# 为了能在不使用 RBAC 的系统中正常运行,这里的引用必须放在函数中
|
||||
# 否则初始化过程中 RBAC 数据模型会尝试读取数据表配置,发生错误
|
||||
#
|
||||
from paste.rbac.rbac_user import RbacUser, Supervisors
|
||||
|
||||
req_handler: RequestHandler = args[0]
|
||||
try:
|
||||
# 验证当前用户状态
|
||||
_user: RbacUser = req_handler.current_user
|
||||
assert _user is not None, f"无效令牌或未登录,无权执行:{req_handler.route_pattern} 操作."
|
||||
|
||||
# 类型检测
|
||||
_right_type = isinstance(_user, RbacUser)
|
||||
assert _right_type, f"当前用户类型错误,必须为 RbacUser 的子类."
|
||||
|
||||
if _user.username not in Supervisors:
|
||||
# 验证用户权限状态
|
||||
_user_can = await _user.can(req_handler.route_pattern, **kwargs)
|
||||
assert _user_can, f"当前用户 {_user.username} 无权执行:{req_handler.route_pattern} 操作(规则验证不通过)."
|
||||
|
||||
# 兼容同步或异步方法
|
||||
_result = func(*args, **kwargs)
|
||||
if isinstance(_result, Awaitable):
|
||||
_result = await _result
|
||||
return _result
|
||||
except AssertionError as e:
|
||||
req_handler.response_error(e, status_code=401, api_status_code=401)
|
||||
req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
req_handler.response_error(e, status_code=501, api_status_code=501)
|
||||
req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True)
|
||||
return None
|
||||
|
||||
# 标记已经被 auth_rule 装饰
|
||||
setattr(wrapper, '__auth_rule__', True)
|
||||
return wrapper
|
||||
@@ -0,0 +1,61 @@
|
||||
from wtforms_tornado import Form
|
||||
|
||||
|
||||
class ModelForm(Form):
|
||||
"""
|
||||
模型表单。派生后主要处理以下内容::
|
||||
|
||||
有可能在 formdata 中出现的非列表类型,统一转为列表类型。
|
||||
"""
|
||||
|
||||
def __init__(self, formdata=None, obj=None, prefix="", data=None, meta=None, **kwargs):
|
||||
"""
|
||||
构造模型表单。
|
||||
|
||||
:param formdata: 来自客户端的输入数据,通常为 request.form 或等效数据。应该提供一个 multi-dict 接口来获取给定键的值列表。
|
||||
:param obj: 从该对象上与表单字段属性匹配的属性中获取现有数据。仅在未传递 formdata 时使用。
|
||||
:param prefix: 如果提供,所有字段的名称都将以值为前缀。这是为了区分单个页面上的多个表单。这只会影响匹配输入数据的 HTML 名称,而不会影响匹配现有数据的 Python 名称。
|
||||
:param data: 从该 dict 中与表单字段属性匹配的键中获取现有数据,如果 obj 也有匹配的属性,则它优先。仅在未传递 formdata 时使用。
|
||||
:param meta: 要在此窗体的 :attr: meta 实例上重写的属性 dict。
|
||||
:param kwargs: 与 data 合并以允许将现有数据作为参数传递。覆盖 data 中的任何重复键。仅在未传递 formdata 时使用。
|
||||
"""
|
||||
if isinstance(formdata, dict):
|
||||
# 对有可能在 formdata 中出现的非列表类型,统一转为列表类型
|
||||
formdata = {k: list(v) if isinstance(v, (list, tuple, set)) else [f'{v}'] for k, v in formdata.items()}
|
||||
|
||||
# 启动父类构造
|
||||
super(Form, self).__init__(formdata=formdata, obj=obj, prefix=prefix, data=data, meta=meta, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def list_to_field_list(cls, formdata, field_name: str, separator: str = '-'):
|
||||
"""
|
||||
将 list 数据转换为符合 FieldList 的赋值规则的字段数据,即转换为 字段名+隔符+下标 格式的表单数据。
|
||||
|
||||
:param formdata: 来自客户端的输入数据
|
||||
:param field_name: 字段名
|
||||
:param separator: 分隔符,默认与 FieldList 一致为 "-" 符号
|
||||
"""
|
||||
# 当以 JSON 数组传入时,转换为以 - 连接的字段项,以符合 FieldList 的赋值规则
|
||||
_value_list = formdata.get(field_name, [])
|
||||
if _value_list and isinstance(_value_list, list):
|
||||
for _idx, _item in enumerate(_value_list):
|
||||
formdata[f"{field_name}{separator}{_idx}"] = _item
|
||||
|
||||
def validate_form(self, auto_raise: bool = True):
|
||||
"""
|
||||
验证表单数据。
|
||||
|
||||
:param auto_raise: 当该参数为 True 时,若验证不成功,抛出验证异常。
|
||||
:return: 验证结果
|
||||
"""
|
||||
validate_result = {}
|
||||
|
||||
if self.validate():
|
||||
return True, validate_result
|
||||
else:
|
||||
validate_result.update(self.errors.items())
|
||||
|
||||
if auto_raise:
|
||||
raise Exception('数据验证错误!', {'form_data': self.data, 'form_errors': validate_result})
|
||||
|
||||
return False, validate_result
|
||||
Executable
+249
@@ -0,0 +1,249 @@
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from collections import namedtuple
|
||||
from typing import Optional, Union, Any, Type
|
||||
|
||||
import tornado.web
|
||||
|
||||
from paste.core import config
|
||||
from paste.db.basemodel import BaseModel
|
||||
from paste.util.encoder import JsonDumpsEncoder
|
||||
from paste.core.logging import echo_log
|
||||
|
||||
|
||||
def init_user_class():
|
||||
"""
|
||||
从配置文件初始化用户类。默认采用 rbac.RbacUser。
|
||||
"""
|
||||
|
||||
try:
|
||||
# 若没有配置 RBAC 直接返回 None
|
||||
_rbac_cfg = config.get_config('rbac.user_class', None)
|
||||
except AssertionError:
|
||||
return None
|
||||
|
||||
_cfg_user_class: str = config.get_config('rbac.user_class', None)
|
||||
if _cfg_user_class is not None:
|
||||
_parts = _cfg_user_class.split('.')
|
||||
_module_name = '.'.join(_parts[:-1])
|
||||
_user_module = importlib.import_module(_module_name)
|
||||
_user_class = getattr(_user_module, _parts[-1])
|
||||
return _user_class
|
||||
|
||||
from paste.rbac.rbac_user import RbacUser
|
||||
return RbacUser
|
||||
|
||||
|
||||
class RequestHandler(tornado.web.RequestHandler, ABC):
|
||||
"""
|
||||
请求控制父类。
|
||||
"""
|
||||
|
||||
route_pattern: Optional[str] = None
|
||||
"""
|
||||
URL 路径模式。由装饰器 web.decorators.route 赋值,在 base.Application.load_handler_module 自动加载时调用,作为访问
|
||||
路径,设置到 Application 中。
|
||||
"""
|
||||
|
||||
user_class: Type[BaseModel] = init_user_class()
|
||||
"""
|
||||
用户数据处理类。装饰器 web.decorators.auth_token 执行令牌验证时调用该类,用于创建用户对象,并保存在 current_user 属性中。
|
||||
注意:这里仅初始化类,而不创建对象。该类允许用户继承扩展,然后自行配置。主要用于执行有关用户的数据操作。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def log(cls, msg: Union[str, Exception], level: int = logging.INFO, is_log_exc: bool = False):
|
||||
"""
|
||||
输出日志文本。
|
||||
|
||||
:param msg: 消息内容,当是 Exception 对象时,从 args 中取出第一项作为消息
|
||||
:param level: 消息等级
|
||||
:param is_log_exc: 是否输出异常信息到日志文件
|
||||
"""
|
||||
echo_log(msg=msg, level=level, is_log_exc=is_log_exc)
|
||||
|
||||
@classmethod
|
||||
def dict_to_namedtuple(cls, name, data):
|
||||
"""
|
||||
递归转换字典和列表中的字典为 namedtuple 对象。
|
||||
|
||||
参数:
|
||||
name: 用于创建 namedtuple 的名称
|
||||
data: 要转换的数据,可以是 dict、list 或基本类型
|
||||
|
||||
返回:
|
||||
转换后的 namedtuple 对象或列表
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
# 处理字典类型
|
||||
NT = namedtuple(name, data.keys())
|
||||
return NT(**{
|
||||
k: cls.dict_to_namedtuple(k, v)
|
||||
for k, v in data.items()
|
||||
})
|
||||
elif isinstance(data, list):
|
||||
# 处理列表类型:递归转换列表中的每个元素
|
||||
return [
|
||||
cls.dict_to_namedtuple(f"{name}_item", item)
|
||||
if isinstance(item, (dict, list)) else item
|
||||
for item in data
|
||||
]
|
||||
else:
|
||||
# 基本类型直接返回
|
||||
return data
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.rule_kwargs = {}
|
||||
"""
|
||||
规则参数,用于在控制器和规则之间做数据交换
|
||||
"""
|
||||
|
||||
self.token_payload: dict[str: Any] = {}
|
||||
"""
|
||||
令牌配载数据字典。装饰器 web.decorators.auth_token 执行令牌验证时解码并赋值。 在 HandlerRequest 子类中
|
||||
只要配置 auth_token 装饰即可使用该配载数据。
|
||||
|
||||
其结构为::
|
||||
|
||||
{
|
||||
'iss': private_iss,
|
||||
'iat': datetime.datetime.utcnow(),
|
||||
'exp': datetime.datetime.utcnow() + datetime.timedelta(days=7),
|
||||
'params': {
|
||||
'id': user_id,
|
||||
'username': username
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
async def after_auth_token(self, token_payload: dict):
|
||||
"""
|
||||
在验证 Token 后调用的函数,子类可覆盖。
|
||||
|
||||
:param token_payload: Token 数据项
|
||||
"""
|
||||
pass
|
||||
|
||||
def token_params(self) -> dict:
|
||||
"""
|
||||
取出 Token 中的参数字典。
|
||||
|
||||
:return: 参数字典
|
||||
"""
|
||||
return self.token_payload.get('params', {})
|
||||
|
||||
def token_param(self, key):
|
||||
"""
|
||||
取出 Token 参数字典中的参数。
|
||||
|
||||
:param key: 参数名称
|
||||
"""
|
||||
return self.token_params().get(key, None)
|
||||
|
||||
def set_default_headers(self):
|
||||
"""
|
||||
设置默认的请求头。
|
||||
"""
|
||||
request_headers = dict(self.request.headers)
|
||||
allow_headers = [
|
||||
'Accept', 'Content-Type', 'Origin', 'Access-Token', 'ClientId', 'Timestamp', 'Verify-Hash', 'Security-Key'
|
||||
]
|
||||
allow_methods = [
|
||||
'OPTIONS', 'GET', 'POST'
|
||||
]
|
||||
allow_origins = [
|
||||
request_headers.get('Origin', '*')
|
||||
]
|
||||
content_type = [
|
||||
request_headers.get('Content-type', 'application/json')
|
||||
]
|
||||
response_header_cfg = {
|
||||
'Access-Control-Allow-Headers': ','.join(set(allow_headers)),
|
||||
'Access-Control-Allow-Methods': ','.join(set(allow_methods)),
|
||||
'Access-Control-Allow-Origin': ','.join(set(allow_origins)),
|
||||
'Access-Control-Allow-Credentials': 'true',
|
||||
'Content-type': ','.join(set(content_type)),
|
||||
}
|
||||
for _k, _v in response_header_cfg.items():
|
||||
self.set_header(_k, _v)
|
||||
|
||||
def get_current_user(self) -> Any:
|
||||
if not hasattr(self, '_current_user'):
|
||||
if self.user_class is not None:
|
||||
# 设置了用户类,但是未创建对象的,这里默认创建空用户对象
|
||||
setattr(self, '_current_user', self.user_class())
|
||||
else:
|
||||
setattr(self, '_current_user', None)
|
||||
return self._current_user
|
||||
|
||||
def options(self):
|
||||
"""
|
||||
处理跨域请求中的 OPTIONS 预检。
|
||||
"""
|
||||
self.set_status(status_code=200)
|
||||
self.finish()
|
||||
|
||||
def request_arguments(self):
|
||||
"""
|
||||
取得所有请求参数。若 self.request.arguments 中有参数,则优先读取。
|
||||
若无参数,则从 self.request.body 读取,且该参数必须为 JSON 结构。
|
||||
|
||||
:return: 请求参数字典
|
||||
"""
|
||||
_args: dict[str: Any] = dict()
|
||||
if len(self.request.arguments) > 0:
|
||||
# 按 Form 提交时,从 Form 参数中读取命令,命令参数从 request.arguments 读取
|
||||
for _n, _v in self.request.arguments.items():
|
||||
if isinstance(_v, list):
|
||||
# 对数组进行分解
|
||||
if len(_v) == 1:
|
||||
_args[_n] = _v[0].decode("utf-8")
|
||||
else:
|
||||
_args[_n] = [__v.decode("utf-8") for __v in _v]
|
||||
else:
|
||||
_args[_n] = f"{_v}"
|
||||
else:
|
||||
# 非 Form 提交时,从 Body 解析命令,命令参数从 body.params 读取
|
||||
_body = self.request.body if self.request.body else '{}'
|
||||
_args = json.loads(_body)
|
||||
return _args
|
||||
|
||||
def response_ok(self, **kwargs):
|
||||
"""
|
||||
成功响应内容。
|
||||
|
||||
:param kwargs: 参数
|
||||
"""
|
||||
self.set_status(status_code=200)
|
||||
chunk = {'code': 200, 'status': 'OK'}
|
||||
chunk.update(kwargs)
|
||||
self.write(json.dumps(chunk, cls=JsonDumpsEncoder, ensure_ascii=False))
|
||||
self.set_header('Content-Type', 'application/json')
|
||||
|
||||
def response_error(self, e: Exception, status_code: int = 200, api_status_code: int = None, **kwargs):
|
||||
"""
|
||||
错误响应内容。
|
||||
|
||||
:param e: 异常对象
|
||||
:param status_code: HTTP/HTTPS 响应状态码
|
||||
:param api_status_code: API 状态码,若不提供则使用 status_code 参数
|
||||
"""
|
||||
if api_status_code is None:
|
||||
api_status_code = status_code
|
||||
|
||||
self.set_status(status_code=status_code)
|
||||
chunk = {'code': api_status_code, 'status': 'error'}
|
||||
chunk.update(kwargs)
|
||||
if len(e.args) > 0 and isinstance(e.args[0], str):
|
||||
chunk['message'] = e.args[0]
|
||||
if len(e.args) > 1:
|
||||
if isinstance(e.args[1], dict):
|
||||
chunk.update(e.args[1])
|
||||
elif isinstance(e.args[1], list):
|
||||
chunk['errors'] = e.args[1]
|
||||
self.write(json.dumps(chunk, cls=JsonDumpsEncoder, ensure_ascii=False))
|
||||
self.set_header('Content-Type', 'application/json')
|
||||
@@ -0,0 +1,212 @@
|
||||
import ast
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import threading
|
||||
from typing import Tuple, List, Any, Optional, Dict
|
||||
|
||||
from tornado.template import Loader, Template
|
||||
from tornado.web import UIModule
|
||||
|
||||
|
||||
class ParamAwareUIModuleDataWarehouse:
|
||||
"""
|
||||
预处理数据仓库。
|
||||
数据用唯一调用 ID 作为 Key 存储。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def prepare(self, module_name: str, call_id: str, data: Any):
|
||||
"""存储预处理数据"""
|
||||
with self._lock:
|
||||
self._store.setdefault(module_name, {})[call_id] = data
|
||||
|
||||
def fetch(self, module_name: str, call_id: str) -> Any:
|
||||
"""获取预处理数据"""
|
||||
with self._lock:
|
||||
return self._store.get(module_name, {}).get(call_id)
|
||||
|
||||
|
||||
warehouse = ParamAwareUIModuleDataWarehouse()
|
||||
"""
|
||||
全局单例,参数感知预处理仓库。
|
||||
"""
|
||||
|
||||
|
||||
class ParamAwareUIModule(UIModule):
|
||||
"""
|
||||
参数感知 UIModule 父类。
|
||||
1、子类应当实现 async_prepare 方法完成数据预处理,该方法在 Handler 执行过程中会根据模板文件的配置调用完成数据初始化,模板中配置的参数会传给该方法。
|
||||
2、原有的 render 方法作为从数据仓库中获取数据,调用 render_with_data 方法完成渲染,已无需在子类中实现,模板中配置的参数也会传给该方法。
|
||||
3、子类应当实现 render_with_data 方法完成渲染,预处理数据通过参数 prepared_data 传入。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def generate_call_id(cls, module_name: str, kwargs: dict) -> str:
|
||||
"""根据模块名和参数生成唯一调用ID"""
|
||||
param_str = ",".join(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
||||
return hashlib.md5(f"{module_name}|{param_str}".encode()).hexdigest()
|
||||
|
||||
async def async_prepare(self, **kwargs) -> Any:
|
||||
"""子类实现异步数据加载,用静态方法避免参数缺失"""
|
||||
raise NotImplementedError
|
||||
|
||||
def render(self, **kwargs):
|
||||
"""自动关联预处理数据"""
|
||||
call_id = self.generate_call_id(self.__class__.__name__, kwargs)
|
||||
prepared_data = warehouse.fetch(self.__class__.__name__, call_id)
|
||||
return self.render_with_data(prepared_data, **kwargs)
|
||||
|
||||
def render_with_data(self, prepared_data: Any, **kwargs):
|
||||
"""子类实现具体渲染逻辑"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UIModuleCallAnalyzer(ast.NodeVisitor):
|
||||
"""
|
||||
用于分析 Tornado 模板生成的 Python 代码,从中解析出对 UIModule 的名称和实际调用参数。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = [] # 存储 (module_class_name, kwargs)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
"""
|
||||
匹配 _tt_tmp = _tt_modules.XxxModule(...) 模式。
|
||||
|
||||
:param node:
|
||||
:return:
|
||||
"""
|
||||
if (isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Attribute)
|
||||
and isinstance(node.value.func.value, ast.Name) and node.value.func.value.id == '_tt_modules'):
|
||||
module_class = node.value.func.attr
|
||||
kwargs = self._extract_kwargs(node.value)
|
||||
self.calls.append((module_class, kwargs))
|
||||
|
||||
@classmethod
|
||||
def _extract_kwargs(cls, call_node: ast.Call) -> dict:
|
||||
"""
|
||||
安全提取调用参数。
|
||||
|
||||
:param call_node: 调用节点
|
||||
:return: 实际参数
|
||||
"""
|
||||
kwargs = {}
|
||||
|
||||
# 处理位置参数 (Tornado不会生成这种情况)
|
||||
for arg in call_node.args:
|
||||
if isinstance(arg, ast.Constant):
|
||||
kwargs.setdefault('_pos_args', []).append(arg.s)
|
||||
|
||||
# 处理关键字参数
|
||||
for kw in call_node.keywords:
|
||||
if isinstance(kw.value, (ast.Constant, ast.Constant, ast.Constant)):
|
||||
kwargs[kw.arg] = ast.literal_eval(ast.unparse(kw.value))
|
||||
elif isinstance(kw.value, ast.Name) and kw.value.id in ('True', 'False', 'None'):
|
||||
kwargs[kw.arg] = ast.literal_eval(kw.value.id)
|
||||
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def ui_module_calls(cls, template_code: str) -> List[Tuple[str, dict]]:
|
||||
"""
|
||||
从模板生成的 Python 代码中提取 UIModule 的调用。
|
||||
|
||||
:param template_code: 模板生成的函数。
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(template_code)
|
||||
analyzer = cls()
|
||||
analyzer.visit(tree)
|
||||
return analyzer.calls
|
||||
except:
|
||||
return []
|
||||
|
||||
|
||||
class ParamAwareTemplate(Template):
|
||||
"""
|
||||
参数感知模板。
|
||||
重写 _generate_python 方法,从 Tornado 模板编译生成的 Python 代码中分析出 UIModule 调用参数。
|
||||
提供 prepare_ui_modules 方法在 Handler 中 load 完成后预处理数据,预处理得到的数据会保存在数据仓库中。
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.ui_module_calls = []
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _generate_python(self, *args, **kwargs):
|
||||
code = super()._generate_python(*args, **kwargs)
|
||||
self.ui_module_calls = UIModuleCallAnalyzer.ui_module_calls(code)
|
||||
return code
|
||||
|
||||
async def prepare_ui_modules(self, template: 'ParamAwareTemplate', ui_modules: dict[UIModule]):
|
||||
"""执行模板中所有UIModule的异步预处理"""
|
||||
tasks = []
|
||||
|
||||
for module_name, kwargs in template.ui_module_calls:
|
||||
module_class = ui_modules.get(module_name)
|
||||
if not hasattr(module_class, 'async_prepare'):
|
||||
continue
|
||||
|
||||
call_id = ParamAwareUIModule.generate_call_id(module_name, kwargs)
|
||||
task = asyncio.create_task(
|
||||
self._prepare_single(module_class, call_id, kwargs)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _prepare_single(self, module_class, call_id, kwargs):
|
||||
"""单个模块的预处理流程"""
|
||||
try:
|
||||
_ui_modulr: ParamAwareUIModule = module_class(handler=self.namespace.get('handler'))
|
||||
data = await _ui_modulr.async_prepare(**kwargs)
|
||||
warehouse.prepare(module_class.__name__, call_id, data)
|
||||
except Exception as e:
|
||||
warehouse.prepare(module_class.__name__, call_id, {
|
||||
"__error__": str(e)
|
||||
})
|
||||
|
||||
|
||||
class ParamAwareLoader(Loader):
|
||||
"""
|
||||
参数感知装载器,也是本代码文件中主要对外开放的类。
|
||||
重写 _create_template 方法,用参数感知模板替换原有模板。
|
||||
重写 load 明确返回参数感知模板。
|
||||
"""
|
||||
|
||||
def __init__(self, root_directory: str, **kwargs: Any) -> None:
|
||||
super().__init__(root_directory, **kwargs)
|
||||
self.templates = {} # type: Dict[str, ParamAwareTemplate]
|
||||
|
||||
def _create_template(self, name: str) -> ParamAwareTemplate:
|
||||
path = os.path.join(self.root, name)
|
||||
with open(path, "rb") as f:
|
||||
template = ParamAwareTemplate(f.read(), name=name, loader=self)
|
||||
return template
|
||||
|
||||
def load(self, name: str, parent_path: Optional[str] = None) -> ParamAwareTemplate:
|
||||
"""Loads a template."""
|
||||
name = self.resolve_path(name, parent_path=parent_path)
|
||||
with self.lock:
|
||||
if name not in self.templates:
|
||||
self.templates[name] = self._create_template(name)
|
||||
return self.templates[name]
|
||||
|
||||
async def load_with_prepare(self, name: str) -> ParamAwareTemplate:
|
||||
"""
|
||||
加载模板,并完成数据准备。
|
||||
|
||||
:param name: 模板名称
|
||||
:return: 完成数据准备的模板
|
||||
"""
|
||||
template = self.load(name)
|
||||
_modules = self.namespace.get('modules', None)
|
||||
if _modules and hasattr(_modules, 'ui_modules'):
|
||||
_ui_modules = _modules.ui_modules
|
||||
await template.prepare_ui_modules(template, _ui_modules)
|
||||
return template
|
||||
@@ -0,0 +1,364 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from asyncio import Task
|
||||
from typing import Optional, Callable, Awaitable, Dict, Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPClientError, HTTPResponse
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
from paste.core.logging import echo_log
|
||||
from paste.util.encoder import JsonDumpsEncoder
|
||||
|
||||
|
||||
_global_http_client: Optional[AsyncHTTPClient] = None
|
||||
|
||||
|
||||
def get_http_client():
|
||||
"""获取全局共享的 HTTP 客户端,避免重复创建和销毁。"""
|
||||
global _global_http_client
|
||||
if _global_http_client is None:
|
||||
_global_http_client = AsyncHTTPClient()
|
||||
return _global_http_client
|
||||
|
||||
|
||||
async def close_http_client():
|
||||
"""关闭全局 HTTP 客户端。"""
|
||||
global _global_http_client
|
||||
if _global_http_client:
|
||||
_global_http_client.close()
|
||||
_global_http_client = None
|
||||
|
||||
|
||||
async def async_request(request: HTTPRequest, before_request: Callable = None, after_request: Callable = None,
|
||||
retry_queue: asyncio.Queue[HTTPRequest] = None, is_log_exc=True,
|
||||
on_error: Callable = None):
|
||||
"""
|
||||
异步提交请求,返回响应数据对象。如提供回调函数,则将响应对象作为回调函数的参数传入,并执行。
|
||||
|
||||
:param request: 请求对象
|
||||
:param before_request: 在提交请求前要处理的行为
|
||||
:param after_request: 请求后的回调函数,回调参数为:HTTPResponse 响应对象、重试请求队列(无该参数则为 None)
|
||||
:param retry_queue: 重试队列,若传入该参数,则失败的请求会放入该队列
|
||||
:param is_log_exc: 是否记录日志
|
||||
:param on_error: 发生异常后的处理,回调参数为:请求对象、异常对象、重试请求队列(无该参数则为 None)
|
||||
:return: 响应数据对象
|
||||
"""
|
||||
_http_client: Optional[AsyncHTTPClient] = None
|
||||
try:
|
||||
# 执行请求前的回调函数
|
||||
if before_request:
|
||||
_before_result = before_request(request, retry_queue, is_log_exc=is_log_exc)
|
||||
if isinstance(_before_result, Awaitable):
|
||||
# 处理回调协程
|
||||
await _before_result
|
||||
|
||||
if is_log_exc:
|
||||
# 记录请求信息
|
||||
echo_log(f'请求地址:{request.method}: {request.url}.')
|
||||
echo_log(f'主体长度:{sys.getsizeof(request.body)}.')
|
||||
|
||||
# 在此之前的异常,不加入重试队列
|
||||
_http_client = get_http_client()
|
||||
_response: HTTPResponse = await _http_client.fetch(request=request)
|
||||
if after_request:
|
||||
_after_result = after_request(_response, retry_queue)
|
||||
if isinstance(_after_result, Awaitable):
|
||||
# 处理协程回调
|
||||
await _after_result
|
||||
return _response
|
||||
except HTTPClientError as e:
|
||||
if e.response and e.response.code in (302, 412):
|
||||
# 这里依然可以拿到响应对象,继续返回
|
||||
return e.response
|
||||
if is_log_exc:
|
||||
echo_log(f'请求错误:{e},地址:{request.url}', level=logging.ERROR, is_log_exc=True)
|
||||
if e.response is not None:
|
||||
echo_log(f'响应内容:{e.response.body.decode()}', level=logging.ERROR)
|
||||
if retry_queue is not None and _http_client is not None:
|
||||
await retry_queue.put(request)
|
||||
if on_error:
|
||||
_err_result = on_error(request, e, retry_queue)
|
||||
if isinstance(_err_result, Awaitable):
|
||||
# 处理协程回调
|
||||
await _err_result
|
||||
except ConnectionError as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'连接错误:{e}', level=logging.ERROR, is_log_exc=True)
|
||||
if retry_queue is not None and _http_client is not None:
|
||||
await retry_queue.put(request)
|
||||
if on_error:
|
||||
_err_result = on_error(request, e, retry_queue)
|
||||
if isinstance(_err_result, Awaitable):
|
||||
# 处理协程回调
|
||||
await _err_result
|
||||
except Exception as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'未知错误:{e}', level=logging.ERROR, is_log_exc=True)
|
||||
if retry_queue is not None and _http_client is not None:
|
||||
await retry_queue.put(request)
|
||||
if on_error:
|
||||
_err_result = on_error(request, e, retry_queue)
|
||||
if isinstance(_err_result, Awaitable):
|
||||
# 处理协程回调
|
||||
await _err_result
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def async_concurrency(request_queue: Optional[asyncio.Queue[HTTPRequest]], con_count=10, retry=5,
|
||||
before_request: Callable = None,
|
||||
after_request: Callable = None, after_done: Callable = None,
|
||||
is_log_exc=True,
|
||||
retry_queue: asyncio.Queue[HTTPRequest] = None,
|
||||
response_list: list[HTTPResponse] = None,
|
||||
on_error: Callable = None,
|
||||
wait_after: int = None):
|
||||
"""
|
||||
异步并发请求,默认并发 10 个请求,且默认合计尝试 5 次(除第 1 次外,再尝试 4 次)。
|
||||
|
||||
:param request_queue: 请求队列
|
||||
:param con_count: 每批并发请求数量
|
||||
:param retry: 总尝试次数,默认尝试 5 次
|
||||
:param before_request: 在提交请求前要处理的行为
|
||||
:param after_request: 请求后的回调函数,回调参数为:HTTPResponse 响应对象、重试请求队列(无该参数则为 None)
|
||||
:param after_done: 所有任务都完成后的回调,回调参数为:response_list,发生异常的请求不在列表中,应通过 on_error 回调获取
|
||||
:param is_log_exc: 是否记录异常日志
|
||||
:param retry_queue: 重试队列,若传入该参数,则失败的请求会放入该队列
|
||||
:param response_list: 响应列表
|
||||
:param on_error: 发生异常后的处理,回调参数为:请求对象、异常对象、重试请求队列(无该参数则为 None)
|
||||
:param wait_after: 在请求完成后的等待时间,应当考虑请求服务器的处理时间,必要时可设置等待时间,但是不易设置过长一般 1~3 秒
|
||||
:return 若设置了 after_done 且有返回值,则返回,否则返响应列表
|
||||
"""
|
||||
if retry_queue is None:
|
||||
retry_queue = asyncio.Queue()
|
||||
|
||||
if response_list is None:
|
||||
response_list = []
|
||||
|
||||
while not request_queue.empty():
|
||||
# 按配置,读取队列,创建任务组
|
||||
_tasks: set[Task] = set()
|
||||
for _i in range(con_count):
|
||||
if request_queue.empty():
|
||||
break
|
||||
|
||||
_request = await request_queue.get()
|
||||
setattr(_request, 'max_retry', retry)
|
||||
_task = asyncio.create_task(async_request(
|
||||
request=_request, before_request=before_request, after_request=after_request,
|
||||
retry_queue=retry_queue, is_log_exc=is_log_exc, on_error=on_error
|
||||
))
|
||||
_tasks.add(_task)
|
||||
|
||||
# 执行,并等待任务组完成
|
||||
response_list += await asyncio.gather(*_tasks)
|
||||
# 处理等待
|
||||
if wait_after:
|
||||
await asyncio.sleep(wait_after)
|
||||
|
||||
# 检查任务(包含重试任务)是否完成,完成则返回,否则继续
|
||||
if not request_queue.empty():
|
||||
continue
|
||||
|
||||
# 任然有需要重试的请求
|
||||
while not retry_queue.empty():
|
||||
_request = await retry_queue.get()
|
||||
_retry = getattr(_request, 'retry', 0) + 1
|
||||
if _retry < retry:
|
||||
setattr(_request, 'retry', _retry)
|
||||
await request_queue.put(_request)
|
||||
|
||||
if is_log_exc and not request_queue.empty():
|
||||
echo_log(f'启动重试,共有:{request_queue.qsize()} 个请求启动重试.')
|
||||
|
||||
echo_log(f'所有请求执行完毕,任务结束.')
|
||||
# 所有请求包括重试都已经完成,执行回调
|
||||
_result = None
|
||||
if after_done:
|
||||
_after_done_result = after_done(response_list)
|
||||
if isinstance(_after_done_result, Awaitable):
|
||||
# 处理协程回调
|
||||
_result = await _after_done_result
|
||||
else:
|
||||
# 普通函数调用
|
||||
_result = _after_done_result
|
||||
# 钩子有返回时,返回钩子处理结果
|
||||
if _result is not None:
|
||||
return _result
|
||||
return response_list
|
||||
|
||||
|
||||
async def async_forward(handler: RequestHandler, forward_url: str, is_log_exc=True, request_timeout=60):
|
||||
"""
|
||||
转发请求。
|
||||
|
||||
:param handler: 收到请求的控制器对象
|
||||
:param forward_url: 要转发的目标地址
|
||||
:param is_log_exc: 是否记录日志
|
||||
:param request_timeout: 超时时长
|
||||
:return: 转发响应结果
|
||||
"""
|
||||
_req_params = {
|
||||
'body': handler.request.body,
|
||||
'headers': {
|
||||
'Accept': '*/*',
|
||||
'Access-Token': handler.request.headers.get('Access-Token'),
|
||||
'Content-Type': handler.request.headers.get('Content-Type', 'application/json'),
|
||||
'timestamp': f'{int(time.time() * 1000)}'
|
||||
},
|
||||
'method': handler.request.method,
|
||||
'request_timeout': request_timeout,
|
||||
'url': forward_url,
|
||||
}
|
||||
_request = HTTPRequest(**_req_params)
|
||||
|
||||
_http_client = get_http_client()
|
||||
try:
|
||||
if is_log_exc:
|
||||
# 记录请求信息
|
||||
echo_log(f'请求地址:{_request.method}: {_request.url}.')
|
||||
echo_log(f'主体长度:{sys.getsizeof(_request.body)}.')
|
||||
_response: HTTPResponse = await _http_client.fetch(request=_request)
|
||||
return _response
|
||||
except HTTPClientError as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'请求错误:{e},地址:{_request.url}', level=logging.ERROR, is_log_exc=True)
|
||||
if e.response is not None:
|
||||
echo_log(f'响应内容:{e.response.body.decode()}', level=logging.ERROR)
|
||||
raise e
|
||||
except ConnectionError as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'连接错误:{e}', level=logging.ERROR, is_log_exc=True)
|
||||
raise e
|
||||
except Exception as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'未知错误:{e}', level=logging.ERROR, is_log_exc=True)
|
||||
raise e
|
||||
|
||||
|
||||
def build_http_request(
|
||||
url: str,
|
||||
body: Optional[Dict[str, Any]] = None,
|
||||
method: str = 'POST',
|
||||
timeout: Optional[float] = None,
|
||||
follow_redirects: bool = True,
|
||||
use_form: bool = False,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
** kwargs
|
||||
) -> HTTPRequest:
|
||||
"""
|
||||
构建一个 tornado.httpclient.HTTPRequest 对象。
|
||||
|
||||
支持 GET 和 POST 方法:
|
||||
- GET: 参数通过 URL 查询字符串传递
|
||||
- POST: 参数通过 JSON body 或 form 表单传递(由 use_form 控制)
|
||||
|
||||
:param url: 请求的完整 URL
|
||||
:param body: 请求体(字典),GET 时为查询参数,POST 时为 JSON 或 form 数据
|
||||
:param method: HTTP 方法,仅支持 'GET' 或 'POST'
|
||||
:param timeout: 请求超时时间(秒)
|
||||
:param follow_redirects: 是否跟随重定向
|
||||
:param use_form: 如果为 True,POST 时使用 application/x-www-form-urlencoded 格式;否则使用 JSON
|
||||
:param extra_headers: 可选的额外请求头,用于传入 Cookie、Authorization 等
|
||||
:param kwargs: 其他参数,符合 tornado.httpclient.HTTPRequest 参数要求
|
||||
:return: tornado.httpclient.HTTPRequest 对象
|
||||
:raises ValueError: 当 method 不合法时抛出
|
||||
"""
|
||||
if method not in ('GET', 'POST'):
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
body = body or {}
|
||||
|
||||
# 基础头
|
||||
headers = {
|
||||
'Accept': '*/*',
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
'Accept-Language': 'zh-CN,zh;q=0.9',
|
||||
'Connection': 'keep-alive',
|
||||
'Content-Type': 'application/x-www-form-urlencoded; charset=UTF-8',
|
||||
'X-Requested-With': 'XMLHttpRequest',
|
||||
}
|
||||
|
||||
# 合并额外头(优先级:extra_headers > DEFAULT_HEADERS)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
req_params = {
|
||||
'url': url,
|
||||
'method': method,
|
||||
'headers': headers,
|
||||
'follow_redirects': follow_redirects,
|
||||
}
|
||||
|
||||
if timeout:
|
||||
req_params['request_timeout'] = timeout
|
||||
|
||||
if method == 'GET':
|
||||
# GET 方法:参数拼接到 URL
|
||||
if body:
|
||||
req_params['url'] = f"{url}?{urlencode(body)}"
|
||||
req_params.pop('body', None)
|
||||
req_params['headers'].pop('Content-Type', None)
|
||||
req_params['headers'].pop('Content-Length', None)
|
||||
req_params['headers'].pop('Transfer-Encoding', None)
|
||||
else:
|
||||
# POST 方法
|
||||
if use_form:
|
||||
# 检查 body 中是否有文件对象
|
||||
has_files = any(isinstance(v, io.IOBase) for v in body.values())
|
||||
|
||||
if has_files:
|
||||
# 构建 multipart/form-data
|
||||
boundary = f"----WebKitFormBoundary{random.randint(10000000, 99999999)}"
|
||||
multipart_body = []
|
||||
|
||||
for key, value in body.items():
|
||||
if isinstance(value, io.IOBase):
|
||||
# 文件对象
|
||||
value.seek(0) # 确保从头读取
|
||||
filename = getattr(value, 'name', key) # 尝试获取文件名
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
|
||||
multipart_body.append(f'--{boundary}\r\n'.encode())
|
||||
multipart_body.append(f'Content-Disposition: form-data; name="{key}"; filename="{filename}"\r\n'.encode())
|
||||
multipart_body.append(f'Content-Type: {content_type}\r\n\r\n'.encode())
|
||||
multipart_body.append(value.read())
|
||||
multipart_body.append(b'\r\n')
|
||||
else:
|
||||
# 普通文本
|
||||
multipart_body.append(f'--{boundary}\r\n'.encode())
|
||||
multipart_body.append(f'Content-Disposition: form-data; name="{key}"\r\n\r\n'.encode())
|
||||
# 处理布尔值和 None
|
||||
if isinstance(value, bool):
|
||||
value = str(value).lower()
|
||||
elif value is None:
|
||||
value = ''
|
||||
else:
|
||||
value = str(value)
|
||||
multipart_body.append(value.encode('utf-8'))
|
||||
multipart_body.append(b'\r\n')
|
||||
|
||||
multipart_body.append(f'--{boundary}--\r\n'.encode())
|
||||
req_params['body'] = b''.join(multipart_body)
|
||||
headers['Content-Length'] = str(len(req_params['body']))
|
||||
headers['Content-Type'] = f'multipart/form-data; boundary={boundary}'
|
||||
else:
|
||||
# 普通 form 表单(无文件)
|
||||
req_params['body'] = urlencode(body).encode('utf-8')
|
||||
headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
|
||||
headers['Content-Length'] = str(len(req_params['body']))
|
||||
else:
|
||||
body_bytes = json.dumps(body, cls=JsonDumpsEncoder, ensure_ascii=False).encode('utf-8')
|
||||
req_params['body'] = body_bytes
|
||||
headers['Content-Length'] = str(len(body_bytes))
|
||||
# 保持 application/json
|
||||
|
||||
req_params.update(kwargs)
|
||||
return HTTPRequest(**req_params)
|
||||
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
import typing
|
||||
|
||||
import tornado
|
||||
import tornado.web
|
||||
from tornado_swagger._builders import generate_doc_from_endpoints
|
||||
from tornado_swagger._handlers import TornadoBaseHandler
|
||||
from tornado_swagger.const import API_SWAGGER_2
|
||||
from tornado_swagger.setup import STATIC_PATH
|
||||
|
||||
from paste.web.application import ApplicationSwagger
|
||||
|
||||
|
||||
class SwaggerUiHandler(TornadoBaseHandler):
|
||||
"""
|
||||
自定义 Ui,支持从应用程序读取文档页面。
|
||||
主要是为了允许不同的应用具有不同的接口描述页面。
|
||||
"""
|
||||
|
||||
def get(self):
|
||||
if hasattr(self.application, 'swagger_home_template'):
|
||||
self.write(self.application.swagger_home_template)
|
||||
else:
|
||||
self.write(
|
||||
f'类型错误,无法从应用程序读取 swagger_home_template 属性,'
|
||||
f'请使用 ApplicationSwagger 以支持 Swagger。'
|
||||
)
|
||||
|
||||
|
||||
class SwaggerSpecHandler(TornadoBaseHandler):
|
||||
"""
|
||||
自定义 Spec,支持从应用程序读取 Schema。
|
||||
主要是为了允许不同的应用具有不同的接口描述页面。
|
||||
"""
|
||||
|
||||
def get(self):
|
||||
if hasattr(self.application, 'swagger_schema'):
|
||||
self.write(self.application.swagger_schema)
|
||||
else:
|
||||
self.write(
|
||||
f'类型错误,无法从应用程序读取 swagger_schema 属性,'
|
||||
f'请使用 ApplicationSwagger 以支持 Swagger。'
|
||||
)
|
||||
|
||||
|
||||
def setup_swagger(
|
||||
app: ApplicationSwagger,
|
||||
routes: typing.List[tornado.web.URLSpec],
|
||||
*,
|
||||
swagger_url: str = "/api/doc",
|
||||
api_base_url: str = "/",
|
||||
description: str = "Swagger API definition",
|
||||
api_version: str = "1.0.0",
|
||||
title: str = "Swagger API",
|
||||
contact: str = "",
|
||||
schemes: list = None,
|
||||
security_definitions: dict = None,
|
||||
security: list = None,
|
||||
display_models: bool = True,
|
||||
api_definition_version: str = API_SWAGGER_2
|
||||
):
|
||||
"""
|
||||
注入 Swagger ui 到应用程序路由。
|
||||
"""
|
||||
|
||||
swagger_schema = generate_doc_from_endpoints(
|
||||
routes,
|
||||
api_base_url=api_base_url,
|
||||
description=description,
|
||||
api_version=api_version,
|
||||
title=title,
|
||||
contact=contact,
|
||||
schemes=schemes,
|
||||
security_definitions=security_definitions,
|
||||
security=security,
|
||||
api_definition_version=api_definition_version,
|
||||
)
|
||||
|
||||
_swagger_ui_url = f"/{swagger_url}" if not swagger_url.startswith("/") else swagger_url
|
||||
_base_swagger_ui_url = _swagger_ui_url.rstrip("/")
|
||||
_swagger_spec_url = f"{_swagger_ui_url}/swagger.json"
|
||||
|
||||
routes[:0] = [
|
||||
tornado.web.url(_swagger_ui_url, SwaggerUiHandler),
|
||||
tornado.web.url(f"{_base_swagger_ui_url}/", SwaggerUiHandler),
|
||||
tornado.web.url(_swagger_spec_url, SwaggerSpecHandler),
|
||||
]
|
||||
|
||||
app.swagger_schema = swagger_schema
|
||||
|
||||
with open(os.path.join(STATIC_PATH, "ui.html"), "r", encoding="utf-8") as f:
|
||||
app.swagger_home_template = (
|
||||
f.read().replace(
|
||||
"{{ SWAGGER_URL }}", _swagger_spec_url
|
||||
).replace(
|
||||
"{{ DISPLAY_MODELS }}", str(-1 if not display_models else 1)
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,130 @@
|
||||
from abc import ABC
|
||||
from typing import Optional, Awaitable, Any, Type
|
||||
|
||||
from tornado import websocket
|
||||
|
||||
import tornado.websocket
|
||||
|
||||
from paste.db.basemodel import BaseModel
|
||||
from paste.web.handler import init_user_class
|
||||
|
||||
|
||||
class WebSocketHandler(tornado.websocket.WebSocketHandler, ABC):
|
||||
"""
|
||||
WebSocketHandler 的派生父类,主要增加了 send 方法,用于向客户端发送数据。
|
||||
"""
|
||||
|
||||
_web_sockets: set['WebSocketHandler'] = set()
|
||||
"""
|
||||
用于全局保存所有的客户端连接。
|
||||
"""
|
||||
|
||||
user_class: Type[BaseModel] = init_user_class()
|
||||
"""
|
||||
用户数据处理类。装饰器 web.decorators.auth_token 执行令牌验证时调用该类,用于创建用户对象,并保存在 current_user 属性中。
|
||||
注意:这里仅初始化类,而不创建对象。该类允许用户继承扩展,然后自行配置。主要用于执行有关用户的数据操作。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_socket(cls, web_socket):
|
||||
"""
|
||||
加入 WebSocket 集合。
|
||||
|
||||
:param web_socket: 要加入的 WebSocketHandler 对象
|
||||
"""
|
||||
assert hasattr(web_socket, 'send')
|
||||
cls._web_sockets.add(web_socket)
|
||||
|
||||
@classmethod
|
||||
def get_sockets(cls):
|
||||
"""
|
||||
取得 WebSocket 集合。
|
||||
"""
|
||||
return cls._web_sockets
|
||||
|
||||
@classmethod
|
||||
def has_sockets(cls):
|
||||
"""
|
||||
连接队列中是否还有连接。
|
||||
"""
|
||||
return True if cls._web_sockets else False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.token_payload: dict[str: Any] = {}
|
||||
"""
|
||||
令牌配载数据字典。装饰器 web.decorators.auth_token 执行令牌验证时解码并赋值。 在 HandlerRequest 子类中
|
||||
只要配置 auth_token 装饰即可使用该配载数据。
|
||||
|
||||
其结构为::
|
||||
|
||||
{
|
||||
'iss': private_iss,
|
||||
'iat': datetime.datetime.utcnow(),
|
||||
'exp': datetime.datetime.utcnow() + datetime.timedelta(days=7),
|
||||
'params': {
|
||||
'id': user_id,
|
||||
'username': username
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def token_params(self) -> dict:
|
||||
"""
|
||||
取出 Token 中的参数字典。
|
||||
|
||||
:return: 参数字典
|
||||
"""
|
||||
return self.token_payload.get('params', {})
|
||||
|
||||
def token_param(self, key):
|
||||
"""
|
||||
取出 Token 参数字典中的参数。
|
||||
|
||||
:param key: 参数名称
|
||||
"""
|
||||
return self.token_params().get(key, None)
|
||||
|
||||
def is_connected(self):
|
||||
"""
|
||||
检查当前WebSocket连接是否打开。
|
||||
"""
|
||||
return self.ws_connection is not None and self.ws_connection.stream is not None
|
||||
|
||||
def select_subprotocol(self, subprotocols: [str]) -> Optional[str]:
|
||||
"""
|
||||
选择子协议字符串。注意::
|
||||
|
||||
1、该方法返回的数据必须位于 subprotocols 数组中;
|
||||
2、若有 subprotocols 参数传入,默认始终返回第 0 项;
|
||||
3、用于验证的 Token 始终放在子协议的最后一项,读取该数据设置到 request.headers 中;
|
||||
|
||||
:param subprotocols: 子协议数组,当前端传入字符串时,该数组仅有一项
|
||||
:return: 选择的子协议
|
||||
"""
|
||||
if subprotocols:
|
||||
_token = subprotocols[-1]
|
||||
self.request.headers.add('Access-Token', _token)
|
||||
|
||||
return subprotocols[0]
|
||||
return None
|
||||
|
||||
def on_close(self):
|
||||
"""
|
||||
关闭连接时,从集合中删除客户端连接。
|
||||
"""
|
||||
if self in self._web_sockets:
|
||||
self._web_sockets.remove(self)
|
||||
|
||||
def send(self) -> Optional[Awaitable[None]]:
|
||||
"""
|
||||
向客户端发送数据。必须在子类中加以实现。
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def data_received(self, chunk: bytes):
|
||||
pass
|
||||
|
||||
def check_origin(self, origin):
|
||||
return True
|
||||
Reference in New Issue
Block a user