首次提交
This commit is contained in:
@@ -0,0 +1,995 @@
|
||||
"""
|
||||
封装了 Python 对 Redis 的基本操作。
|
||||
同时处理了 Java 在操作 Redis 后留下的字节码问题。
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import pathlib
|
||||
import random
|
||||
import types
|
||||
from logging import ERROR, WARNING
|
||||
from typing import Optional, Callable, Awaitable, Union, Tuple, Dict
|
||||
|
||||
import javaobj
|
||||
import redis
|
||||
from redis.asyncio import ConnectionPool, StrictRedis
|
||||
from redis.client import Pipeline
|
||||
|
||||
from paste.core import aio_pool, config, logging
|
||||
from paste.util.snow_id import IdWorker
|
||||
|
||||
|
||||
class LuaScriptManager:
|
||||
"""
|
||||
Lua 脚本管理器。
|
||||
负责加载、缓存和执行 Lua 脚本。
|
||||
"""
|
||||
|
||||
# 默认 Lua 脚本内容(作为内置默认值,无需外部文件)
|
||||
DEFAULT_SCRIPTS = {
|
||||
"stock_decr": """
|
||||
-- 扣减库存(原子操作)
|
||||
-- KEYS[1]: 库存 key
|
||||
-- ARGV[1]: 扣减数量
|
||||
-- 返回值: 1=成功, 0=库存不足, -1=key不存在
|
||||
|
||||
local key = KEYS[1]
|
||||
local quantity = tonumber(ARGV[1])
|
||||
|
||||
local current = redis.call('GET', key)
|
||||
if not current then
|
||||
return -1
|
||||
end
|
||||
|
||||
current = tonumber(current)
|
||||
if current >= quantity then
|
||||
redis.call('DECRBY', key, quantity)
|
||||
return 1
|
||||
else
|
||||
return 0
|
||||
end
|
||||
""",
|
||||
|
||||
"stock_incr": """
|
||||
-- 增加库存(原子操作)
|
||||
-- KEYS[1]: 库存 key
|
||||
-- ARGV[1]: 增加数量
|
||||
-- 返回值: 当前库存
|
||||
|
||||
local key = KEYS[1]
|
||||
local quantity = tonumber(ARGV[1])
|
||||
|
||||
redis.call('INCRBY', key, quantity)
|
||||
return redis.call('GET', key)
|
||||
""",
|
||||
|
||||
"stock_peek": """
|
||||
-- 查看库存(原子操作)
|
||||
-- KEYS[1]: 库存 key
|
||||
-- 返回值: 当前库存
|
||||
|
||||
local key = KEYS[1]
|
||||
local current = redis.call('GET', key)
|
||||
|
||||
if not current then
|
||||
return -1
|
||||
end
|
||||
return tonumber(current)
|
||||
""",
|
||||
}
|
||||
|
||||
_scripts: Dict[str, Tuple[str, str]] = {} # name -> (sha, script_content)
|
||||
_script_dir: Optional[pathlib.Path] = None
|
||||
_use_external_files: bool = False # 是否使用外部文件
|
||||
|
||||
@classmethod
|
||||
def set_script_dir(cls, script_dir: str, use_external: bool = True):
|
||||
"""
|
||||
设置 Lua 脚本目录
|
||||
|
||||
:param script_dir: 脚本目录路径
|
||||
:param use_external: 是否使用外部文件(False 则使用内置默认脚本)
|
||||
"""
|
||||
cls._script_dir = pathlib.Path(script_dir) if script_dir else None
|
||||
cls._use_external_files = use_external
|
||||
|
||||
@classmethod
|
||||
async def load_script(cls, redis_client: StrictRedis, script_name: str) -> str:
|
||||
"""
|
||||
加载并注册 Lua 脚本
|
||||
优先使用外部文件,不存在则使用内置默认脚本
|
||||
|
||||
:param redis_client: Redis 客户端
|
||||
:param script_name: 脚本名称(如 stock_decr)
|
||||
:return: 脚本 SHA
|
||||
"""
|
||||
script_content = None
|
||||
|
||||
# 尝试从外部文件加载
|
||||
if cls._use_external_files and cls._script_dir:
|
||||
script_path = cls._script_dir / f"{script_name}.lua"
|
||||
if script_path.exists():
|
||||
with open(script_path, 'r', encoding='utf-8') as f:
|
||||
script_content = f.read()
|
||||
logging.echo_log(f"Lua 脚本从外部文件加载: {script_path}")
|
||||
|
||||
# 使用内置默认脚本
|
||||
if script_content is None:
|
||||
if script_name not in cls.DEFAULT_SCRIPTS:
|
||||
raise ValueError(f"脚本不存在: {script_name},且无内置默认值")
|
||||
script_content = cls.DEFAULT_SCRIPTS[script_name]
|
||||
logging.echo_log(f"Lua 脚本使用内置默认值: {script_name}")
|
||||
|
||||
# 计算 SHA
|
||||
sha = hashlib.sha1(script_content.encode()).hexdigest()
|
||||
|
||||
# 缓存脚本
|
||||
cls._scripts[script_name] = (sha, script_content)
|
||||
|
||||
# 预加载到 Redis
|
||||
try:
|
||||
await redis_client.script_load(script_content)
|
||||
except Exception:
|
||||
pass # 预加载失败不影响后续使用
|
||||
|
||||
return sha
|
||||
|
||||
@classmethod
|
||||
async def load_default_scripts(cls, redis_client: StrictRedis):
|
||||
"""
|
||||
加载所有默认脚本
|
||||
"""
|
||||
for script_name in cls.DEFAULT_SCRIPTS.keys():
|
||||
await cls.load_script(redis_client, script_name)
|
||||
logging.echo_log(f"已加载 {len(cls.DEFAULT_SCRIPTS)} 个默认 Lua 脚本")
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, redis_client: StrictRedis, script_name: str,
|
||||
keys: list, args: list) -> any:
|
||||
"""
|
||||
执行 Lua 脚本
|
||||
优先使用 evalsha(性能更好),失败则降级到 eval
|
||||
|
||||
:param redis_client: Redis 客户端
|
||||
:param script_name: 脚本名称
|
||||
:param keys: KEYS 参数列表
|
||||
:param args: ARGV 参数列表
|
||||
:return: 脚本执行结果
|
||||
"""
|
||||
if script_name not in cls._scripts:
|
||||
# 脚本未加载,尝试加载
|
||||
await cls.load_script(redis_client, script_name)
|
||||
|
||||
sha, script_content = cls._scripts[script_name]
|
||||
|
||||
try:
|
||||
return await redis_client.evalsha(sha, len(keys), *keys, *args)
|
||||
except redis.ResponseError as e:
|
||||
if "NOSCRIPT" in str(e):
|
||||
# 重新加载并重试
|
||||
await redis_client.script_load(script_content)
|
||||
return await redis_client.evalsha(sha, len(keys), *keys, *args)
|
||||
else:
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def reload_script(cls, redis_client: StrictRedis, script_name: str) -> str:
|
||||
"""重新加载指定的 Lua 脚本"""
|
||||
if script_name in cls._scripts:
|
||||
del cls._scripts[script_name]
|
||||
return await cls.load_script(redis_client, script_name)
|
||||
|
||||
|
||||
class Redis:
|
||||
"""
|
||||
Redis 基础操作。
|
||||
"""
|
||||
|
||||
connect_pool: Optional[ConnectionPool] = None
|
||||
|
||||
prefix = b'\xac\xed\x00\x05'
|
||||
utf_flag = b'\x74'
|
||||
|
||||
lua_scripts = LuaScriptManager
|
||||
"""Lua 脚本管理器。"""
|
||||
|
||||
@classmethod
|
||||
def is_java_serialized(cls, bs: Union[bytes, str]):
|
||||
"""
|
||||
判断是否为 Java 序列化后的数据。
|
||||
|
||||
:param bs: 字节流
|
||||
"""
|
||||
if not isinstance(bs, bytes):
|
||||
return False
|
||||
return bs[:4] == cls.prefix
|
||||
|
||||
@classmethod
|
||||
async def get_pool(cls) -> ConnectionPool:
|
||||
"""
|
||||
取得 Redis 连接池。
|
||||
|
||||
:return: 连接池对象
|
||||
"""
|
||||
if cls.connect_pool is None:
|
||||
_conn_params = config.get_config("redis.connection")
|
||||
cls.connect_pool = ConnectionPool.from_url(**_conn_params)
|
||||
return cls.connect_pool
|
||||
|
||||
@classmethod
|
||||
async def close_pool(cls):
|
||||
if cls.connect_pool is not None:
|
||||
await cls.connect_pool.disconnect()
|
||||
cls.connect_pool = None
|
||||
|
||||
@classmethod
|
||||
async def get_redis(cls) -> StrictRedis:
|
||||
"""
|
||||
取得数据库对象。
|
||||
|
||||
:return: 数据库对象
|
||||
"""
|
||||
_pool = await cls.get_pool()
|
||||
return StrictRedis(
|
||||
connection_pool=_pool,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
socket_keepalive=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def ping(cls):
|
||||
"""
|
||||
测试连接。
|
||||
|
||||
:return: 测试结果
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.ping()
|
||||
|
||||
@classmethod
|
||||
async def get_pipe(cls, transaction: bool = True, shard_hint=None) -> Pipeline:
|
||||
"""
|
||||
取得管道对象。
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return _redis.pipeline(transaction=transaction, shard_hint=shard_hint)
|
||||
|
||||
# ========== Lua 脚本初始化 ==========
|
||||
|
||||
@classmethod
|
||||
async def init_lua_scripts(cls, script_dir: str = None, use_external: bool = False):
|
||||
"""
|
||||
初始化 Lua 脚本
|
||||
建议在应用启动时调用一次
|
||||
|
||||
:param script_dir: 外部脚本目录(可选)
|
||||
:param use_external: 是否使用外部文件,默认 False 使用内置脚本
|
||||
"""
|
||||
if script_dir:
|
||||
cls.lua_scripts.set_script_dir(script_dir, use_external)
|
||||
|
||||
async with await cls.get_redis() as _redis:
|
||||
await cls.lua_scripts.load_default_scripts(_redis)
|
||||
|
||||
# ========== 库存核心方法(原子操作) ==========
|
||||
|
||||
@classmethod
|
||||
async def stock_decr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, str]:
|
||||
"""
|
||||
扣减库存(原子操作)
|
||||
使用 Lua 脚本保证原子性,防止超卖
|
||||
|
||||
:param stock_key: 库存 Key(支持分片,如 stock:iPhone15:shard:0)
|
||||
:param quantity: 扣减数量
|
||||
:return: (是否成功, 消息)
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
try:
|
||||
result = await cls.lua_scripts.execute(
|
||||
_redis,
|
||||
"stock_decr",
|
||||
keys=[stock_key],
|
||||
args=[quantity]
|
||||
)
|
||||
|
||||
if result == 1:
|
||||
return True, "扣减成功"
|
||||
elif result == 0:
|
||||
return False, "库存不足"
|
||||
else:
|
||||
return False, "商品不存在"
|
||||
except Exception as e:
|
||||
logging.echo_log(f"扣减库存异常: {e}", level=ERROR, is_log_exc=True)
|
||||
return False, f"系统异常: {e}"
|
||||
|
||||
@classmethod
|
||||
async def stock_incr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, int]:
|
||||
"""
|
||||
增加库存(原子操作)
|
||||
用于退货入库、补货等场景
|
||||
|
||||
:param stock_key: 库存 Key
|
||||
:param quantity: 增加数量
|
||||
:return: (是否成功, 当前库存)
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
try:
|
||||
result = await cls.lua_scripts.execute(
|
||||
_redis,
|
||||
"stock_incr",
|
||||
keys=[stock_key],
|
||||
args=[quantity]
|
||||
)
|
||||
return True, int(result)
|
||||
except Exception as e:
|
||||
logging.echo_log(f"增加库存异常: {e}", level=ERROR, is_log_exc=True)
|
||||
return False, 0
|
||||
|
||||
@classmethod
|
||||
async def stock_peek(cls, stock_key: str) -> int:
|
||||
"""
|
||||
查看剩余库存(原子操作)
|
||||
|
||||
:param stock_key: 库存 Key
|
||||
:return: 剩余库存
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
try:
|
||||
result = await cls.lua_scripts.execute(
|
||||
_redis,
|
||||
"stock_peek",
|
||||
keys=[stock_key],
|
||||
args=[]
|
||||
)
|
||||
return int(result) if result >= 0 else 0
|
||||
except Exception as e:
|
||||
logging.echo_log(f"查询库存异常: {e}", level=ERROR, is_log_exc=True)
|
||||
return 0
|
||||
|
||||
# ========== 库存分片辅助方法 ==========
|
||||
|
||||
@classmethod
|
||||
def get_shard_key(cls, sku_id: str, shard_id: int) -> str:
|
||||
"""
|
||||
获取分片 Key。
|
||||
推荐格式:{业务域}:{实体}:{唯一标识}:{分片/维度}:{扩展}
|
||||
|
||||
:param sku_id: 商品ID
|
||||
:param shard_id: 分片ID
|
||||
:return: 分片 Key
|
||||
"""
|
||||
return f"stock:{sku_id}:shard:{shard_id}"
|
||||
|
||||
@classmethod
|
||||
def get_user_shard(cls, sku_id: str, user_id: str, shard_count: int = 10) -> str:
|
||||
"""
|
||||
根据用户ID获取分片 Key
|
||||
|
||||
:param sku_id: 商品ID
|
||||
:param user_id: 用户ID
|
||||
:param shard_count: 分片总数
|
||||
:return: 分片 Key
|
||||
"""
|
||||
shard = hash(user_id) % shard_count
|
||||
return cls.get_shard_key(sku_id, shard)
|
||||
|
||||
@classmethod
|
||||
async def init_sharded_stock(cls, sku_id: str, total_stock: int, shard_count: int = 10):
|
||||
"""
|
||||
初始化分片库存
|
||||
|
||||
:param sku_id: 商品ID
|
||||
:param total_stock: 总库存
|
||||
:param shard_count: 分片数量
|
||||
"""
|
||||
base = total_stock // shard_count
|
||||
remainder = total_stock % shard_count
|
||||
|
||||
async with await cls.get_redis() as _redis:
|
||||
for i in range(shard_count):
|
||||
shard_key = cls.get_shard_key(sku_id, i)
|
||||
stock = base + (1 if i < remainder else 0)
|
||||
await _redis.set(shard_key, stock)
|
||||
logging.echo_log(f"初始化分片 {i}: {shard_key} = {stock}")
|
||||
|
||||
# ========== 基础 KV 操作 ==========
|
||||
|
||||
@classmethod
|
||||
async def keys(cls):
|
||||
"""
|
||||
取得所有的 Key。
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
_keys = await _redis.keys()
|
||||
return _keys
|
||||
|
||||
@classmethod
|
||||
async def show_keys(cls):
|
||||
"""
|
||||
控制台显示所有的 Keys。
|
||||
"""
|
||||
_keys = await cls.keys()
|
||||
for _key in _keys:
|
||||
if isinstance(_key, bytes):
|
||||
if cls.is_java_serialized(_key):
|
||||
print(_key[7:].decode('utf-8'), '=>', _key)
|
||||
else:
|
||||
print(_key.decode('utf-8'), '=>', _key)
|
||||
else:
|
||||
print(_key)
|
||||
|
||||
@classmethod
|
||||
async def get(cls, key: Union[bytes, str]):
|
||||
"""
|
||||
多种方式读取 Redis 中的数据。
|
||||
|
||||
:param key: Redis Key 名称
|
||||
:return: 数据内容
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
_result = await _redis.get(key)
|
||||
|
||||
if _result is None and not cls.is_java_serialized(key):
|
||||
if isinstance(key, str):
|
||||
key_bytes = key.encode('utf-8')
|
||||
else:
|
||||
key_bytes = key
|
||||
_key = cls.prefix + cls.utf_flag + len(key_bytes).to_bytes(2, 'big') + key_bytes
|
||||
_result = await _redis.get(_key)
|
||||
|
||||
if _result is None:
|
||||
return _result
|
||||
|
||||
if isinstance(_result, bytes) and cls.is_java_serialized(_result):
|
||||
return javaobj.loads(_result)
|
||||
else:
|
||||
return _result
|
||||
|
||||
@classmethod
|
||||
async def set(cls, key: str, value: any, ex: int = None):
|
||||
"""
|
||||
设置键值对
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.set(key, value, ex=ex)
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, key: str):
|
||||
"""
|
||||
删除键
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.delete(key)
|
||||
|
||||
@classmethod
|
||||
async def exists(cls, key: str) -> bool:
|
||||
"""
|
||||
检查键是否存在
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.exists(key) > 0
|
||||
|
||||
@classmethod
|
||||
async def expire(cls, key: str, seconds: int):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.expire(key, seconds)
|
||||
|
||||
@classmethod
|
||||
async def incr(cls, key: str) -> int:
|
||||
"""
|
||||
原子递增
|
||||
"""
|
||||
async with await cls.get_redis() as _redis:
|
||||
return await _redis.incr(key)
|
||||
|
||||
# ========== 回调处理 ==========
|
||||
|
||||
@classmethod
|
||||
def get_func_name(cls, func):
|
||||
"""
|
||||
得到方法名称。
|
||||
|
||||
:param func: 方法对象
|
||||
:return: 方法名称
|
||||
"""
|
||||
if isinstance(func, types.FunctionType):
|
||||
return func.__name__
|
||||
elif isinstance(func, types.MethodType):
|
||||
return func.__func__.__name__
|
||||
elif isinstance(func, (classmethod, staticmethod)):
|
||||
return func.__func__.__name__
|
||||
elif hasattr(func, '__call__'):
|
||||
return func.__class__.__name__
|
||||
else:
|
||||
return str(func)
|
||||
|
||||
@classmethod
|
||||
async def callback(cls, func: Callable, message_key: str, is_delete=False):
|
||||
"""
|
||||
根据消息 KEY 读取数据,并执行回调函数,如果回调函数正确执行,则根据参数 is_delete 判断删除消息。
|
||||
|
||||
:param func: 回调函数
|
||||
:param message_key: 消息 KEY
|
||||
:param is_delete: 是否删除处理过的消息
|
||||
"""
|
||||
result = None
|
||||
async with await cls.get_redis() as _redis:
|
||||
try:
|
||||
message_data = await _redis.hgetall(message_key)
|
||||
if not message_data:
|
||||
logging.echo_log(f"警告: 空消息数据 {message_key}.", level=WARNING)
|
||||
return result
|
||||
|
||||
if func:
|
||||
# 处理回调
|
||||
result = func(message_data)
|
||||
# 处理协程
|
||||
if isinstance(result, Awaitable):
|
||||
result = await result
|
||||
|
||||
if is_delete:
|
||||
# 回调正确执行,且设置为删除删除的,才会删除消息
|
||||
await _redis.delete(message_key)
|
||||
logging.echo_log(f"消息已删除: {message_key};数据为:{message_data}.")
|
||||
except redis.RedisError as e:
|
||||
logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True)
|
||||
except Exception as e:
|
||||
logging.echo_log(
|
||||
f"执行回调异常:{e};方法:{cls.get_func_name(func)};消息: {message_key}.",
|
||||
level=ERROR, is_log_exc=True
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class PubSubActor(Redis):
|
||||
"""
|
||||
发布订阅执行器。用于发布消息和订阅消息。
|
||||
订阅采用阻塞式读取,可以在读取到数据后,执行回调方法,并根据参数确定是否删除历史消息。
|
||||
"""
|
||||
|
||||
def __init__(self, hash_name: str):
|
||||
self.hash_name = f"{hash_name}_HASH_NAME"
|
||||
self.channel = f"{hash_name}_CHANNEL"
|
||||
|
||||
self.running = False
|
||||
"""
|
||||
优雅退出控制标志
|
||||
"""
|
||||
|
||||
self.stopping = False
|
||||
"""
|
||||
控制整个 run_forever 循环退出
|
||||
"""
|
||||
|
||||
async def publish(self, data: dict) -> str:
|
||||
"""
|
||||
数据写入 Redis 并发布消息。
|
||||
|
||||
:param data: 写入 Redis 的数据
|
||||
:return: 消息ID
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
# 生成雪花 ID 作为 Hash Key
|
||||
_random_num = random.randint(1000, 9999)
|
||||
_id = IdWorker.get_id_worker(3, 3, _random_num).get_id()
|
||||
|
||||
# 写入Redis hash
|
||||
await _redis.hset(f"{self.hash_name}:{_id}", mapping=data)
|
||||
# 发布新消息通知
|
||||
await _redis.publish(self.channel, _id)
|
||||
return _id
|
||||
|
||||
async def subscribe(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
监听消息。
|
||||
|
||||
:param func: 监听回调程序
|
||||
:param is_delete: 回调执行完毕后,是否删除消息
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
_pubsub = _redis.pubsub()
|
||||
await _pubsub.subscribe(self.channel)
|
||||
|
||||
try:
|
||||
self.running = True
|
||||
|
||||
# 使用 while 循环,而不是直接 async for,以便加入超时控制
|
||||
while not self.stopping and self.running:
|
||||
try:
|
||||
# 每次循环都重新获取迭代器
|
||||
listen_iter = _pubsub.listen()
|
||||
message = await asyncio.wait_for(listen_iter.__anext__(), timeout=60.0)
|
||||
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
|
||||
message_id = message["data"]
|
||||
message_key = f"{self.hash_name}:{message_id}"
|
||||
|
||||
try:
|
||||
# 隔离处理回调异常
|
||||
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
|
||||
await aio_pool.run_background_task(self.callback(func, message_key, is_delete), 10)
|
||||
# await self.callback(func, message_key, is_delete=is_delete)
|
||||
except Exception:
|
||||
# 继续处理下条消息
|
||||
continue
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时,是心跳成功的标志
|
||||
logging.echo_log("心跳:连接正常,继续监听...")
|
||||
continue
|
||||
except redis.exceptions.ConnectionError as e:
|
||||
# 连接错误,触发重连
|
||||
logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True)
|
||||
raise e
|
||||
except StopAsyncIteration:
|
||||
# pubsub 正常关闭
|
||||
logging.echo_log("PubSub 迭代器已停止.")
|
||||
break
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.echo_log("收到退出信号,停止监听...")
|
||||
self.running = False
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True)
|
||||
raise e
|
||||
finally:
|
||||
self.running = False
|
||||
try:
|
||||
await _pubsub.unsubscribe(self.channel)
|
||||
await _pubsub.close()
|
||||
except Exception as close_err:
|
||||
logging.echo_log(f"资源关闭异常: {close_err}.")
|
||||
finally:
|
||||
logging.echo_log("监听已完全停止.")
|
||||
|
||||
async def run_forever(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
持久运行的监听器,包含自动重连逻辑和优雅退出。
|
||||
"""
|
||||
while not self.stopping:
|
||||
try:
|
||||
logging.echo_log("启动新的监听会话...")
|
||||
await self.subscribe(func, is_delete)
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.echo_log("收到退出信号,停止监听...")
|
||||
self.stopping = True
|
||||
break
|
||||
except Exception as e:
|
||||
logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True)
|
||||
|
||||
if self.stopping:
|
||||
logging.echo_log("总开关已打开,停止重连.")
|
||||
break
|
||||
|
||||
logging.echo_log("等待重新连接...")
|
||||
try:
|
||||
# 关键改动:直接使用 sleep,它本身就是可中断的
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
logging.echo_log("等待期间被取消,准备退出.")
|
||||
break
|
||||
|
||||
logging.echo_log("监听服务已完全停止.")
|
||||
|
||||
async def history(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
处理历史数据。
|
||||
|
||||
:param func: 监听回调程序
|
||||
:param is_delete: 回调执行完毕后,是否删除消息
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
_keys = await _redis.keys()
|
||||
for _k in _keys:
|
||||
try:
|
||||
# 隔离处理回调异常
|
||||
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
|
||||
await aio_pool.run_background_task(self.callback(func, _k, is_delete), 10)
|
||||
# await self.callback(func, _k, is_delete=is_delete)
|
||||
except Exception:
|
||||
# 继续处理下条消息
|
||||
continue
|
||||
|
||||
def subscribe_stop(self):
|
||||
self.running = False
|
||||
self.stopping = True
|
||||
|
||||
|
||||
class StreamActor(Redis):
|
||||
"""
|
||||
流执行器。使用 Redis Streams 实现发布消息和消费消息。
|
||||
消费采用消费者组模式,支持消息确认和可靠传递。
|
||||
方法结构与 PubSubActor 保持一致,便于无缝替换。
|
||||
|
||||
此版本集成了启动时的僵尸任务自动恢复功能。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def actor_config(cls, config_path: str):
|
||||
"""
|
||||
根据路径,取得配置信息。
|
||||
|
||||
:param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔
|
||||
:return:
|
||||
"""
|
||||
_stream_name = config_path.split(".")[-1].upper()
|
||||
_stream_config = config.get_config(config_path)
|
||||
_group_name = _stream_config.get('group', f"{_stream_name}_GROUP")
|
||||
_consumer_name = _stream_config.get('consumer', f"{_stream_name}_CONSUMER")
|
||||
_snow_id = IdWorker.get_id_worker().get_id()
|
||||
_consumer_name = f"{_consumer_name}_{_snow_id}"
|
||||
return _stream_name, _group_name, _consumer_name
|
||||
|
||||
@classmethod
|
||||
def new_actor(cls, config_path: str):
|
||||
"""
|
||||
根据配置文件中的配置小节创建流执行器。
|
||||
|
||||
:param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔
|
||||
:return: 执行器对象
|
||||
"""
|
||||
_stream_name, _group_name, _consumer_name = cls.actor_config(config_path)
|
||||
return cls(_stream_name, _group_name, _consumer_name)
|
||||
|
||||
def __init__(self, stream_name: str, group_name: str, consumer_name: str):
|
||||
"""
|
||||
初始化流执行器。
|
||||
|
||||
:param stream_name: Redis Stream 的名称
|
||||
:param group_name: 消费者组的名称
|
||||
:param consumer_name: 当前消费者的名称
|
||||
"""
|
||||
self.stream_name = stream_name
|
||||
self.group_name = group_name
|
||||
self.consumer_name = consumer_name
|
||||
|
||||
self.running = False
|
||||
"""
|
||||
优雅退出控制标志
|
||||
"""
|
||||
|
||||
self.stopping = False
|
||||
"""
|
||||
控制整个 run_forever 循环退出
|
||||
"""
|
||||
|
||||
async def _ensure_group_exists(self):
|
||||
"""确保消费者组已存在,如果不存在则创建。"""
|
||||
try:
|
||||
_redis = await self.get_redis()
|
||||
await _redis.xgroup_create(
|
||||
name=self.stream_name,
|
||||
groupname=self.group_name,
|
||||
id='0', # 从头开始消费
|
||||
mkstream=True # Stream 不存在时自动创建
|
||||
)
|
||||
logging.echo_log(f"消费者组 '{self.group_name}' 已创建.")
|
||||
except redis.exceptions.ResponseError as e:
|
||||
if "Consumer Group name already exists" in str(e):
|
||||
logging.echo_log(f"消费者组 '{self.group_name}' 已存在.")
|
||||
else:
|
||||
raise
|
||||
|
||||
async def publish(self, data: dict) -> str:
|
||||
"""
|
||||
将数据作为消息写入 Redis Stream。
|
||||
|
||||
:param data: 写入 Stream 的数据字典
|
||||
:return: 消息ID
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
# 添加时会自动生成唯一的消息ID
|
||||
message_id = await _redis.xadd(name=self.stream_name, fields=data)
|
||||
logging.echo_log(f"消息已发布至 Stream '{self.stream_name}',ID: {message_id};数据为:{data}.")
|
||||
return message_id
|
||||
|
||||
async def reclaim_stale_tasks(self, func: Callable, is_delete: bool, stale_threshold_ms: int = 5 * 60 * 1000):
|
||||
"""
|
||||
检查并尝试重新处理僵尸任务。
|
||||
|
||||
Args:
|
||||
func (Callable): 用于处理任务的业务回调函数。
|
||||
is_delete (bool): 处理成功后是否确认消息。
|
||||
stale_threshold_ms (int): 判定为僵尸任务的空闲时间阈值(毫秒)。
|
||||
"""
|
||||
async with await self.get_redis() as _redis:
|
||||
# 1. 发现僵尸任务
|
||||
try:
|
||||
stale_tasks = await _redis.xpending_range(
|
||||
name=self.stream_name,
|
||||
groupname=self.group_name,
|
||||
min='-',
|
||||
max='+',
|
||||
count=10, # 每次最多处理10个僵尸任务,避免启动时阻塞太久
|
||||
idle=stale_threshold_ms
|
||||
)
|
||||
except Exception as e:
|
||||
logging.echo_log(f"检查僵尸任务时出错: {e}", level=ERROR, is_log_exc=True)
|
||||
return
|
||||
|
||||
if not stale_tasks:
|
||||
logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.")
|
||||
return
|
||||
|
||||
if not stale_tasks or not isinstance(stale_tasks, list):
|
||||
logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.")
|
||||
return
|
||||
message_ids = [task['message_id'] for task in stale_tasks]
|
||||
logging.echo_log(f"发现 {len(message_ids)} 个僵尸任务,尝试认领并重新处理...")
|
||||
|
||||
# 2. 认领任务
|
||||
try:
|
||||
reclaimed_messages = await _redis.xclaim(
|
||||
name=self.stream_name,
|
||||
groupname=self.group_name,
|
||||
consumername=self.consumer_name, # 认领给自己
|
||||
min_idle_time=stale_threshold_ms,
|
||||
message_ids=message_ids,
|
||||
justid=False # 我们需要消息内容来处理
|
||||
)
|
||||
except Exception as e:
|
||||
logging.echo_log(f"认领僵尸任务时出错: {e}", level=ERROR, is_log_exc=True)
|
||||
return
|
||||
|
||||
if not reclaimed_messages:
|
||||
logging.echo_log("未能成功认领任何僵尸任务.")
|
||||
return
|
||||
|
||||
logging.echo_log(f"成功认领 {len(reclaimed_messages)} 个僵尸任务,开始处理.")
|
||||
|
||||
# 3. 处理被认领的任务
|
||||
for message_id, message_data in reclaimed_messages:
|
||||
# 使用我们已有的 _callback_wrapper 来处理,保证逻辑一致
|
||||
await self._callback_wrapper(
|
||||
func=func,
|
||||
message_id=message_id,
|
||||
message_data=message_data,
|
||||
is_delete=is_delete
|
||||
)
|
||||
|
||||
async def history(self, func: Callable, is_delete: bool):
|
||||
"""
|
||||
启动时的恢复程序。
|
||||
检查并处理长时间未完成的僵尸任务,确保系统健壮性。
|
||||
"""
|
||||
logging.echo_log("执行启动恢复程序,检查僵尸任务...")
|
||||
# 将 func 和 is_delete 传递下去
|
||||
await self.reclaim_stale_tasks(func=func, is_delete=is_delete, stale_threshold_ms=5 * 60 * 1000)
|
||||
logging.echo_log("启动恢复程序执行完毕.")
|
||||
|
||||
async def subscribe(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
从消费者组中监听并处理新消息。
|
||||
启动时会先执行恢复程序,处理僵尸任务。
|
||||
"""
|
||||
await self._ensure_group_exists()
|
||||
|
||||
# === 核心改动:启动时先执行恢复,并传入回调参数 ===
|
||||
await self.history(func=func, is_delete=is_delete)
|
||||
|
||||
async with await self.get_redis() as _redis:
|
||||
try:
|
||||
self.running = True
|
||||
logging.echo_log("僵尸任务恢复完成,开始监听新消息...")
|
||||
|
||||
while not self.stopping and self.running:
|
||||
try:
|
||||
# 阻塞读取新消息
|
||||
streams = await _redis.xreadgroup(
|
||||
groupname=self.group_name,
|
||||
consumername=self.consumer_name,
|
||||
streams={self.stream_name: '>'}, # '>' 表示只读取新消息
|
||||
count=1,
|
||||
block=5000 # 5秒超时,类似 PubSub 的心跳
|
||||
)
|
||||
|
||||
if not streams:
|
||||
# 超时,是心跳成功的标志
|
||||
logging.echo_log("心跳:连接正常,继续监听...")
|
||||
continue
|
||||
|
||||
# 解析消息
|
||||
stream, messages = streams[0]
|
||||
message_id, message_data = messages[0]
|
||||
logging.echo_log(f"收到新消息: ID={message_id}, 数据={message_data}")
|
||||
|
||||
try:
|
||||
# 隔离处理回调异常
|
||||
# 采用后台运行的方式处理,防止消息排队,提高消息处理性能
|
||||
await aio_pool.run_background_task(
|
||||
self._callback_wrapper(func, message_id, message_data, is_delete), 10
|
||||
)
|
||||
except Exception:
|
||||
# 回调处理失败,消息未被确认,将留在队列中稍后重试
|
||||
continue
|
||||
|
||||
except redis.exceptions.ConnectionError as e:
|
||||
# 连接错误,触发重连
|
||||
logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True)
|
||||
raise e
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.echo_log("收到退出信号,停止监听...")
|
||||
self.running = False
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True)
|
||||
raise e
|
||||
finally:
|
||||
self.running = False
|
||||
logging.echo_log("Stream 监听已完全停止.")
|
||||
|
||||
async def _callback_wrapper(self, func: Callable, message_id: str, message_data: dict, is_delete: bool):
|
||||
"""
|
||||
一个包装器,用于将 Stream 的消息处理逻辑适配到基类的 callback 方法签名上。
|
||||
这样做可以复用基类 callback 中的异常处理逻辑。
|
||||
"""
|
||||
# 如果没有提供回调函数,则无法处理,直接返回,避免丢失任务
|
||||
if not func:
|
||||
logging.echo_log(f"警告: 收到消息 {message_id} 但未提供业务回调函数,消息将被忽略.", level=WARNING)
|
||||
return
|
||||
|
||||
# 不能直接调用基类的 callback,因为它会尝试删除
|
||||
# 在这里复制它的异常处理逻辑,但使用 Stream 的操作
|
||||
result = None
|
||||
async with await self.get_redis() as _redis:
|
||||
try:
|
||||
if func:
|
||||
# 处理回调
|
||||
result = func(message_data)
|
||||
# 处理协程
|
||||
if isinstance(result, Awaitable):
|
||||
result = await result
|
||||
|
||||
if is_delete:
|
||||
# 先从 PENDING 列表中移除
|
||||
await _redis.xack(self.stream_name, self.group_name, message_id)
|
||||
# 再从 Stream 中逻辑删除
|
||||
await _redis.xdel(self.stream_name, message_id)
|
||||
logging.echo_log(f"消息已确认 (ACK): {message_id};数据为:{message_data}.")
|
||||
except redis.RedisError as e:
|
||||
logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True)
|
||||
except Exception as e:
|
||||
logging.echo_log(
|
||||
f"执行回调异常:{e};方法:{self.get_func_name(func)};消息: {message_id}.",
|
||||
level=ERROR, is_log_exc=True
|
||||
)
|
||||
return result
|
||||
|
||||
async def run_forever(self, func: Callable = None, is_delete=False):
|
||||
"""
|
||||
持久运行的监听器,包含自动重连逻辑和优雅退出。
|
||||
"""
|
||||
while not self.stopping:
|
||||
try:
|
||||
logging.echo_log("启动新的监听会话...")
|
||||
await self.subscribe(func, is_delete)
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.echo_log("收到退出信号,停止监听...")
|
||||
self.stopping = True
|
||||
break
|
||||
except Exception as e:
|
||||
logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True)
|
||||
|
||||
if self.stopping:
|
||||
logging.echo_log("总开关已打开,停止重连.")
|
||||
break
|
||||
|
||||
logging.echo_log("等待重新连接...")
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
logging.echo_log("等待期间被取消,准备退出.")
|
||||
break
|
||||
|
||||
logging.echo_log("监听服务已完全停止.")
|
||||
|
||||
def subscribe_stop(self):
|
||||
self.running = False
|
||||
self.stopping = True
|
||||
Reference in New Issue
Block a user