Files
d3i-szct/paste/db/redis.py
T
zwf 4729698049 Squashed 'paste-framework/' content from commit 34e8684
git-subtree-dir: paste-framework
git-subtree-split: 34e8684c4bc3cebbe177509f42ab4ef5b5425a7a
2026-06-02 19:09:22 +08:00

996 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
封装了 Python 对 Redis 的基本操作。
同时处理了 Java 在操作 Redis 后留下的字节码问题。
"""
import asyncio
import hashlib
import pathlib
import random
import types
from logging import ERROR, WARNING
from typing import Optional, Callable, Awaitable, Union, Tuple, Dict
import javaobj
import redis
from redis.asyncio import ConnectionPool, StrictRedis
from redis.client import Pipeline
from paste.core import aio_pool, config, logging
from paste.util.snow_id import IdWorker
class LuaScriptManager:
"""
Lua 脚本管理器。
负责加载、缓存和执行 Lua 脚本。
"""
# 默认 Lua 脚本内容(作为内置默认值,无需外部文件)
DEFAULT_SCRIPTS = {
"stock_decr": """
-- 扣减库存(原子操作)
-- KEYS[1]: 库存 key
-- ARGV[1]: 扣减数量
-- 返回值: 1=成功, 0=库存不足, -1=key不存在
local key = KEYS[1]
local quantity = tonumber(ARGV[1])
local current = redis.call('GET', key)
if not current then
return -1
end
current = tonumber(current)
if current >= quantity then
redis.call('DECRBY', key, quantity)
return 1
else
return 0
end
""",
"stock_incr": """
-- 增加库存(原子操作)
-- KEYS[1]: 库存 key
-- ARGV[1]: 增加数量
-- 返回值: 当前库存
local key = KEYS[1]
local quantity = tonumber(ARGV[1])
redis.call('INCRBY', key, quantity)
return redis.call('GET', key)
""",
"stock_peek": """
-- 查看库存(原子操作)
-- KEYS[1]: 库存 key
-- 返回值: 当前库存
local key = KEYS[1]
local current = redis.call('GET', key)
if not current then
return -1
end
return tonumber(current)
""",
}
_scripts: Dict[str, Tuple[str, str]] = {} # name -> (sha, script_content)
_script_dir: Optional[pathlib.Path] = None
_use_external_files: bool = False # 是否使用外部文件
@classmethod
def set_script_dir(cls, script_dir: str, use_external: bool = True):
"""
设置 Lua 脚本目录
:param script_dir: 脚本目录路径
:param use_external: 是否使用外部文件(False 则使用内置默认脚本)
"""
cls._script_dir = pathlib.Path(script_dir) if script_dir else None
cls._use_external_files = use_external
@classmethod
async def load_script(cls, redis_client: StrictRedis, script_name: str) -> str:
"""
加载并注册 Lua 脚本
优先使用外部文件,不存在则使用内置默认脚本
:param redis_client: Redis 客户端
:param script_name: 脚本名称(如 stock_decr
:return: 脚本 SHA
"""
script_content = None
# 尝试从外部文件加载
if cls._use_external_files and cls._script_dir:
script_path = cls._script_dir / f"{script_name}.lua"
if script_path.exists():
with open(script_path, 'r', encoding='utf-8') as f:
script_content = f.read()
logging.echo_log(f"Lua 脚本从外部文件加载: {script_path}")
# 使用内置默认脚本
if script_content is None:
if script_name not in cls.DEFAULT_SCRIPTS:
raise ValueError(f"脚本不存在: {script_name},且无内置默认值")
script_content = cls.DEFAULT_SCRIPTS[script_name]
logging.echo_log(f"Lua 脚本使用内置默认值: {script_name}")
# 计算 SHA
sha = hashlib.sha1(script_content.encode()).hexdigest()
# 缓存脚本
cls._scripts[script_name] = (sha, script_content)
# 预加载到 Redis
try:
await redis_client.script_load(script_content)
except Exception:
pass # 预加载失败不影响后续使用
return sha
@classmethod
async def load_default_scripts(cls, redis_client: StrictRedis):
"""
加载所有默认脚本
"""
for script_name in cls.DEFAULT_SCRIPTS.keys():
await cls.load_script(redis_client, script_name)
logging.echo_log(f"已加载 {len(cls.DEFAULT_SCRIPTS)} 个默认 Lua 脚本")
@classmethod
async def execute(cls, redis_client: StrictRedis, script_name: str,
keys: list, args: list) -> any:
"""
执行 Lua 脚本
优先使用 evalsha(性能更好),失败则降级到 eval
:param redis_client: Redis 客户端
:param script_name: 脚本名称
:param keys: KEYS 参数列表
:param args: ARGV 参数列表
:return: 脚本执行结果
"""
if script_name not in cls._scripts:
# 脚本未加载,尝试加载
await cls.load_script(redis_client, script_name)
sha, script_content = cls._scripts[script_name]
try:
return await redis_client.evalsha(sha, len(keys), *keys, *args)
except redis.ResponseError as e:
if "NOSCRIPT" in str(e):
# 重新加载并重试
await redis_client.script_load(script_content)
return await redis_client.evalsha(sha, len(keys), *keys, *args)
else:
raise
@classmethod
async def reload_script(cls, redis_client: StrictRedis, script_name: str) -> str:
"""重新加载指定的 Lua 脚本"""
if script_name in cls._scripts:
del cls._scripts[script_name]
return await cls.load_script(redis_client, script_name)
class Redis:
"""
Redis 基础操作。
"""
connect_pool: Optional[ConnectionPool] = None
prefix = b'\xac\xed\x00\x05'
utf_flag = b'\x74'
lua_scripts = LuaScriptManager
"""Lua 脚本管理器。"""
@classmethod
def is_java_serialized(cls, bs: Union[bytes, str]):
"""
判断是否为 Java 序列化后的数据。
:param bs: 字节流
"""
if not isinstance(bs, bytes):
return False
return bs[:4] == cls.prefix
@classmethod
async def get_pool(cls) -> ConnectionPool:
"""
取得 Redis 连接池。
:return: 连接池对象
"""
if cls.connect_pool is None:
_conn_params = config.get_config("redis.connection")
cls.connect_pool = ConnectionPool.from_url(**_conn_params)
return cls.connect_pool
@classmethod
async def close_pool(cls):
if cls.connect_pool is not None:
await cls.connect_pool.disconnect()
cls.connect_pool = None
@classmethod
async def get_redis(cls) -> StrictRedis:
"""
取得数据库对象。
:return: 数据库对象
"""
_pool = await cls.get_pool()
return StrictRedis(
connection_pool=_pool,
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
socket_keepalive=True
)
@classmethod
async def ping(cls):
"""
测试连接。
:return: 测试结果
"""
async with await cls.get_redis() as _redis:
return await _redis.ping()
@classmethod
async def get_pipe(cls, transaction: bool = True, shard_hint=None) -> Pipeline:
"""
取得管道对象。
"""
async with await cls.get_redis() as _redis:
return _redis.pipeline(transaction=transaction, shard_hint=shard_hint)
# ========== Lua 脚本初始化 ==========
@classmethod
async def init_lua_scripts(cls, script_dir: str = None, use_external: bool = False):
"""
初始化 Lua 脚本
建议在应用启动时调用一次
:param script_dir: 外部脚本目录(可选)
:param use_external: 是否使用外部文件,默认 False 使用内置脚本
"""
if script_dir:
cls.lua_scripts.set_script_dir(script_dir, use_external)
async with await cls.get_redis() as _redis:
await cls.lua_scripts.load_default_scripts(_redis)
# ========== 库存核心方法(原子操作) ==========
@classmethod
async def stock_decr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, str]:
"""
扣减库存(原子操作)
使用 Lua 脚本保证原子性,防止超卖
:param stock_key: 库存 Key(支持分片,如 stock:iPhone15:shard:0
:param quantity: 扣减数量
:return: (是否成功, 消息)
"""
async with await cls.get_redis() as _redis:
try:
result = await cls.lua_scripts.execute(
_redis,
"stock_decr",
keys=[stock_key],
args=[quantity]
)
if result == 1:
return True, "扣减成功"
elif result == 0:
return False, "库存不足"
else:
return False, "商品不存在"
except Exception as e:
logging.echo_log(f"扣减库存异常: {e}", level=ERROR, is_log_exc=True)
return False, f"系统异常: {e}"
@classmethod
async def stock_incr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, int]:
"""
增加库存(原子操作)
用于退货入库、补货等场景
:param stock_key: 库存 Key
:param quantity: 增加数量
:return: (是否成功, 当前库存)
"""
async with await cls.get_redis() as _redis:
try:
result = await cls.lua_scripts.execute(
_redis,
"stock_incr",
keys=[stock_key],
args=[quantity]
)
return True, int(result)
except Exception as e:
logging.echo_log(f"增加库存异常: {e}", level=ERROR, is_log_exc=True)
return False, 0
@classmethod
async def stock_peek(cls, stock_key: str) -> int:
"""
查看剩余库存(原子操作)
:param stock_key: 库存 Key
:return: 剩余库存
"""
async with await cls.get_redis() as _redis:
try:
result = await cls.lua_scripts.execute(
_redis,
"stock_peek",
keys=[stock_key],
args=[]
)
return int(result) if result >= 0 else 0
except Exception as e:
logging.echo_log(f"查询库存异常: {e}", level=ERROR, is_log_exc=True)
return 0
# ========== 库存分片辅助方法 ==========
@classmethod
def get_shard_key(cls, sku_id: str, shard_id: int) -> str:
"""
获取分片 Key。
推荐格式:{业务域}:{实体}:{唯一标识}:{分片/维度}:{扩展}
:param sku_id: 商品ID
:param shard_id: 分片ID
:return: 分片 Key
"""
return f"stock:{sku_id}:shard:{shard_id}"
@classmethod
def get_user_shard(cls, sku_id: str, user_id: str, shard_count: int = 10) -> str:
"""
根据用户ID获取分片 Key
:param sku_id: 商品ID
:param user_id: 用户ID
:param shard_count: 分片总数
:return: 分片 Key
"""
shard = hash(user_id) % shard_count
return cls.get_shard_key(sku_id, shard)
@classmethod
async def init_sharded_stock(cls, sku_id: str, total_stock: int, shard_count: int = 10):
"""
初始化分片库存
:param sku_id: 商品ID
:param total_stock: 总库存
:param shard_count: 分片数量
"""
base = total_stock // shard_count
remainder = total_stock % shard_count
async with await cls.get_redis() as _redis:
for i in range(shard_count):
shard_key = cls.get_shard_key(sku_id, i)
stock = base + (1 if i < remainder else 0)
await _redis.set(shard_key, stock)
logging.echo_log(f"初始化分片 {i}: {shard_key} = {stock}")
# ========== 基础 KV 操作 ==========
@classmethod
async def keys(cls):
"""
取得所有的 Key。
"""
async with await cls.get_redis() as _redis:
_keys = await _redis.keys()
return _keys
@classmethod
async def show_keys(cls):
"""
控制台显示所有的 Keys。
"""
_keys = await cls.keys()
for _key in _keys:
if isinstance(_key, bytes):
if cls.is_java_serialized(_key):
print(_key[7:].decode('utf-8'), '=>', _key)
else:
print(_key.decode('utf-8'), '=>', _key)
else:
print(_key)
@classmethod
async def get(cls, key: Union[bytes, str]):
"""
多种方式读取 Redis 中的数据。
:param key: Redis Key 名称
:return: 数据内容
"""
async with await cls.get_redis() as _redis:
_result = await _redis.get(key)
if _result is None and not cls.is_java_serialized(key):
if isinstance(key, str):
key_bytes = key.encode('utf-8')
else:
key_bytes = key
_key = cls.prefix + cls.utf_flag + len(key_bytes).to_bytes(2, 'big') + key_bytes
_result = await _redis.get(_key)
if _result is None:
return _result
if isinstance(_result, bytes) and cls.is_java_serialized(_result):
return javaobj.loads(_result)
else:
return _result
@classmethod
async def set(cls, key: str, value: any, ex: int = None):
"""
设置键值对
"""
async with await cls.get_redis() as _redis:
return await _redis.set(key, value, ex=ex)
@classmethod
async def delete(cls, key: str):
"""
删除键
"""
async with await cls.get_redis() as _redis:
return await _redis.delete(key)
@classmethod
async def exists(cls, key: str) -> bool:
"""
检查键是否存在
"""
async with await cls.get_redis() as _redis:
return await _redis.exists(key) > 0
@classmethod
async def expire(cls, key: str, seconds: int):
"""
设置过期时间
"""
async with await cls.get_redis() as _redis:
return await _redis.expire(key, seconds)
@classmethod
async def incr(cls, key: str) -> int:
"""
原子递增
"""
async with await cls.get_redis() as _redis:
return await _redis.incr(key)
# ========== 回调处理 ==========
@classmethod
def get_func_name(cls, func):
"""
得到方法名称。
:param func: 方法对象
:return: 方法名称
"""
if isinstance(func, types.FunctionType):
return func.__name__
elif isinstance(func, types.MethodType):
return func.__func__.__name__
elif isinstance(func, (classmethod, staticmethod)):
return func.__func__.__name__
elif hasattr(func, '__call__'):
return func.__class__.__name__
else:
return str(func)
@classmethod
async def callback(cls, func: Callable, message_key: str, is_delete=False):
"""
根据消息 KEY 读取数据,并执行回调函数,如果回调函数正确执行,则根据参数 is_delete 判断删除消息。
:param func: 回调函数
:param message_key: 消息 KEY
:param is_delete: 是否删除处理过的消息
"""
result = None
async with await cls.get_redis() as _redis:
try:
message_data = await _redis.hgetall(message_key)
if not message_data:
logging.echo_log(f"警告: 空消息数据 {message_key}.", level=WARNING)
return result
if func:
# 处理回调
result = func(message_data)
# 处理协程
if isinstance(result, Awaitable):
result = await result
if is_delete:
# 回调正确执行,且设置为删除删除的,才会删除消息
await _redis.delete(message_key)
logging.echo_log(f"消息已删除: {message_key};数据为:{message_data}.")
except redis.RedisError as e:
logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True)
except Exception as e:
logging.echo_log(
f"执行回调异常:{e};方法:{cls.get_func_name(func)};消息: {message_key}.",
level=ERROR, is_log_exc=True
)
return result
class PubSubActor(Redis):
"""
发布订阅执行器。用于发布消息和订阅消息。
订阅采用阻塞式读取,可以在读取到数据后,执行回调方法,并根据参数确定是否删除历史消息。
"""
def __init__(self, hash_name: str):
self.hash_name = f"{hash_name}_HASH_NAME"
self.channel = f"{hash_name}_CHANNEL"
self.running = False
"""
优雅退出控制标志
"""
self.stopping = False
"""
控制整个 run_forever 循环退出
"""
async def publish(self, data: dict) -> str:
"""
数据写入 Redis 并发布消息。
:param data: 写入 Redis 的数据
:return: 消息ID
"""
async with await self.get_redis() as _redis:
# 生成雪花 ID 作为 Hash Key
_random_num = random.randint(1000, 9999)
_id = IdWorker.get_id_worker(3, 3, _random_num).get_id()
# 写入Redis hash
await _redis.hset(f"{self.hash_name}:{_id}", mapping=data)
# 发布新消息通知
await _redis.publish(self.channel, _id)
return _id
async def subscribe(self, func: Callable = None, is_delete=False):
"""
监听消息。
:param func: 监听回调程序
:param is_delete: 回调执行完毕后,是否删除消息
"""
async with await self.get_redis() as _redis:
_pubsub = _redis.pubsub()
await _pubsub.subscribe(self.channel)
try:
self.running = True
# 使用 while 循环,而不是直接 async for,以便加入超时控制
while not self.stopping and self.running:
try:
# 每次循环都重新获取迭代器
listen_iter = _pubsub.listen()
message = await asyncio.wait_for(listen_iter.__anext__(), timeout=60.0)
if message["type"] != "message":
continue
message_id = message["data"]
message_key = f"{self.hash_name}:{message_id}"
try:
# 隔离处理回调异常
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
await aio_pool.run_background_task(self.callback(func, message_key, is_delete), 10)
# await self.callback(func, message_key, is_delete=is_delete)
except Exception:
# 继续处理下条消息
continue
except asyncio.TimeoutError:
# 超时,是心跳成功的标志
logging.echo_log("心跳:连接正常,继续监听...")
continue
except redis.exceptions.ConnectionError as e:
# 连接错误,触发重连
logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True)
raise e
except StopAsyncIteration:
# pubsub 正常关闭
logging.echo_log("PubSub 迭代器已停止.")
break
except (asyncio.CancelledError, KeyboardInterrupt):
logging.echo_log("收到退出信号,停止监听...")
self.running = False
raise
except Exception as e:
logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True)
raise e
finally:
self.running = False
try:
await _pubsub.unsubscribe(self.channel)
await _pubsub.close()
except Exception as close_err:
logging.echo_log(f"资源关闭异常: {close_err}.")
finally:
logging.echo_log("监听已完全停止.")
async def run_forever(self, func: Callable = None, is_delete=False):
"""
持久运行的监听器,包含自动重连逻辑和优雅退出。
"""
while not self.stopping:
try:
logging.echo_log("启动新的监听会话...")
await self.subscribe(func, is_delete)
except (asyncio.CancelledError, KeyboardInterrupt):
logging.echo_log("收到退出信号,停止监听...")
self.stopping = True
break
except Exception as e:
logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True)
if self.stopping:
logging.echo_log("总开关已打开,停止重连.")
break
logging.echo_log("等待重新连接...")
try:
# 关键改动:直接使用 sleep,它本身就是可中断的
await asyncio.sleep(10)
except asyncio.CancelledError:
logging.echo_log("等待期间被取消,准备退出.")
break
logging.echo_log("监听服务已完全停止.")
async def history(self, func: Callable = None, is_delete=False):
"""
处理历史数据。
:param func: 监听回调程序
:param is_delete: 回调执行完毕后,是否删除消息
"""
async with await self.get_redis() as _redis:
_keys = await _redis.keys()
for _k in _keys:
try:
# 隔离处理回调异常
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
await aio_pool.run_background_task(self.callback(func, _k, is_delete), 10)
# await self.callback(func, _k, is_delete=is_delete)
except Exception:
# 继续处理下条消息
continue
def subscribe_stop(self):
self.running = False
self.stopping = True
class StreamActor(Redis):
"""
流执行器。使用 Redis Streams 实现发布消息和消费消息。
消费采用消费者组模式,支持消息确认和可靠传递。
方法结构与 PubSubActor 保持一致,便于无缝替换。
此版本集成了启动时的僵尸任务自动恢复功能。
"""
@classmethod
def actor_config(cls, config_path: str):
"""
根据路径,取得配置信息。
:param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔
:return:
"""
_stream_name = config_path.split(".")[-1].upper()
_stream_config = config.get_config(config_path)
_group_name = _stream_config.get('group', f"{_stream_name}_GROUP")
_consumer_name = _stream_config.get('consumer', f"{_stream_name}_CONSUMER")
_snow_id = IdWorker.get_id_worker().get_id()
_consumer_name = f"{_consumer_name}_{_snow_id}"
return _stream_name, _group_name, _consumer_name
@classmethod
def new_actor(cls, config_path: str):
"""
根据配置文件中的配置小节创建流执行器。
:param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔
:return: 执行器对象
"""
_stream_name, _group_name, _consumer_name = cls.actor_config(config_path)
return cls(_stream_name, _group_name, _consumer_name)
def __init__(self, stream_name: str, group_name: str, consumer_name: str):
"""
初始化流执行器。
:param stream_name: Redis Stream 的名称
:param group_name: 消费者组的名称
:param consumer_name: 当前消费者的名称
"""
self.stream_name = stream_name
self.group_name = group_name
self.consumer_name = consumer_name
self.running = False
"""
优雅退出控制标志
"""
self.stopping = False
"""
控制整个 run_forever 循环退出
"""
async def _ensure_group_exists(self):
"""确保消费者组已存在,如果不存在则创建。"""
try:
_redis = await self.get_redis()
await _redis.xgroup_create(
name=self.stream_name,
groupname=self.group_name,
id='0', # 从头开始消费
mkstream=True # Stream 不存在时自动创建
)
logging.echo_log(f"消费者组 '{self.group_name}' 已创建.")
except redis.exceptions.ResponseError as e:
if "Consumer Group name already exists" in str(e):
logging.echo_log(f"消费者组 '{self.group_name}' 已存在.")
else:
raise
async def publish(self, data: dict) -> str:
"""
将数据作为消息写入 Redis Stream。
:param data: 写入 Stream 的数据字典
:return: 消息ID
"""
async with await self.get_redis() as _redis:
# 添加时会自动生成唯一的消息ID
message_id = await _redis.xadd(name=self.stream_name, fields=data)
logging.echo_log(f"消息已发布至 Stream '{self.stream_name}'ID: {message_id};数据为:{data}.")
return message_id
async def reclaim_stale_tasks(self, func: Callable, is_delete: bool, stale_threshold_ms: int = 5 * 60 * 1000):
"""
检查并尝试重新处理僵尸任务。
Args:
func (Callable): 用于处理任务的业务回调函数。
is_delete (bool): 处理成功后是否确认消息。
stale_threshold_ms (int): 判定为僵尸任务的空闲时间阈值(毫秒)。
"""
async with await self.get_redis() as _redis:
# 1. 发现僵尸任务
try:
stale_tasks = await _redis.xpending_range(
name=self.stream_name,
groupname=self.group_name,
min='-',
max='+',
count=10, # 每次最多处理10个僵尸任务,避免启动时阻塞太久
idle=stale_threshold_ms
)
except Exception as e:
logging.echo_log(f"检查僵尸任务时出错: {e}", level=ERROR, is_log_exc=True)
return
if not stale_tasks:
logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.")
return
if not stale_tasks or not isinstance(stale_tasks, list):
logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.")
return
message_ids = [task['message_id'] for task in stale_tasks]
logging.echo_log(f"发现 {len(message_ids)} 个僵尸任务,尝试认领并重新处理...")
# 2. 认领任务
try:
reclaimed_messages = await _redis.xclaim(
name=self.stream_name,
groupname=self.group_name,
consumername=self.consumer_name, # 认领给自己
min_idle_time=stale_threshold_ms,
message_ids=message_ids,
justid=False # 我们需要消息内容来处理
)
except Exception as e:
logging.echo_log(f"认领僵尸任务时出错: {e}", level=ERROR, is_log_exc=True)
return
if not reclaimed_messages:
logging.echo_log("未能成功认领任何僵尸任务.")
return
logging.echo_log(f"成功认领 {len(reclaimed_messages)} 个僵尸任务,开始处理.")
# 3. 处理被认领的任务
for message_id, message_data in reclaimed_messages:
# 使用我们已有的 _callback_wrapper 来处理,保证逻辑一致
await self._callback_wrapper(
func=func,
message_id=message_id,
message_data=message_data,
is_delete=is_delete
)
async def history(self, func: Callable, is_delete: bool):
"""
启动时的恢复程序。
检查并处理长时间未完成的僵尸任务,确保系统健壮性。
"""
logging.echo_log("执行启动恢复程序,检查僵尸任务...")
# 将 func 和 is_delete 传递下去
await self.reclaim_stale_tasks(func=func, is_delete=is_delete, stale_threshold_ms=5 * 60 * 1000)
logging.echo_log("启动恢复程序执行完毕.")
async def subscribe(self, func: Callable = None, is_delete=False):
"""
从消费者组中监听并处理新消息。
启动时会先执行恢复程序,处理僵尸任务。
"""
await self._ensure_group_exists()
# === 核心改动:启动时先执行恢复,并传入回调参数 ===
await self.history(func=func, is_delete=is_delete)
async with await self.get_redis() as _redis:
try:
self.running = True
logging.echo_log("僵尸任务恢复完成,开始监听新消息...")
while not self.stopping and self.running:
try:
# 阻塞读取新消息
streams = await _redis.xreadgroup(
groupname=self.group_name,
consumername=self.consumer_name,
streams={self.stream_name: '>'}, # '>' 表示只读取新消息
count=1,
block=5000 # 5秒超时,类似 PubSub 的心跳
)
if not streams:
# 超时,是心跳成功的标志
logging.echo_log("心跳:连接正常,继续监听...")
continue
# 解析消息
stream, messages = streams[0]
message_id, message_data = messages[0]
logging.echo_log(f"收到新消息: ID={message_id}, 数据={message_data}")
try:
# 隔离处理回调异常
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
await aio_pool.run_background_task(
self._callback_wrapper(func, message_id, message_data, is_delete), 10
)
except Exception:
# 回调处理失败,消息未被确认,将留在队列中稍后重试
continue
except redis.exceptions.ConnectionError as e:
# 连接错误,触发重连
logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True)
raise e
except (asyncio.CancelledError, KeyboardInterrupt):
logging.echo_log("收到退出信号,停止监听...")
self.running = False
raise
except Exception as e:
logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True)
raise e
finally:
self.running = False
logging.echo_log("Stream 监听已完全停止.")
async def _callback_wrapper(self, func: Callable, message_id: str, message_data: dict, is_delete: bool):
"""
一个包装器,用于将 Stream 的消息处理逻辑适配到基类的 callback 方法签名上。
这样做可以复用基类 callback 中的异常处理逻辑。
"""
# 如果没有提供回调函数,则无法处理,直接返回,避免丢失任务
if not func:
logging.echo_log(f"警告: 收到消息 {message_id} 但未提供业务回调函数,消息将被忽略.", level=WARNING)
return
# 不能直接调用基类的 callback,因为它会尝试删除
# 在这里复制它的异常处理逻辑,但使用 Stream 的操作
result = None
async with await self.get_redis() as _redis:
try:
if func:
# 处理回调
result = func(message_data)
# 处理协程
if isinstance(result, Awaitable):
result = await result
if is_delete:
# 先从 PENDING 列表中移除
await _redis.xack(self.stream_name, self.group_name, message_id)
# 再从 Stream 中逻辑删除
await _redis.xdel(self.stream_name, message_id)
logging.echo_log(f"消息已确认 (ACK): {message_id};数据为:{message_data}.")
except redis.RedisError as e:
logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True)
except Exception as e:
logging.echo_log(
f"执行回调异常:{e};方法:{self.get_func_name(func)};消息: {message_id}.",
level=ERROR, is_log_exc=True
)
return result
async def run_forever(self, func: Callable = None, is_delete=False):
"""
持久运行的监听器,包含自动重连逻辑和优雅退出。
"""
while not self.stopping:
try:
logging.echo_log("启动新的监听会话...")
await self.subscribe(func, is_delete)
except (asyncio.CancelledError, KeyboardInterrupt):
logging.echo_log("收到退出信号,停止监听...")
self.stopping = True
break
except Exception as e:
logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True)
if self.stopping:
logging.echo_log("总开关已打开,停止重连.")
break
logging.echo_log("等待重新连接...")
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
logging.echo_log("等待期间被取消,准备退出.")
break
logging.echo_log("监听服务已完全停止.")
def subscribe_stop(self):
self.running = False
self.stopping = True