Merge commit '47296980495f8bbfc9493e93de85dd62de6fa6b9' as 'paste-framework'
This commit is contained in:
@@ -0,0 +1,364 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from asyncio import Task
|
||||
from typing import Optional, Callable, Awaitable, Dict, Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPClientError, HTTPResponse
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
from paste.core.logging import echo_log
|
||||
from paste.util.encoder import JsonDumpsEncoder
|
||||
|
||||
|
||||
_global_http_client: Optional[AsyncHTTPClient] = None
|
||||
|
||||
|
||||
def get_http_client():
|
||||
"""获取全局共享的 HTTP 客户端,避免重复创建和销毁。"""
|
||||
global _global_http_client
|
||||
if _global_http_client is None:
|
||||
_global_http_client = AsyncHTTPClient()
|
||||
return _global_http_client
|
||||
|
||||
|
||||
async def close_http_client():
|
||||
"""关闭全局 HTTP 客户端。"""
|
||||
global _global_http_client
|
||||
if _global_http_client:
|
||||
_global_http_client.close()
|
||||
_global_http_client = None
|
||||
|
||||
|
||||
async def async_request(request: HTTPRequest, before_request: Callable = None, after_request: Callable = None,
|
||||
retry_queue: asyncio.Queue[HTTPRequest] = None, is_log_exc=True,
|
||||
on_error: Callable = None):
|
||||
"""
|
||||
异步提交请求,返回响应数据对象。如提供回调函数,则将响应对象作为回调函数的参数传入,并执行。
|
||||
|
||||
:param request: 请求对象
|
||||
:param before_request: 在提交请求前要处理的行为
|
||||
:param after_request: 请求后的回调函数,回调参数为:HTTPResponse 响应对象、重试请求队列(无该参数则为 None)
|
||||
:param retry_queue: 重试队列,若传入该参数,则失败的请求会放入该队列
|
||||
:param is_log_exc: 是否记录日志
|
||||
:param on_error: 发生异常后的处理,回调参数为:请求对象、异常对象、重试请求队列(无该参数则为 None)
|
||||
:return: 响应数据对象
|
||||
"""
|
||||
_http_client: Optional[AsyncHTTPClient] = None
|
||||
try:
|
||||
# 执行请求前的回调函数
|
||||
if before_request:
|
||||
_before_result = before_request(request, retry_queue, is_log_exc=is_log_exc)
|
||||
if isinstance(_before_result, Awaitable):
|
||||
# 处理回调协程
|
||||
await _before_result
|
||||
|
||||
if is_log_exc:
|
||||
# 记录请求信息
|
||||
echo_log(f'请求地址:{request.method}: {request.url}.')
|
||||
echo_log(f'主体长度:{sys.getsizeof(request.body)}.')
|
||||
|
||||
# 在此之前的异常,不加入重试队列
|
||||
_http_client = get_http_client()
|
||||
_response: HTTPResponse = await _http_client.fetch(request=request)
|
||||
if after_request:
|
||||
_after_result = after_request(_response, retry_queue)
|
||||
if isinstance(_after_result, Awaitable):
|
||||
# 处理协程回调
|
||||
await _after_result
|
||||
return _response
|
||||
except HTTPClientError as e:
|
||||
if e.response and e.response.code in (302, 412):
|
||||
# 这里依然可以拿到响应对象,继续返回
|
||||
return e.response
|
||||
if is_log_exc:
|
||||
echo_log(f'请求错误:{e},地址:{request.url}', level=logging.ERROR, is_log_exc=True)
|
||||
if e.response is not None:
|
||||
echo_log(f'响应内容:{e.response.body.decode()}', level=logging.ERROR)
|
||||
if retry_queue is not None and _http_client is not None:
|
||||
await retry_queue.put(request)
|
||||
if on_error:
|
||||
_err_result = on_error(request, e, retry_queue)
|
||||
if isinstance(_err_result, Awaitable):
|
||||
# 处理协程回调
|
||||
await _err_result
|
||||
except ConnectionError as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'连接错误:{e}', level=logging.ERROR, is_log_exc=True)
|
||||
if retry_queue is not None and _http_client is not None:
|
||||
await retry_queue.put(request)
|
||||
if on_error:
|
||||
_err_result = on_error(request, e, retry_queue)
|
||||
if isinstance(_err_result, Awaitable):
|
||||
# 处理协程回调
|
||||
await _err_result
|
||||
except Exception as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'未知错误:{e}', level=logging.ERROR, is_log_exc=True)
|
||||
if retry_queue is not None and _http_client is not None:
|
||||
await retry_queue.put(request)
|
||||
if on_error:
|
||||
_err_result = on_error(request, e, retry_queue)
|
||||
if isinstance(_err_result, Awaitable):
|
||||
# 处理协程回调
|
||||
await _err_result
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def async_concurrency(request_queue: Optional[asyncio.Queue[HTTPRequest]], con_count=10, retry=5,
|
||||
before_request: Callable = None,
|
||||
after_request: Callable = None, after_done: Callable = None,
|
||||
is_log_exc=True,
|
||||
retry_queue: asyncio.Queue[HTTPRequest] = None,
|
||||
response_list: list[HTTPResponse] = None,
|
||||
on_error: Callable = None,
|
||||
wait_after: int = None):
|
||||
"""
|
||||
异步并发请求,默认并发 10 个请求,且默认合计尝试 5 次(除第 1 次外,再尝试 4 次)。
|
||||
|
||||
:param request_queue: 请求队列
|
||||
:param con_count: 每批并发请求数量
|
||||
:param retry: 总尝试次数,默认尝试 5 次
|
||||
:param before_request: 在提交请求前要处理的行为
|
||||
:param after_request: 请求后的回调函数,回调参数为:HTTPResponse 响应对象、重试请求队列(无该参数则为 None)
|
||||
:param after_done: 所有任务都完成后的回调,回调参数为:response_list,发生异常的请求不在列表中,应通过 on_error 回调获取
|
||||
:param is_log_exc: 是否记录异常日志
|
||||
:param retry_queue: 重试队列,若传入该参数,则失败的请求会放入该队列
|
||||
:param response_list: 响应列表
|
||||
:param on_error: 发生异常后的处理,回调参数为:请求对象、异常对象、重试请求队列(无该参数则为 None)
|
||||
:param wait_after: 在请求完成后的等待时间,应当考虑请求服务器的处理时间,必要时可设置等待时间,但是不易设置过长一般 1~3 秒
|
||||
:return 若设置了 after_done 且有返回值,则返回,否则返响应列表
|
||||
"""
|
||||
if retry_queue is None:
|
||||
retry_queue = asyncio.Queue()
|
||||
|
||||
if response_list is None:
|
||||
response_list = []
|
||||
|
||||
while not request_queue.empty():
|
||||
# 按配置,读取队列,创建任务组
|
||||
_tasks: set[Task] = set()
|
||||
for _i in range(con_count):
|
||||
if request_queue.empty():
|
||||
break
|
||||
|
||||
_request = await request_queue.get()
|
||||
setattr(_request, 'max_retry', retry)
|
||||
_task = asyncio.create_task(async_request(
|
||||
request=_request, before_request=before_request, after_request=after_request,
|
||||
retry_queue=retry_queue, is_log_exc=is_log_exc, on_error=on_error
|
||||
))
|
||||
_tasks.add(_task)
|
||||
|
||||
# 执行,并等待任务组完成
|
||||
response_list += await asyncio.gather(*_tasks)
|
||||
# 处理等待
|
||||
if wait_after:
|
||||
await asyncio.sleep(wait_after)
|
||||
|
||||
# 检查任务(包含重试任务)是否完成,完成则返回,否则继续
|
||||
if not request_queue.empty():
|
||||
continue
|
||||
|
||||
# 任然有需要重试的请求
|
||||
while not retry_queue.empty():
|
||||
_request = await retry_queue.get()
|
||||
_retry = getattr(_request, 'retry', 0) + 1
|
||||
if _retry < retry:
|
||||
setattr(_request, 'retry', _retry)
|
||||
await request_queue.put(_request)
|
||||
|
||||
if is_log_exc and not request_queue.empty():
|
||||
echo_log(f'启动重试,共有:{request_queue.qsize()} 个请求启动重试.')
|
||||
|
||||
echo_log(f'所有请求执行完毕,任务结束.')
|
||||
# 所有请求包括重试都已经完成,执行回调
|
||||
_result = None
|
||||
if after_done:
|
||||
_after_done_result = after_done(response_list)
|
||||
if isinstance(_after_done_result, Awaitable):
|
||||
# 处理协程回调
|
||||
_result = await _after_done_result
|
||||
else:
|
||||
# 普通函数调用
|
||||
_result = _after_done_result
|
||||
# 钩子有返回时,返回钩子处理结果
|
||||
if _result is not None:
|
||||
return _result
|
||||
return response_list
|
||||
|
||||
|
||||
async def async_forward(handler: RequestHandler, forward_url: str, is_log_exc=True, request_timeout=60):
|
||||
"""
|
||||
转发请求。
|
||||
|
||||
:param handler: 收到请求的控制器对象
|
||||
:param forward_url: 要转发的目标地址
|
||||
:param is_log_exc: 是否记录日志
|
||||
:param request_timeout: 超时时长
|
||||
:return: 转发响应结果
|
||||
"""
|
||||
_req_params = {
|
||||
'body': handler.request.body,
|
||||
'headers': {
|
||||
'Accept': '*/*',
|
||||
'Access-Token': handler.request.headers.get('Access-Token'),
|
||||
'Content-Type': handler.request.headers.get('Content-Type', 'application/json'),
|
||||
'timestamp': f'{int(time.time() * 1000)}'
|
||||
},
|
||||
'method': handler.request.method,
|
||||
'request_timeout': request_timeout,
|
||||
'url': forward_url,
|
||||
}
|
||||
_request = HTTPRequest(**_req_params)
|
||||
|
||||
_http_client = get_http_client()
|
||||
try:
|
||||
if is_log_exc:
|
||||
# 记录请求信息
|
||||
echo_log(f'请求地址:{_request.method}: {_request.url}.')
|
||||
echo_log(f'主体长度:{sys.getsizeof(_request.body)}.')
|
||||
_response: HTTPResponse = await _http_client.fetch(request=_request)
|
||||
return _response
|
||||
except HTTPClientError as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'请求错误:{e},地址:{_request.url}', level=logging.ERROR, is_log_exc=True)
|
||||
if e.response is not None:
|
||||
echo_log(f'响应内容:{e.response.body.decode()}', level=logging.ERROR)
|
||||
raise e
|
||||
except ConnectionError as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'连接错误:{e}', level=logging.ERROR, is_log_exc=True)
|
||||
raise e
|
||||
except Exception as e:
|
||||
if is_log_exc:
|
||||
echo_log(f'未知错误:{e}', level=logging.ERROR, is_log_exc=True)
|
||||
raise e
|
||||
|
||||
|
||||
def build_http_request(
|
||||
url: str,
|
||||
body: Optional[Dict[str, Any]] = None,
|
||||
method: str = 'POST',
|
||||
timeout: Optional[float] = None,
|
||||
follow_redirects: bool = True,
|
||||
use_form: bool = False,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
** kwargs
|
||||
) -> HTTPRequest:
|
||||
"""
|
||||
构建一个 tornado.httpclient.HTTPRequest 对象。
|
||||
|
||||
支持 GET 和 POST 方法:
|
||||
- GET: 参数通过 URL 查询字符串传递
|
||||
- POST: 参数通过 JSON body 或 form 表单传递(由 use_form 控制)
|
||||
|
||||
:param url: 请求的完整 URL
|
||||
:param body: 请求体(字典),GET 时为查询参数,POST 时为 JSON 或 form 数据
|
||||
:param method: HTTP 方法,仅支持 'GET' 或 'POST'
|
||||
:param timeout: 请求超时时间(秒)
|
||||
:param follow_redirects: 是否跟随重定向
|
||||
:param use_form: 如果为 True,POST 时使用 application/x-www-form-urlencoded 格式;否则使用 JSON
|
||||
:param extra_headers: 可选的额外请求头,用于传入 Cookie、Authorization 等
|
||||
:param kwargs: 其他参数,符合 tornado.httpclient.HTTPRequest 参数要求
|
||||
:return: tornado.httpclient.HTTPRequest 对象
|
||||
:raises ValueError: 当 method 不合法时抛出
|
||||
"""
|
||||
if method not in ('GET', 'POST'):
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
body = body or {}
|
||||
|
||||
# 基础头
|
||||
headers = {
|
||||
'Accept': '*/*',
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
'Accept-Language': 'zh-CN,zh;q=0.9',
|
||||
'Connection': 'keep-alive',
|
||||
'Content-Type': 'application/x-www-form-urlencoded; charset=UTF-8',
|
||||
'X-Requested-With': 'XMLHttpRequest',
|
||||
}
|
||||
|
||||
# 合并额外头(优先级:extra_headers > DEFAULT_HEADERS)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
req_params = {
|
||||
'url': url,
|
||||
'method': method,
|
||||
'headers': headers,
|
||||
'follow_redirects': follow_redirects,
|
||||
}
|
||||
|
||||
if timeout:
|
||||
req_params['request_timeout'] = timeout
|
||||
|
||||
if method == 'GET':
|
||||
# GET 方法:参数拼接到 URL
|
||||
if body:
|
||||
req_params['url'] = f"{url}?{urlencode(body)}"
|
||||
req_params.pop('body', None)
|
||||
req_params['headers'].pop('Content-Type', None)
|
||||
req_params['headers'].pop('Content-Length', None)
|
||||
req_params['headers'].pop('Transfer-Encoding', None)
|
||||
else:
|
||||
# POST 方法
|
||||
if use_form:
|
||||
# 检查 body 中是否有文件对象
|
||||
has_files = any(isinstance(v, io.IOBase) for v in body.values())
|
||||
|
||||
if has_files:
|
||||
# 构建 multipart/form-data
|
||||
boundary = f"----WebKitFormBoundary{random.randint(10000000, 99999999)}"
|
||||
multipart_body = []
|
||||
|
||||
for key, value in body.items():
|
||||
if isinstance(value, io.IOBase):
|
||||
# 文件对象
|
||||
value.seek(0) # 确保从头读取
|
||||
filename = getattr(value, 'name', key) # 尝试获取文件名
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
|
||||
multipart_body.append(f'--{boundary}\r\n'.encode())
|
||||
multipart_body.append(f'Content-Disposition: form-data; name="{key}"; filename="{filename}"\r\n'.encode())
|
||||
multipart_body.append(f'Content-Type: {content_type}\r\n\r\n'.encode())
|
||||
multipart_body.append(value.read())
|
||||
multipart_body.append(b'\r\n')
|
||||
else:
|
||||
# 普通文本
|
||||
multipart_body.append(f'--{boundary}\r\n'.encode())
|
||||
multipart_body.append(f'Content-Disposition: form-data; name="{key}"\r\n\r\n'.encode())
|
||||
# 处理布尔值和 None
|
||||
if isinstance(value, bool):
|
||||
value = str(value).lower()
|
||||
elif value is None:
|
||||
value = ''
|
||||
else:
|
||||
value = str(value)
|
||||
multipart_body.append(value.encode('utf-8'))
|
||||
multipart_body.append(b'\r\n')
|
||||
|
||||
multipart_body.append(f'--{boundary}--\r\n'.encode())
|
||||
req_params['body'] = b''.join(multipart_body)
|
||||
headers['Content-Length'] = str(len(req_params['body']))
|
||||
headers['Content-Type'] = f'multipart/form-data; boundary={boundary}'
|
||||
else:
|
||||
# 普通 form 表单(无文件)
|
||||
req_params['body'] = urlencode(body).encode('utf-8')
|
||||
headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
|
||||
headers['Content-Length'] = str(len(req_params['body']))
|
||||
else:
|
||||
body_bytes = json.dumps(body, cls=JsonDumpsEncoder, ensure_ascii=False).encode('utf-8')
|
||||
req_params['body'] = body_bytes
|
||||
headers['Content-Length'] = str(len(body_bytes))
|
||||
# 保持 application/json
|
||||
|
||||
req_params.update(kwargs)
|
||||
return HTTPRequest(**req_params)
|
||||
Reference in New Issue
Block a user