import asyncio import io import json import logging import mimetypes import random import sys import time from asyncio import Task from typing import Optional, Callable, Awaitable, Dict, Any from urllib.parse import urlencode from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPClientError, HTTPResponse from tornado.web import RequestHandler from paste.core.logging import echo_log from paste.util.encoder import JsonDumpsEncoder _global_http_client: Optional[AsyncHTTPClient] = None def get_http_client(): """获取全局共享的 HTTP 客户端,避免重复创建和销毁。""" global _global_http_client if _global_http_client is None: _global_http_client = AsyncHTTPClient() return _global_http_client async def close_http_client(): """关闭全局 HTTP 客户端。""" global _global_http_client if _global_http_client: _global_http_client.close() _global_http_client = None async def async_request(request: HTTPRequest, before_request: Callable = None, after_request: Callable = None, retry_queue: asyncio.Queue[HTTPRequest] = None, is_log_exc=True, on_error: Callable = None): """ 异步提交请求,返回响应数据对象。如提供回调函数,则将响应对象作为回调函数的参数传入,并执行。 :param request: 请求对象 :param before_request: 在提交请求前要处理的行为 :param after_request: 请求后的回调函数,回调参数为:HTTPResponse 响应对象、重试请求队列(无该参数则为 None) :param retry_queue: 重试队列,若传入该参数,则失败的请求会放入该队列 :param is_log_exc: 是否记录日志 :param on_error: 发生异常后的处理,回调参数为:请求对象、异常对象、重试请求队列(无该参数则为 None) :return: 响应数据对象 """ _http_client: Optional[AsyncHTTPClient] = None try: # 执行请求前的回调函数 if before_request: _before_result = before_request(request, retry_queue, is_log_exc=is_log_exc) if isinstance(_before_result, Awaitable): # 处理回调协程 await _before_result if is_log_exc: # 记录请求信息 echo_log(f'请求地址:{request.method}: {request.url}.') echo_log(f'主体长度:{sys.getsizeof(request.body)}.') # 在此之前的异常,不加入重试队列 _http_client = get_http_client() _response: HTTPResponse = await _http_client.fetch(request=request) if after_request: _after_result = after_request(_response, retry_queue) if isinstance(_after_result, Awaitable): # 处理协程回调 await _after_result return _response except HTTPClientError as e: if e.response and e.response.code in (302, 412): # 这里依然可以拿到响应对象,继续返回 return e.response if is_log_exc: echo_log(f'请求错误:{e},地址:{request.url}', level=logging.ERROR, is_log_exc=True) if e.response is not None: echo_log(f'响应内容:{e.response.body.decode()}', level=logging.ERROR) if retry_queue is not None and _http_client is not None: await retry_queue.put(request) if on_error: _err_result = on_error(request, e, retry_queue) if isinstance(_err_result, Awaitable): # 处理协程回调 await _err_result except ConnectionError as e: if is_log_exc: echo_log(f'连接错误:{e}', level=logging.ERROR, is_log_exc=True) if retry_queue is not None and _http_client is not None: await retry_queue.put(request) if on_error: _err_result = on_error(request, e, retry_queue) if isinstance(_err_result, Awaitable): # 处理协程回调 await _err_result except Exception as e: if is_log_exc: echo_log(f'未知错误:{e}', level=logging.ERROR, is_log_exc=True) if retry_queue is not None and _http_client is not None: await retry_queue.put(request) if on_error: _err_result = on_error(request, e, retry_queue) if isinstance(_err_result, Awaitable): # 处理协程回调 await _err_result return None async def async_concurrency(request_queue: Optional[asyncio.Queue[HTTPRequest]], con_count=10, retry=5, before_request: Callable = None, after_request: Callable = None, after_done: Callable = None, is_log_exc=True, retry_queue: asyncio.Queue[HTTPRequest] = None, response_list: list[HTTPResponse] = None, on_error: Callable = None, wait_after: int = None): """ 异步并发请求,默认并发 10 个请求,且默认合计尝试 5 次(除第 1 次外,再尝试 4 次)。 :param request_queue: 请求队列 :param con_count: 每批并发请求数量 :param retry: 总尝试次数,默认尝试 5 次 :param before_request: 在提交请求前要处理的行为 :param after_request: 请求后的回调函数,回调参数为:HTTPResponse 响应对象、重试请求队列(无该参数则为 None) :param after_done: 所有任务都完成后的回调,回调参数为:response_list,发生异常的请求不在列表中,应通过 on_error 回调获取 :param is_log_exc: 是否记录异常日志 :param retry_queue: 重试队列,若传入该参数,则失败的请求会放入该队列 :param response_list: 响应列表 :param on_error: 发生异常后的处理,回调参数为:请求对象、异常对象、重试请求队列(无该参数则为 None) :param wait_after: 在请求完成后的等待时间,应当考虑请求服务器的处理时间,必要时可设置等待时间,但是不易设置过长一般 1~3 秒 :return 若设置了 after_done 且有返回值,则返回,否则返响应列表 """ if retry_queue is None: retry_queue = asyncio.Queue() if response_list is None: response_list = [] while not request_queue.empty(): # 按配置,读取队列,创建任务组 _tasks: set[Task] = set() for _i in range(con_count): if request_queue.empty(): break _request = await request_queue.get() setattr(_request, 'max_retry', retry) _task = asyncio.create_task(async_request( request=_request, before_request=before_request, after_request=after_request, retry_queue=retry_queue, is_log_exc=is_log_exc, on_error=on_error )) _tasks.add(_task) # 执行,并等待任务组完成 response_list += await asyncio.gather(*_tasks) # 处理等待 if wait_after: await asyncio.sleep(wait_after) # 检查任务(包含重试任务)是否完成,完成则返回,否则继续 if not request_queue.empty(): continue # 任然有需要重试的请求 while not retry_queue.empty(): _request = await retry_queue.get() _retry = getattr(_request, 'retry', 0) + 1 if _retry < retry: setattr(_request, 'retry', _retry) await request_queue.put(_request) if is_log_exc and not request_queue.empty(): echo_log(f'启动重试,共有:{request_queue.qsize()} 个请求启动重试.') echo_log(f'所有请求执行完毕,任务结束.') # 所有请求包括重试都已经完成,执行回调 _result = None if after_done: _after_done_result = after_done(response_list) if isinstance(_after_done_result, Awaitable): # 处理协程回调 _result = await _after_done_result else: # 普通函数调用 _result = _after_done_result # 钩子有返回时,返回钩子处理结果 if _result is not None: return _result return response_list async def async_forward(handler: RequestHandler, forward_url: str, is_log_exc=True, request_timeout=60): """ 转发请求。 :param handler: 收到请求的控制器对象 :param forward_url: 要转发的目标地址 :param is_log_exc: 是否记录日志 :param request_timeout: 超时时长 :return: 转发响应结果 """ _req_params = { 'body': handler.request.body, 'headers': { 'Accept': '*/*', 'Access-Token': handler.request.headers.get('Access-Token'), 'Content-Type': handler.request.headers.get('Content-Type', 'application/json'), 'timestamp': f'{int(time.time() * 1000)}' }, 'method': handler.request.method, 'request_timeout': request_timeout, 'url': forward_url, } _request = HTTPRequest(**_req_params) _http_client = get_http_client() try: if is_log_exc: # 记录请求信息 echo_log(f'请求地址:{_request.method}: {_request.url}.') echo_log(f'主体长度:{sys.getsizeof(_request.body)}.') _response: HTTPResponse = await _http_client.fetch(request=_request) return _response except HTTPClientError as e: if is_log_exc: echo_log(f'请求错误:{e},地址:{_request.url}', level=logging.ERROR, is_log_exc=True) if e.response is not None: echo_log(f'响应内容:{e.response.body.decode()}', level=logging.ERROR) raise e except ConnectionError as e: if is_log_exc: echo_log(f'连接错误:{e}', level=logging.ERROR, is_log_exc=True) raise e except Exception as e: if is_log_exc: echo_log(f'未知错误:{e}', level=logging.ERROR, is_log_exc=True) raise e def build_http_request( url: str, body: Optional[Dict[str, Any]] = None, method: str = 'POST', timeout: Optional[float] = None, follow_redirects: bool = True, use_form: bool = False, extra_headers: Optional[Dict[str, str]] = None, ** kwargs ) -> HTTPRequest: """ 构建一个 tornado.httpclient.HTTPRequest 对象。 支持 GET 和 POST 方法: - GET: 参数通过 URL 查询字符串传递 - POST: 参数通过 JSON body 或 form 表单传递(由 use_form 控制) :param url: 请求的完整 URL :param body: 请求体(字典),GET 时为查询参数,POST 时为 JSON 或 form 数据 :param method: HTTP 方法,仅支持 'GET' 或 'POST' :param timeout: 请求超时时间(秒) :param follow_redirects: 是否跟随重定向 :param use_form: 如果为 True,POST 时使用 application/x-www-form-urlencoded 格式;否则使用 JSON :param extra_headers: 可选的额外请求头,用于传入 Cookie、Authorization 等 :param kwargs: 其他参数,符合 tornado.httpclient.HTTPRequest 参数要求 :return: tornado.httpclient.HTTPRequest 对象 :raises ValueError: 当 method 不合法时抛出 """ if method not in ('GET', 'POST'): raise ValueError(f"Unsupported HTTP method: {method}") body = body or {} # 基础头 headers = { 'Accept': '*/*', 'Accept-Encoding': 'gzip, deflate', 'Accept-Language': 'zh-CN,zh;q=0.9', 'Connection': 'keep-alive', 'Content-Type': 'application/x-www-form-urlencoded; charset=UTF-8', 'X-Requested-With': 'XMLHttpRequest', } # 合并额外头(优先级:extra_headers > DEFAULT_HEADERS) if extra_headers: headers.update(extra_headers) req_params = { 'url': url, 'method': method, 'headers': headers, 'follow_redirects': follow_redirects, } if timeout: req_params['request_timeout'] = timeout if method == 'GET': # GET 方法:参数拼接到 URL if body: req_params['url'] = f"{url}?{urlencode(body)}" req_params.pop('body', None) req_params['headers'].pop('Content-Type', None) req_params['headers'].pop('Content-Length', None) req_params['headers'].pop('Transfer-Encoding', None) else: # POST 方法 if use_form: # 检查 body 中是否有文件对象 has_files = any(isinstance(v, io.IOBase) for v in body.values()) if has_files: # 构建 multipart/form-data boundary = f"----WebKitFormBoundary{random.randint(10000000, 99999999)}" multipart_body = [] for key, value in body.items(): if isinstance(value, io.IOBase): # 文件对象 value.seek(0) # 确保从头读取 filename = getattr(value, 'name', key) # 尝试获取文件名 content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' multipart_body.append(f'--{boundary}\r\n'.encode()) multipart_body.append(f'Content-Disposition: form-data; name="{key}"; filename="{filename}"\r\n'.encode()) multipart_body.append(f'Content-Type: {content_type}\r\n\r\n'.encode()) multipart_body.append(value.read()) multipart_body.append(b'\r\n') else: # 普通文本 multipart_body.append(f'--{boundary}\r\n'.encode()) multipart_body.append(f'Content-Disposition: form-data; name="{key}"\r\n\r\n'.encode()) # 处理布尔值和 None if isinstance(value, bool): value = str(value).lower() elif value is None: value = '' else: value = str(value) multipart_body.append(value.encode('utf-8')) multipart_body.append(b'\r\n') multipart_body.append(f'--{boundary}--\r\n'.encode()) req_params['body'] = b''.join(multipart_body) headers['Content-Length'] = str(len(req_params['body'])) headers['Content-Type'] = f'multipart/form-data; boundary={boundary}' else: # 普通 form 表单(无文件) req_params['body'] = urlencode(body).encode('utf-8') headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' headers['Content-Length'] = str(len(req_params['body'])) else: body_bytes = json.dumps(body, cls=JsonDumpsEncoder, ensure_ascii=False).encode('utf-8') req_params['body'] = body_bytes headers['Content-Length'] = str(len(body_bytes)) # 保持 application/json req_params.update(kwargs) return HTTPRequest(**req_params)