commit 291e6fcaae7e3b964d648aaecb9ebffb4eebf42c Author: zwf <2466627138@qq.com> Date: Tue Jun 2 16:26:10 2026 +0800 首次提交 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5753c7c --- /dev/null +++ b/.gitignore @@ -0,0 +1,88 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Caches +*.cache +*.db +*.log + +# Distribution / packaging +.build/ +dist/ +build/ +*.egg-info/ +*.egg/ + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ + +# PyCharm +.idea/ +*.iml +*.ipynb_checkpoints + +# VS Code +.vscode/ + +# Environment +.venv/ +venv/ +ENV/ +env/ +.ENV/ +env.bak/ +venv.bak/ + +# Logs +logs/ + +# Examples 运行时日志(演示用途,不提交) +examples/*/logs/ + +# Config files (local overrides) +config.json +config.local.json + +!examples/*/config.json + +# OS +.DS_Store +Thumbs.db + +# Jupyter Notebook +.ipynb_checkpoints + +# Docker +docker-compose.yml +Dockerfile + +# IDE +*.swp +*.swo +*.swn + +# Backup files +*.bak +*.tmp + +# Project-specific +.PASTE_FRAMEWORK_CACHE/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..9cc86f0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,56 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [2.0.1] - 2025-04-08 + +### Fixed +- `JsonDumpsEncoder` now handles `NaTType` (pandas NaT) without throwing TypeError +- CORS `Access-Control-Allow-Origin` now defaults to `"*"` instead of echoing the request Origin header (security fix) +- `RbacUser.can()` correctly handles empty permission list +- Fixed race condition in `aio_pool.py` when task queue is full under high concurrency + +### Changed +- Refactored `config.py` — `get_config()` now consistently returns `None` instead of raising on missing optional keys (use `default` parameter for fallback) +- Upgraded minimum Python version from 3.9 to 3.11 +- `pyproject.toml` dependencies cleaned up — core dependencies slimmed down; pandas/numpy/matplotlib/opencv moved to `[project.optional-dependencies]` + +### Removed +- Deprecated `web/legacy_session.py` — use JWT-based `@auth_token` decorator instead + +## [2.0.0] - 2025-03-15 + +### Added +- **Swagger UI auto-generation** (`ApplicationSwagger`) — fully inferred from handler route patterns and decorators +- **Redis StreamActor** — consumer-group-based message processing with automatic zombie task recovery +- **ParamAwareUIModule** — async data pre-loading for Tornado UIModules +- **RBAC with dynamic rules** — rules are serialized Python classes stored in DB, supporting time-based, IP-based, and custom rule chains +- **Snowflake ID generator** (`snow_id.py`) — thread-safe, no external dependency, 10K+ IDs/sec +- **Auto-loading handlers** — `Application.load_handlers()` discovers all handlers in a package via `route_pattern` attribute +- **BaseX encoder/decoder** (`encoder.py`) — auto-detect Base16/32/64/85 + zlib decompression +- **Async DB engine** — SQLAlchemy async session factory with `aiomysql` / `aiosqlite` support +- **Task service with daemonization** — `TaskService` with PID file, start/stop/status commands, and graceful shutdown +- **5 complete examples** under `examples/` — HelloWorld, background tasks, Redis streams, scheduled services, model generation +- **pytest test suite** — unit tests (mocked) + integration tests for DB, Redis, RBAC + +### Changed +- Complete project restructuring: introduced `web/`, `core/`, `db/`, `rbac/`, `util/`, `service/` package layout +- `Application` now inherits from `tornado.web.Application` — backward-incompatible refactor from v1.x +- `RequestHandler.response_ok()` and `response_error()` now use `JsonDumpsEncoder` by default +- Logging system rewritten — supports rotating file handlers, per-module loggers, console output + +### Removed +- All v1.x code — this is a ground-up rewrite + +## [1.0.0] - 2024-01-01 + +### Added +- Initial release with basic Tornado wrapper +- Basic JWT auth and password hashing +- Simple `config.json` loader +- Minimal utility functions + +> ⚠️ **Note:** v1.0.0 is deprecated and no longer maintained. Please upgrade to v2.0.0+. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..0ba0c22 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,41 @@ +# Contribution Guidelines + +Thank you for considering contributing to this project! Your involvement helps make this project better for everyone. + +## How to Contribute + +1. **Fork the repository** +2. **Create a new feature branch** + ```bash + git checkout -b feature/your-feature-name + ``` +3. **Make your changes** + Write clean, well-documented code. Follow the project’s coding standards. +4. **Add tests** + All new features or bug fixes must include appropriate unit or integration tests. +5. **Commit your changes** + Use descriptive commit messages following [Conventional Commits](https://www.conventionalcommits.org/). +6. **Push to your fork** + ```bash + git push origin feature/your-feature-name + ``` +7. **Open a Pull Request** + Clearly describe what you changed, why, and how it affects the project. + +## Code Style & Standards + +- Follow the project’s linter rules (e.g., PEP8 for Python, ESLint for JavaScript) +- Keep functions small and focused +- Document public APIs with docstrings or comments +- Never commit secrets or environment files + +## Reporting Issues + +If you find a bug or have a feature request, please open an [Issue](https://github.com/your-repo/issues) with: +- A clear title +- Steps to reproduce (for bugs) +- Expected vs actual behavior + +## License + +By contributing, you agree that your contributions will be licensed under the MIT License. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..76b0a08 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Paste Contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..1a22148 --- /dev/null +++ b/README.md @@ -0,0 +1,287 @@ +```markdown +┌─────────────────────────────────────────────────┐ +│ │ +│ ██████╗ █████╗ ███████╗████████╗███████╗ │ +│ ██╔══██╗██╔══██╗██╔════╝╚══██╔══╝██╔════╝ │ +│ ██████╔╝███████║███████╗ ██║ █████╗ │ +│ ██╔═══╝ ██╔══██║╚════██║ ██║ ██╔══╝ │ +│ ██║ ██║ ██║███████║ ██║ ███████╗ │ +│ ╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚══════╝ │ +│ │ +│ Python Api-first Scalable Task Engine │ +│ │ +└─────────────────────────────────────────────────┘ +``` + +# PASTE — A Production-Ready Lightweight Python Framework + +> A minimalist, battle-tested Python framework for building api services with **built-in RBAC, JWT, async tasks, Snowflake IDs, Swagger, and modular utilities** — zero boilerplate, maximum control. Built on top of **Tornado**. + +![License](https://img.shields.io/badge/license-MIT-blue) +![Python](https://img.shields.io/badge/python-3.11%2B-blue) +![Build](https://img.shields.io/badge/build-passing-brightgreen) + +--- + +## ✨ Core Features + +| Feature | Description | +|---------|-------------| +| 🚀 **Auto-Loading Handlers** | Define `route_pattern = "/user"` in your handler → API is automatically registered. No router config needed. | +| 🔐 **RBAC with Dynamic Rules** | Permissions stored as serialized Python objects in DB. Define complex rules as classes (e.g., `TimeBasedRule`, `IPWhitelistRule`). | +| 📜 **Swagger UI Auto-Generated** | Built-in Swagger UI at `/docs`. No YAML. Schema is inferred from handlers and decorators. | +| 🔢 **Snowflake ID Generator** | Embedded, thread-safe, no external dependency. Generates 10K+ unique IDs/sec. | +| ⚡ **Async Task Pool with Backpressure** | `run_background_task(coro)` safely manages concurrent background jobs with configurable limits. | +| 🔐 **JWT + Password Hashing** | Stateless authentication via `paste.security.token`. Passwords hashed with PBKDF2-sha256. | +| 🧰 **Modular Utility Library (`util/`)** | Clean, testable utilities: `ustr`, `udict`, `ufile`, `pagination`, `snow_id`, `encoder` — no pollution in `__init__.py`. | +| 📦 **Layered Architecture** | `core/` (config, logging, async), `db/` (engine, models), `rbac/` (roles, rules), `web/` (handlers, app), `service/` (business logic) — clean separation. | +| 🛡️ **Safe File Handling** | `sanitize_filename()` ensures cross-platform compatibility (Windows/Linux/macOS). | +| 📊 **JSON Encoder for Complex Types** | Auto-serializes `datetime`, `Decimal`, `BaseModel`, `Enum`, `numpy` — no extra config. | +| 📬 **Redis Stream Actor** | Reliable consumer-group-based message processing with auto zombie-task recovery. | +| 🧩 **ParamAware UI Modules** | Async-preload data for Tornado UIModules — solves the age-old sync-render bottleneck. | + +--- + +## 🛠️ Quick Start + +### 1. Install + +```bash +git clone https://github.com/your-repo/paste.git +cd paste +pip install -e . +``` + +### 2. Run the example + +All ready-to-run examples are located in the `examples/` directory: + +```bash +cd examples/01_hello_world +python main.py +``` + +👉 Open: [http://localhost:9090/hello](http://localhost:9090/hello) + +--- + +## 📚 Examples + +Jump right in with these working examples: + +| Example | Description | Key Features Demonstrated | +|---------|-------------|--------------------------| +| [`01_hello_world`](examples/01_hello_world/) | Minimal paste app | Auto-loading handlers, config, response helpers | +| [`02_background_task`](examples/02_background_task/) | Async background tasks | Task pool, logging, daemonized services | +| [`03_redis_stream`](examples/03_redis_stream/) | Redis Stream publisher & consumer | `StreamActor`, consumer groups, message ACK | +| [`04_tasks_service`](examples/04_tasks_service/) | Scheduled task service | `TaskService`, cron-like scheduling, PID management | +| [`05_gen_models`](examples/05_gen_models/) | Auto-generate DB models from existing tables | SQLAlchemy model generation | + +Each example includes a `config.json`, a `handler.py`, and a `main.py` — ready to run. + +--- + +## 🏗️ Framework Architecture + +``` +paste/ # Framework core (do NOT modify) +├── core/ # Foundation layer +│ ├── config.py # Dot-path config loader (get_config("db.engine")) +│ ├── logging.py # Logger with rotating file + console +│ └── aio_pool.py # Async task pool with backpressure +├── db/ # Database layer +│ ├── engine.py # SQLAlchemy engine factory +│ ├── redis.py # Redis connection + StreamActor +│ ├── basemodel.py # Async ORM base model +│ ├── basetable.py # Table reflection utilities +│ ├── baseadapter.py # Result-set adapter +│ └── gen_models.py # Auto-generate model classes +├── web/ # Web layer +│ ├── application.py # Tornado Application with auto-loading +│ ├── handler.py # Base RequestHandler with response helpers +│ ├── decorators.py # @route, @auth_token, @auth_permission +│ ├── swagger.py # Swagger UI auto-generation +│ ├── form.py # WTForms integration +│ ├── param_aware_loader.py # Async UIModule data preloader +│ └── websocket.py # WebSocket support +├── rbac/ # Access control layer +│ ├── rbac_user.py # User with permission inheritance +│ ├── rbac_role.py # Role management +│ ├── rbac_rule.py # Rule engine (pickle-serialized) +│ ├── rbac_item.py # Permission item hierarchy +│ ├── rbac_assignment.py # User-item assignment +│ └── rbac_permission.py # Permission query +├── security/ # Security layer +│ ├── token.py # JWT encode/decode with configurable issuer +│ └── shash.py # PBKDF2-sha256 password hashing +├── service/ # Service layer +│ ├── server.py # Server process management +│ ├── daemonize.py # Unix daemon + PID file +│ └── task_service.py # Scheduled task runner +├── chart/ # Chart generation (optional) +│ ├── bar.py / pie.py / line.py +│ └── ... +└── util/ # Utility library + ├── ustr.py / udict.py # String & dict helpers + ├── ufile.py # File operations with sanitized filenames + ├── pagination.py # Page, Pagination, CursorPagination + ├── snow_id.py # Snowflake ID generator + ├── encoder.py # JSON encoder + BaseX decoder + ├── pdf.py / svg.py / xlsx.py + └── ... +``` + +Your application code lives **entirely outside** `paste/`: + +``` +myapp/ +├── main.py # Entry point: create Application, listen, start +├── config.json # DB, logger, RBAC, Tornado configuration +├── apps/ # Your handlers (recommended) +│ └── demo/ +│ ├── __init__.py +│ ├── handler_user.py +│ └── handler_auth.py +├── models/ # Your DB models +│ └── db_models.py +└── service/ # Your background services + └── __init__.py +``` + +--- + +## 🔐 RBAC Example + +```python +# handler.py +from paste.rbac.rbac_user import RbacUser +from paste.web.decorators import route, auth_token, auth_permission +from paste.web.handler import RequestHandler + +@route("/admin/users") +class AdminUserHandler(RequestHandler): + @auth_token + @auth_permission + async def get(self): + users = await RbacUser.query_as_df(RbacUser().gen_query()) + self.response_ok(rows=users.to_dict('records')) +``` + +```python +# rule.py — dynamic rule as a serialized Python class +from datetime import datetime +from paste.rbac.rbac_rule import RbacRule + +class BusinessTimeRule(RbacRule): + async def run(self, **kwargs) -> bool: + hour = datetime.now().hour + return 9 <= hour < 18 # only allow during business hours +``` + +--- + +## ⚙️ Configuration + +All configuration lives in a single `config.json`: + +```json +{ + "tornado": { + "demo": { + "autoreload": false, + "handlers_pkg": "apps.demo", + "port": 9090, + "static_path": "static", + "template_path": "templates", + "swagger_title": "DemoAPI", + "swagger_description": "Demo API", + "swagger_api_version": "1.0.1" + } + }, + "db": { + "engine": { + "engine": "mysql+pymysql://user:pass@localhost:3306/mydb", + "async_engine": "mysql+aiomysql://user:pass@localhost:3306/mydb", + "engine_option": { + "echo": false, + "pool_size": 10 + } + } + }, + "redis": { + "connection": "redis://localhost:6379/0", + "streams": { + "user_event": { + "group": "user_event_group", + "consumer": "processor_01" + } + } + }, + "rbac": { + "user_class": "myapp.models.MyRbacUser", + "table": { + "rule": "rbac_rule", + "user": "rbac_user", + "item": "rbac_item", + "assignment": "rbac_assignment", + "item_child": "rbac_item_child" + } + }, + "logger": { + "default": { + "basic": { + "filename": "logs/app.log", + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + "level": 20 + } + } + } +} +``` + +Access any config value with dot-path: + +```python +from paste.core import config +engine_url = config.get_config("db.engine.engine") +port = config.get_config("tornado.demo.port", 9000) +``` + +--- + +## 📣 License + +MIT © 2025 [Wayne Zhang / Organization] + +--- + +## 🌍 Contributing + +Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on: +- Code style (Black + flake8 + mypy) +- How to submit issues and pull requests + +--- + +## 🧪 Running Tests + +```bash +pytest # Run all tests +pytest tests/unit # Run unit tests only (no external services) +pytest tests/integration # Run integration tests (require DB/Redis) +pytest --cov=paste # Run with coverage report +``` + +--- + +> **💡 Why paste?** +> +> Most frameworks make you write boilerplate. +> `paste` makes you **stop writing boilerplate** — and **start building features**. +> +> - ✅ No `@app.get("/user")` decorators — just `route_pattern = "/user"` +> - ✅ No YAML for Swagger — schema inferred automatically +> - ✅ No Redis for IDs — Snowflake built-in +> - ✅ No Celery for tasks — async pool with backpressure +> - ✅ No config files in 5 places — one `config.json`, one `get_config("db.engine")` +``` diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..bdcca9c --- /dev/null +++ b/environment.yml @@ -0,0 +1,50 @@ +# environment.yml +name: paste-env +channels: + - conda-forge + - defaults + +dependencies: + - python=3.11 + - aiohttp>=3.13.0 + - aiomysql>=0.2.0 + - aioquic>=1.2.0 + - aiosqlite>0.21.0 + - aiofiles>=23.0.0 + - cryptography==46.0.3 + - matplotlib>=3.10.1 + - matplotlib-inline>=0.1.7 + - numpy>=1.24.0 + - openpyxl>=3.1.5 + - pandas>=2.0.0 + - pillow>=10.0.0 + - psutil>=5.9.0 + - PyJWT>=1.7.1 + - PyMySQL>=1.1.0 + - pyOpenSSL>=24.3.0 + - pytest>=8.0.0 + - pytest-asyncio>=0.23.0 + - pytest-cov>=4.0.0 + - PyYAML>=6.0.2 + - requests>=2.32.5 + - selenium>=4.38.0 + - scipy>=1.14.0 + - sqlalchemy==2.0.49 + - svgwrite>=1.4.2 + - tabulate>=0.9.0 + - tinycss2>=1.4.0 + - tinyhtml5>=2.0.0 + - tornado>=6.4 + - weasyprint>=64.1 + - WTForms>=3.2.1 + - pip + - pip: + - javaobj-py3>=0.4.4 + - jieba>=0.42.1 + - jpush>=3.3.9 + - redis>=5.2.1 + - opencv-python>=4.11.0.86 + - pypinyin>=0.55.0 + - seaborn>=0.13.2 + - tornado-swagger>=1.4.5 + - tornado-wtforms>=0.0.1 \ No newline at end of file diff --git a/paste/__init__.py b/paste/__init__.py new file mode 100755 index 0000000..f414c4e --- /dev/null +++ b/paste/__init__.py @@ -0,0 +1,10 @@ +__package_name__ = "paste" + +__version__ = "2.0.1" + +__author__ = "Wayne Zhang" + +__email__ = "waynezwf@qq.com" + +def get_version(): + return f"{__package_name__} version: V{__version__}, written by {__author__}." \ No newline at end of file diff --git a/paste/chart/__init__.py b/paste/chart/__init__.py new file mode 100644 index 0000000..29bd989 --- /dev/null +++ b/paste/chart/__init__.py @@ -0,0 +1,8 @@ +import seaborn + + +def get_seaborn_colors(n, palette='husl'): + """ + 使用Seaborn的调色板生成颜色。 + """ + return seaborn.color_palette(palette, n_colors=n).as_hex() diff --git a/paste/chart/bar.py b/paste/chart/bar.py new file mode 100644 index 0000000..ae98b42 --- /dev/null +++ b/paste/chart/bar.py @@ -0,0 +1,291 @@ +import base64 +from io import BytesIO +from typing import List, Dict + +import numpy as np +from matplotlib import pyplot as plt + +from paste import chart +from paste.util import ufont + + +# =========================== +# 工具函数:统一字体与保存逻辑 +# =========================== + +def _setup_plot(): + """ + 设置全局绘图参数(字体、负号、DPI),避免重复代码。 + """ + available_font = ufont.get_fonts() + plt.rcParams.update({ + 'font.sans-serif': list(available_font), + 'axes.unicode_minus': False, + }) + + +def _save_to_base64(dpi: int = 128) -> str: + """ + 将当前图形保存为 base64 编码的 SVG 字符串,自动关闭图形。 + 返回 Data URL 格式。 + """ + buffer = BytesIO() + plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight') + plt.close() + image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + return f"data:image/svg+xml;base64,{image_base64}" + + +# =========================== +# 1. gen_vertical_bars:竖向堆叠柱状图 +# =========================== + +def gen_vertical_bars( + primary_values: List[float], + nested_values: List[float], + x_labels: List[str], + group_labels: List[str], + color_palette: str = 'BuPu', + dpi: int = 128 +) -> str: + """ + 生成竖向堆叠柱状图(两个分组)。 + + 数据结构说明: + - primary_values: [v1, v2, ..., vn] # 主柱高度,长度 = n + - nested_values: [v1, v2, ..., vn] # 嵌套柱高度,长度 = n,从主柱顶部叠加 + - x_labels: ['A', 'B', 'C', ...] # 每个柱子的X轴标签,长度 = n + - group_labels: ['Group1', 'Group2'] # 两个分组的图例名称,长度必须为2 + + 注意:所有列表长度必须一致(n),否则会报错。 + 本函数仅支持两个分组,若需扩展,建议使用 genHBars。 + + :param primary_values: 主柱数据(底部),数值列表 + :param nested_values: 嵌套柱数据(顶部),数值列表 + :param x_labels: X轴标签列表 + :param group_labels: 图例标签列表,长度为2 + :param color_palette: Seaborn 色盘名称 + :param dpi: 图像分辨率 + :return: base64 编码的 SVG 图像字符串 + """ + if len(group_labels) != 2: + raise ValueError("group_labels 必须包含两个元素:[主组, 嵌套组]") + + _setup_plot() + + colors = chart.get_seaborn_colors(2, palette=color_palette) + fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi) + + # 绘制主柱(底部) + bars1 = ax.bar(x_labels, primary_values, label=group_labels[0], color=colors[0]) + + # 绘制嵌套柱(从主柱顶部开始,即 bottom=primary_values) + # 注意:原代码中 bottom=0 是错误的!应为 bottom=primary_values + # 修正:从主柱顶部开始叠加 + bars2 = ax.bar(x_labels, nested_values, bottom=primary_values, label=group_labels[1], color=colors[1]) + + # 添加数值标签 + ax.bar_label(bars1, label_type='edge', padding=3) # 边缘标签 + ax.bar_label(bars2, label_type='center', padding=1) # 中心标签 + + # 旋转X轴标签,避免重叠 + plt.xticks(rotation=45, ha='right') + + # 添加图例 + ax.legend() + + # 自动紧凑布局 + plt.tight_layout() + + return _save_to_base64(dpi) + + +# =========================== +# 2. gen_horizontal_stacked_bars:横向堆叠柱状图(多分组) +# =========================== + +def gen_horizontal_stacked_bars( + data_shap: np.ndarray, + x_labels: List[str], + y_labels: List[str], + y_data_unit: str = '', + legend_title: str = '', + color_palette: str = 'BuPu', + dpi: int = 128 +) -> str: + """ + 生成横向堆叠柱状图(支持多个分组)。 + + 数据结构说明: + - data_shap: numpy.ndarray, shape = (len(y_labels), len(x_labels)) + - 每行代表一个 Y 轴类别(如:城市、部门) + - 每列代表一个 X 轴分组(如:年份、产品类型) + - 例如:data_shap[2][1] 表示第3个Y类别(y_labels[2])在第2个X分组(x_labels[1])的数值 + + - x_labels: ['2021', '2022', '2023'] # 每个分组的名称(对应列) + - y_labels: ['北京', '上海', '广州'] # 每个柱子的类别名称(对应行) + - y_data_unit: 如 '人'、'万元',用于标注总量 + + :param data_shap: 二维数值数组,shape=(行数, 列数) + :param x_labels: 每个分组的标签列表 + :param y_labels: 每个柱子的标签列表 + :param y_data_unit: Y轴总量单位(如 '人') + :param legend_title: 图例标题 + :param color_palette: Seaborn 色盘 + :param dpi: 分辨率 + :return: base64 编码的 SVG 图像字符串 + """ + if data_shap.ndim != 2: + raise ValueError("data_shap 必须是二维数组,shape=(Y, X)") + + if len(x_labels) != data_shap.shape[1]: + raise ValueError(f"x_labels 长度 ({len(x_labels)}) 与 data_shap 列数 ({data_shap.shape[1]}) 不匹配") + + if len(y_labels) != data_shap.shape[0]: + raise ValueError(f"y_labels 长度 ({len(y_labels)}) 与 data_shap 行数 ({data_shap.shape[0]}) 不匹配") + + _setup_plot() + + colors = chart.get_seaborn_colors(len(x_labels), palette=color_palette) + fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi) + + # 初始化左侧起始位置(从0开始) + left_pos = np.zeros(len(y_labels)) + bars = [] # 存储每个分组的柱形对象,用于后续标签绘制 + + # 按列(分组)逐个绘制横向堆叠柱形 + for i in range(len(x_labels)): + bar = ax.barh( + y_labels, + data_shap[:, i], + left=left_pos, + label=x_labels[i], + color=colors[i], + edgecolor='white', + linewidth=0.5 + ) + bars.append(bar) + left_pos += data_shap[:, i] # 更新下一次的起始位置 + + # 计算每个Y类别的总和 + totals = np.sum(data_shap, axis=1) + # 构建带总量的Y标签:如 "北京\n1200人" + y_labels_mix = [f"{place}\n{int(total)}{y_data_unit}" for place, total in zip(y_labels, totals)] + ax.set_yticks(np.arange(len(y_labels))) + ax.set_yticklabels(y_labels_mix) + ax.invert_yaxis() # 使第一个类别在顶部 + + # 图例 + ax.legend(title=legend_title, bbox_to_anchor=(0.958, 0.25), frameon=True, framealpha=0.7) + + # 智能标签函数:根据柱宽决定标签位置 + def add_smart_labels(spacing_pixels: int = 10): + dpi = fig.dpi + xlim = ax.get_xlim() + x_range = xlim[1] - xlim[0] + spacing_data = spacing_pixels * x_range / (fig.get_size_inches()[0] * dpi) + + text_style = { + 'fontsize': 9, + 'va': 'center', + 'bbox': { + 'boxstyle': 'round', + 'facecolor': 'white', + 'alpha': 0.7, + 'edgecolor': 'none', + 'pad': 0.3 + } + } + + for i, (place, total_width) in enumerate(zip(y_labels, totals)): + values = [bar[i].get_width() for bar in bars] + label_text = "+".join(str(int(v)) for v in values if v > 0) + if not label_text: + continue + + text_width = (len(label_text) * 0.015 + 0.05) * x_range + y_pos = bars[0][i].get_y() + bars[0][i].get_height() / 2 + + if total_width >= text_width + 2 * spacing_data: + ax.text(total_width - spacing_data, y_pos, label_text, ha='right', **text_style) + else: + ax.text(total_width + spacing_data, y_pos, label_text, ha='left', **text_style) + + add_smart_labels() + + # 添加网格线(仅x轴方向) + ax.grid(axis='x', linestyle=':', alpha=0.6) + plt.tight_layout() + + return _save_to_base64(dpi) + + +# =========================== +# 3. gen_percent_stacked_bars:百分比堆叠柱状图 +# =========================== + +def gen_percent_stacked_bars( + data: Dict[str, List[float]], + x_labels: List[str], + legend_title: str = '', + color_palette: str = 'BuPu', + dpi: int = 128 +) -> str: + """ + 绘制百分比堆叠柱状图(所有柱子高度统一为100%)。 + + 数据结构说明: + - data: Dict[str, List[float]] + - 每个 key 代表一个分组(如 'A组', 'B组') + - 每个 value 是长度为 len(x_labels) 的数值列表,表示该组在每个X标签下的数值 + - 示例:{'A组': [30, 40, 50], 'B组': [70, 60, 50]} → 每列总和为100 + + - x_labels: ['2021', '2022', '2023'] # 每个柱子的标签 + + 注意:每个X位置的总和必须一致(或接近),否则百分比可能失真。 + 本函数自动将每列归一化为100%。 + + :param data: 分组数据字典,键为组名,值为数值列表 + :param x_labels: X轴标签列表 + :param legend_title: 图例标题 + :param color_palette: Seaborn 色盘 + :param dpi: 分辨率 + :return: base64 编码的 SVG 图像字符串 + """ + if not data: + raise ValueError("data 不能为空字典") + + # 验证所有组长度一致 + group_lengths = [len(values) for values in data.values()] + if not all(l == len(x_labels) for l in group_lengths): + raise ValueError(f"所有分组的数值列表长度必须等于 x_labels 长度 ({len(x_labels)})") + + _setup_plot() + + # 转换为 numpy 数组 + group_names = list(data.keys()) + data_array = np.array([data[group] for group in group_names]) # shape: (n_groups, n_x) + percent_data = data_array / data_array.sum(axis=0) * 100 # 按列归一化 + + fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi) + + bottom = np.zeros(len(x_labels)) # 初始底部为0 + + # 绘制每个分组 + for i, (group, color) in enumerate(zip(group_names, chart.get_seaborn_colors(len(group_names), palette=color_palette))): + ax.bar(x_labels, percent_data[i], bottom=bottom, label=group, color=color) + bottom += percent_data[i] + + # 设置Y轴范围为0-100% + ax.set_ylim(0, 100) + + # 添加百分比标签 + for container in ax.containers: + ax.bar_label(container, label_type='center', fmt='%.1f%%', padding=0) + + # 图例 + ax.legend(title=legend_title, loc='center left', bbox_to_anchor=(1, 0.5), frameon=True, framealpha=0.7) + + plt.tight_layout() + + return _save_to_base64(dpi) \ No newline at end of file diff --git a/paste/chart/line.py b/paste/chart/line.py new file mode 100644 index 0000000..c9e5089 --- /dev/null +++ b/paste/chart/line.py @@ -0,0 +1,213 @@ +import base64 +import datetime +from io import BytesIO +from typing import Dict, List, Optional, Union + +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from matplotlib.dates import MonthLocator, DateFormatter +from scipy.interpolate import make_interp_spline + +from paste import chart +from paste.util import ufont + + +def gen_lines(data_dict: Dict[str, pd.Series], color_palette: str = 'BuPu', dpi: int = 128) -> str: + """ + 生成多条折线图,用于对比不同年份的时间序列数据(如月度数据)。 + + 数据结构要求: + - data_dict: Dict[str, pd.Series] + - key: 字符串,表示年份(如 '2022', '2023'),必须可转为整数 + - value: pd.Series,索引为日期(datetime-like),值为数值(如销售额、访问量等) + - 示例: + { + '2022': pd.Series([100, 120, 90], index=pd.date_range('2022-01-01', periods=3, freq='M')), + '2023': pd.Series([110, 130, 95], index=pd.date_range('2023-01-01', periods=3, freq='M')) + } + + 功能说明: + - 自动识别当前年份,并高亮显示其曲线(加粗、加圆点标记) + - 使用每两个月一次的主刻度,格式化为 'MM-DD' + - 保存为 base64 编码的 SVG 图像,适用于 Web 前端直接嵌入 + + :param data_dict: 年度时间序列数据字典,结构如上 + :param color_palette: Seaborn 色盘名称,默认 'BuPu' + :param dpi: 输出图像分辨率,默认 128 + :return: base64 编码的 SVG 图像 Data URL 字符串 + """ + # === 颜色准备 === + colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette) + if len(colors) == 0: + raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效") + + # === 字体设置 === + available_font = ufont.get_fonts() + plt.rcParams['font.sans-serif'] = list(available_font) + plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题 + + # === 创建画布 === + plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi) + ax = plt.gca() + + # === 绘制每条折线 === + current_year = str(datetime.date.today().year) + for i, year in enumerate(data_dict.keys()): + s = data_dict[year] + + # 检查数据有效性 + if len(s) == 0: + raise ValueError(f"数据序列 {year} 为空,请检查输入") + + # 判断是否为当前年份,决定样式 + is_current_year = (year == current_year) + linewidth = 2.5 if is_current_year else 1.5 + marker = 'o' if is_current_year else None + label = f'{year}年' + + ax.plot( + s.index, s.values, + color=colors[i], + linewidth=linewidth, + alpha=0.9, + label=label, + marker=marker, + markersize=4, + markevery=30 # 每30个点显示一个标记,避免密集 + ) + + # === 智能日期刻度 === + ax.xaxis.set_major_locator(MonthLocator(bymonth=range(1, 13, 2))) # 每两个月一个主刻度 + ax.xaxis.set_major_formatter(DateFormatter('%m-%d')) # 格式化为 MM-DD + + # === 网格与边框美化 === + ax.grid(True, which='major', linestyle='--', alpha=0.6) + ax.grid(True, which='minor', linestyle=':', alpha=0.3) + for spine in ['top', 'right']: + ax.spines[spine].set_visible(False) + ax.spines['left'].set_color('#d9d9d9') + ax.spines['bottom'].set_color('#d9d9d9') + + # === 图例 === + legend = ax.legend(loc='upper right', frameon=True, framealpha=0.7, edgecolor='#f0f0f0') + legend.get_frame().set_linewidth(1) + + # === 输出为 base64 SVG === + buffer = BytesIO() + plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none') + plt.close() + + image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + img_base64 = f"data:image/svg+xml;base64,{image_base64}" + return img_base64 + + +def gen_splines(data_dict: Dict[str, pd.Series], total: pd.Series, x_labels: Optional[List[Union[str, int]]] = None, + color_palette: str = 'BuPu', dpi: int = 128) -> str: + """ + 生成平滑曲线图,适用于非时间序列数据(如按类别排序的数值),通过样条插值实现曲线平滑。 + + 数据结构要求: + - data_dict: Dict[str, pd.Series] + - key: 字符串,表示数据系列名称(如 'A产品', 'B产品') + - value: pd.Series,索引为数值型(如 1~12),值为观测值(如销量) + - 示例: + { + 'A产品': pd.Series([10, 15, 12, 18], index=[1,2,3,4]), + 'B产品': pd.Series([8, 14, 16, 20], index=[1,2,3,4]) + } + + - total: pd.Series + - 索引必须与 data_dict 的 key 一致 + - 值为每个系列的总和(用于图例标注) + - 示例:pd.Series({'A产品': 55, 'B产品': 58}) + + - x_labels: List[Union[str, int]], 可选 + - 用于自定义 X 轴刻度标签(通常为原始索引值) + - 若为 None,则使用插值后的密集点(不显示原始刻度) + - 若提供,则在绘图后手动设置刻度,避免插值后刻度混乱 + + 功能说明: + - 按 total 值降序排列绘图顺序(重要数据优先显示) + - 使用三次样条插值(k=3)生成平滑曲线 + - 图例中显示原始总值(如 "A产品(55)") + + :param data_dict: 各系列的原始观测数据,结构如上 + :param total: 每个系列的总和,用于排序和图例标注 + :param x_labels: 可选,原始 X 轴标签列表(如 [1,2,3,4]),用于控制刻度显示 + :param color_palette: Seaborn 色盘名称 + :param dpi: 输出分辨率 + :return: base64 编码的 SVG 图像 Data URL 字符串 + """ + # === 颜色准备 === + colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette) + if len(colors) == 0: + raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效") + + # === 字体设置 === + available_font = ufont.get_fonts() + plt.rcParams['font.sans-serif'] = list(available_font) + plt.rcParams['axes.unicode_minus'] = False + + # === 创建画布 === + plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi) + ax = plt.gca() + + # === 按总值排序,确保重要数据优先显示 === + total_sorted = total.sort_values(ascending=False) + + # === 绘制每条平滑曲线 === + for i, series_name in enumerate(total_sorted.index): + s = data_dict[series_name] + + # 检查数据完整性 + if len(s) < 2: + raise ValueError(f"系列 {series_name} 数据点少于2个,无法进行插值") + + x = s.index.values.astype(float) # 确保为数值类型 + y = s.values.astype(float) + + # 插值:生成 100 个平滑点(线性插值不足以平滑,样条更优) + x_new = np.linspace(x.min(), x.max(), 100) + spline = make_interp_spline(x, y, k=3) # 三次样条插值 + y_smooth = spline(x_new) + + # 转回 Series 便于绘图 + s_smooth = pd.Series(y_smooth, index=x_new) + + ax.plot( + s_smooth.index, s_smooth.values, + color=colors[i], + linewidth=2, + alpha=0.9, + label=f'{series_name}({total_sorted[series_name]})', # 图例中显示总值 + markersize=4, + markevery=30 # 标记稀疏,避免视觉干扰 + ) + + # === 手动设置 X 轴刻度(若提供)=== + if x_labels is not None: + # 确保 x_labels 与原始 x 一致(可选验证) + plt.xticks(ticks=x_labels, labels=x_labels) + + # === 网格与边框 === + ax.grid(True, which='major', linestyle='--', alpha=0.6) + ax.grid(True, which='minor', linestyle=':', alpha=0.3) + for spine in ['top', 'right']: + ax.spines[spine].set_visible(False) + ax.spines['left'].set_color('#d9d9d9') + ax.spines['bottom'].set_color('#d9d9d9') + + # === 图例 === + legend = ax.legend(loc='upper left', frameon=True, framealpha=0.7, edgecolor='#f0f0f0') + legend.get_frame().set_linewidth(1) + + # === 输出为 base64 SVG === + buffer = BytesIO() + plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none') + plt.close() + + image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + img_base64 = f"data:image/svg+xml;base64,{image_base64}" + return img_base64 \ No newline at end of file diff --git a/paste/chart/pie.py b/paste/chart/pie.py new file mode 100644 index 0000000..b24ffea --- /dev/null +++ b/paste/chart/pie.py @@ -0,0 +1,144 @@ +import base64 +from io import BytesIO + +from matplotlib import pyplot as plt +from matplotlib.patches import Rectangle + +from paste import chart +from paste.util import ufont + + +def gen_pie(data_df, value_column, percentage_column, legend_labels, color_palette='BuPu', dpi=128): + """ + 生成环形图(Doughnut Chart),并附带右侧图例(色块 + 文字描述)。 + + 注意:虽然函数名为 gen_pie,但实际绘制的是环形图(有中心空洞),非传统饼图。 + + 参数说明(数据结构要求): + -------------------------- + data_df : pandas.DataFrame + 必须包含以下三列: + - value_column (数值列): 每个类别的数值(如设备数量),用于计算扇区角度。 + - percentage_column (百分比列): 每个类别的百分比(如 35.2%),用于显示在图例中。 + - legend_labels (标签列): 每个类别的名称(如 '服务器', '交换机'),用于图例文本。 + + 示例结构: + | server_count | percentage | device_type | + |--------------|------------|-------------| + | 35 | 35.2% | 服务器 | + | 28 | 28.1% | 交换机 | + | 22 | 22.0% | 路由器 | + + 要求: + - 所有数值必须 ≥ 0; + - 百分比列应为字符串格式(含'%'符号),或可被转换为字符串; + - 所有列长度必须一致,且与 data_df 的行数匹配。 + + value_column : str + data_df 中表示数值的列名(如 'server_count')。 + + percentage_column : str + data_df 中表示百分比的列名(如 'percentage'),用于图例中显示。 + + legend_labels : str + data_df 中表示类别名称的列名(如 'device_type'),用于图例文本。 + + color_palette : str, default='BuPu' + Seaborn 颜色调色板名称,用于为每个扇区分配颜色。 + 可选值参考:'BuPu', 'viridis', 'plasma', 'Set3' 等。 + 若类别数 > 调色板颜色数,会自动循环使用。 + + dpi : int, default=128 + 输出图像的分辨率(dots per inch),影响 SVG 清晰度。 + + 返回值: + -------- + str : Base64 编码的 SVG 图像 Data URL,可直接用于 HTML 。 + """ + + # === 1. 颜色准备 === + # 从 Seaborn 获取指定调色板的颜色序列,长度等于数据行数 + colors = chart.get_seaborn_colors(len(data_df.index), palette=color_palette) + + # === 2. 字体设置 === + # 获取系统可用中文字体,优先使用支持中文的字体 + available_font = ufont.get_fonts() + plt.rcParams['font.sans-serif'] = list(available_font) + plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题 + + # === 3. 创建画布与子图 === + # 使用非白色背景,便于嵌入网页(透明背景) + fig = plt.figure(figsize=(8, 6), facecolor='none', dpi=dpi) + + # 左侧:环形图区域 + ax1 = fig.add_axes((0.1, 0.05, 0.6, 0.9)) # [left, bottom, width, height] + # 右侧:图例区域(色块+文字) + ax2 = fig.add_axes((0.75, 0.05, 0.2, 0.9)) + + # 移除两个子图的坐标轴(纯图形,无刻度) + for ax in [ax1, ax2]: + ax.set_axis_off() + + # === 4. 绘制环形图 === + # 使用 wedgeprops 设置内环宽度,实现环形效果 + wedges, _ = ax1.pie( + data_df[value_column], # 数值列,决定扇区大小 + colors=colors, # 颜色序列 + startangle=90, # 从正上方开始绘制(更美观) + wedgeprops=dict( + width=0.6, # 环形宽度(0~1),越大越薄 + edgecolor='white', # 边缘白色,提升对比度 + linewidth=1 # 边缘线宽 + ), + radius=1.2, # 半径略大于1,避免边缘被裁剪 + ) + + # === 5. 右侧图例:色块 + 文字 === + # 图例参数定义 + num_items = len(data_df.index) + box_size = 0.08 # 每个色块大小(宽度和高度) + text_offset = 0.05 # 文字与色块的水平间距 + font_size = 24 # 字体大小 + line_height = 0.1 # 每行占用高度(包括上下间距) + total_legend_height = num_items * line_height + start_y = 0.45 + total_legend_height / 2 # 使图例垂直居中于右侧区域 + + y_pos = start_y # 当前绘制的 y 坐标 + + # 遍历每一类,绘制色块和文本 + for i, (label, color) in enumerate(zip(data_df[legend_labels], colors)): + # 绘制色块矩形 + ax2.add_patch( + Rectangle( + (0.05, y_pos - box_size / 2), # 左下角坐标:x=0.05,y居中 + box_size, box_size, # 宽高均为 box_size + facecolor=color, # 填充颜色 + edgecolor='white', # 白色描边 + lw=1 # 线宽 + ) + ) + + # 绘制文字标签:名称 + 数值 + 百分比 + ax2.text( + 0.05 + box_size + text_offset, # 文字起始 x 位置:色块右侧 + 间距 + y_pos, # y 位置居中于色块 + f"{label}({data_df[value_column][i]}台,{data_df[percentage_column][i]}%)", + va='center', # 垂直居中 + ha='left', # 水平左对齐 + fontsize=font_size, + fontweight='bold' + ) + + y_pos -= line_height # 下移一行,准备下一个图例项 + + # === 6. 输出图像 === + # 使用 BytesIO 缓存 SVG 图像 + buffer = BytesIO() + plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none') + plt.close() # 关闭图形以释放内存 + + # 编码为 Base64 并构造 Data URL + image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + img_base64 = f"data:image/svg+xml;base64,{image_base64}" + + return img_base64 \ No newline at end of file diff --git a/paste/core/__init__.py b/paste/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paste/core/aio_pool.py b/paste/core/aio_pool.py new file mode 100644 index 0000000..adb0b8e --- /dev/null +++ b/paste/core/aio_pool.py @@ -0,0 +1,153 @@ +import asyncio +import datetime +from typing import Optional, Callable, Set, Any + +import psutil +from dateutil.relativedelta import relativedelta + +from paste.core.logging import echo_log + +MAX_BACKGROUND_TASKS = 3000 +""" +最大任务数量,根据服务器内存调整。 +""" + +aio_loop: Optional[asyncio.AbstractEventLoop] = None +""" +异步循环对象。 +""" + +aio_runner: Optional[Callable] = None +""" +异步方法运行器,对应:: + + asyncio.events.AbstractEventLoop.run_until_complete() + +方法的返回值 。 +""" + +global_background_tasks: Set[asyncio.Task] = set() +""" +全局后台任务池。 +""" + + +async def run_background_task(coro, max_tasks: int = MAX_BACKGROUND_TASKS): + """ + 运行后台任务。 + + :param coro: 要在后台执行的协程。 + :param max_tasks: 背压总量 + """ + global MAX_BACKGROUND_TASKS + if max_tasks != MAX_BACKGROUND_TASKS: + MAX_BACKGROUND_TASKS = max_tasks + + if len(global_background_tasks) >= MAX_BACKGROUND_TASKS: + # 增加背压控制 + await asyncio.wait(global_background_tasks, return_when=asyncio.FIRST_COMPLETED) + + task = asyncio.create_task(coro) + global_background_tasks.add(task) + # 任务完成后自动移除 + task.add_done_callback(global_background_tasks.discard) + + +def get_aio_loop(): + """ + 这里必须采用方法,在适当的时间点创建事件循环对象,否则会导致服务无法启动。 + 主要是测试到 EventLoop 之间的冲突,或异步事件已经在运行,导致无法顺利执行。 + + :return: 事件循环对象 + """ + global aio_loop + if aio_loop: + return aio_loop + else: + try: + # 尝试获取当前运行中的事件循环 + aio_loop = asyncio.get_running_loop() + except RuntimeError: + # 如果没有运行中的循环,才创建新的 + aio_loop = asyncio.new_event_loop() + return aio_loop + + +def get_aio_runner(): + """ + 这里必须采用方法,在适当的时间点创建事件循环对象,否则会导致服务无法启动。 + 主要是测试到 EventLoop 之间的冲突,或异步事件已经在运行,导致无法顺利执行。 + + :return: 运行器对象 + """ + global aio_loop, aio_runner + if aio_runner: + return aio_runner + else: + aio_runner = get_aio_loop().run_until_complete + return aio_runner + + +def process_info(pid: int): + """ + 若传入的 PID 存在,则返回进程信息,否则返回 None。 + + :param pid: 进程 ID + :return: 进程信息 + """ + try: + process = psutil.Process(pid) + _delta = relativedelta(datetime.datetime.now(), datetime.datetime.fromtimestamp(process.create_time())) + _d, _h, _m, _s = _delta.days, _delta.hours, _delta.minutes, _delta.seconds + process.cpu_percent() + return { + 'name': process.name(), + 'cpu_usage': f"{process.cpu_percent()}%", + 'memory_usage': f"{process.memory_info().rss / (1024 * 1024):.3f}MB", + 'running_time': f"{_d}天{_h}时{_m}分{_s}秒", + } + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + return None + + +async def gather_with_retry(*coro_constructors: Callable[[], Any], + max_retries: int = 3, + eba: int = 0.5, + return_exceptions: bool = False) -> tuple[Any, ...]: + """ + 封装 asyncio.gather,支持对一组并发任务进行 N 次重试。 + + 与 asyncio.gather 对齐: + - 使用 *args 传参,而非列表 + - 支持 return_exceptions 参数 + - 返回 list 类型,顺序一致 + + 但要求:每个参数必须是一个**无参函数**,调用后返回一个 awaitable。 + 这样才能在每次重试时重新创建任务。 + + Args: + *coro_constructors: Callable 对象,async 方法要用 lambda 封装 + max_retries (int): 最大重试次数(总尝试次数 = max_retries + 1) + eba (int): 指数退避的起始等待时间 + return_exceptions (bool): 若为 True,异常作为结果返回,不抛出 + + Returns: + list: 所有任务结果列表。若 return_exceptions=True,异常也会作为列表元素。 + + Raises: + Exception: 当 return_exceptions=False 且所有重试都失败时,抛出最后一次尝试中的第一个异常。 + """ + for attempt in range(max_retries + 1): + try: + tasks = [ctor() for ctor in coro_constructors] + results = await asyncio.gather(*tasks, return_exceptions=return_exceptions) + return tuple(results) + except Exception as e: + if attempt == max_retries: + echo_log(f"共执行 {max_retries} 次后全部失败: {str(e)}") + else: + echo_log(f"执行第 {attempt + 1} 次重试失败.") + # 指数退避 + await asyncio.sleep(eba * (attempt + 1)) + + raise RuntimeError("Unreachable") diff --git a/paste/core/config.py b/paste/core/config.py new file mode 100644 index 0000000..115cf96 --- /dev/null +++ b/paste/core/config.py @@ -0,0 +1,53 @@ +""" +读取配置信息的方法集合。 +""" + +import json +import os + +from paste.util import ufile, udict + +GLOBAL_CONFIG = None +""" +全局单例配置系统。 +""" + + +def load_config() -> dict: + """ + 生成配置 JSON 对象。 + + :return: JSON 对象 + """ + global GLOBAL_CONFIG + if GLOBAL_CONFIG is None: + config_file = os.path.abspath(os.path.join(os.path.curdir, 'config.json')) + GLOBAL_CONFIG = json.loads(ufile.read_to_buffer(config_file)) + return GLOBAL_CONFIG + + +def get_config_by_path(path: str, default=None): + """ + 读取配置参数。若 path 存在则返回值;若 path 不存在,且没有默认值,则抛出异常,否则返回默认值。 + + :param path: 字典中的 key 路径,以"."号分隔 + :param default: 默认值,为 None 时表示未设置,此时若键名不存在,会抛出异常 + """ + config = load_config() + _result = udict.get_by_path(config, path, default) + if _result is None: + if default is None: + raise AssertionError('未读取到配置参数: %s' % path) + else: + return default + return _result + + +def get_config(key: str, default=None): + """ + 读取配置参数。若 key 存在则返回值;若 key 不存在,且没有默认值,则抛出异常,否则返回默认值。 + + :param key: 键名,或配置字典中的 path + :param default: 默认值,为 None 时表示未设置,此时若键名不存在,会抛出异常 + """ + return get_config_by_path(path=key, default=default) diff --git a/paste/core/logging.py b/paste/core/logging.py new file mode 100644 index 0000000..e483015 --- /dev/null +++ b/paste/core/logging.py @@ -0,0 +1,160 @@ +""" +实现对日志文件的配置封装,详细参考 getLogger 方法。 +输出日志使用 echo_log 方法。 +输出到日志文件使用 logToFile 方法。 +""" + +import logging +import sys +import traceback +from logging import handlers +from typing import Any, Union, Optional + +from paste.core import config + +logger_config_name = 'logger.default' +""" +默认配置字段名称,当在 getLogger 方法中设置了不同名称后,该变量会被修改。 +""" + +paste_logger: Optional[logging.Logger] = None +""" +全局日志对象,获取日志对象时初始化。 +""" + + +def set_logger_config(config_name: str): + """ + 设置新的日志配置名称。 + + :param config_name: 日志配置名称 + """ + global logger_config_name + if config_name != logger_config_name: + logger_config_name = config_name + + +def get_logger(): + """ + 取得日志对象。先根据配置,更新系统日志配置。若配置了额外的日志文件、格式、层级,则增加响应的日志输出。 + + 注意:除非额外配置,否则都使用与系统日志相同的配置参数。 + + 配置结构参考:: + + { + "logger": { + "basic": { + "filename": "sys_log.log", + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + "level": 20 + }, + "filename": "zb_data_log.log", + "name": "ZbData" + } + } + """ + global paste_logger + + if paste_logger is None: + # 系统日志基本配置 + _log_base_cfg: dict = config.get_config(f'{logger_config_name}.basic', {}) + # 系统日志文件 + _base_log_file = _log_base_cfg.get('filename', '') + _base_log_format = _log_base_cfg.get('format', '') + _base_formatter = logging.Formatter(_base_log_format) + _base_level = _log_base_cfg.get('level', logging.INFO) + + # 日志名称,若不配置,则使用名称:SQL_ALEX + _log_name = config.get_config(f'{logger_config_name}.name', 'SQL_ALEX') + # 日志文件最大值 + _log_max_bytes = config.get_config(f'{logger_config_name}.max_bytes', 0) + # 日志文件备份数量 + _log_backup_count = config.get_config(f'{logger_config_name}.backup_count', 0) + # 日志格式,若不配置,则使用 base 中的配置 + _log_format = config.get_config(f'{logger_config_name}.format', _base_log_format) + _formatter = logging.Formatter(_log_format) + # 日志层级,若不配置,则使用 base 中的配置 + _log_level = config.get_config(f'{logger_config_name}.level', _base_level) + # 日志文件 + _log_file = config.get_config(f'{logger_config_name}.filename', '') + + # 更新系统日志基本配置 + logging.basicConfig(**_log_base_cfg) + # 重新绑定系统日志文件句柄 + _base_log_file_handler: Optional[handlers.RotatingFileHandler] = None + if _base_log_file not in (None, ''): + _base_log_file_handler = handlers.RotatingFileHandler( + _base_log_file, maxBytes=_log_max_bytes, backupCount=_log_backup_count + ) + _base_log_file_handler.setFormatter(_base_formatter) + logging.root.handlers = [_base_log_file_handler] + + # 创建日志对象 + paste_logger = logging.Logger(name=_log_name, level=_log_level) + # 绑定日志文件句柄 + if _log_file not in ('', None): + # 若配置了日志文件,则创建文件句柄 + _file_handler = handlers.RotatingFileHandler( + _log_file, maxBytes=_log_max_bytes, backupCount=_log_backup_count + ) + _file_handler.setFormatter(_formatter) + paste_logger.addHandler(_file_handler) + else: + # 若未配置,则使用系统日志文件 + if _base_log_file_handler is not None: + paste_logger.addHandler(_base_log_file_handler) + + # 绑定控制台输出 + _console_handler = logging.StreamHandler() + _console_handler.setFormatter(_formatter) + paste_logger.addHandler(_console_handler) + + return paste_logger + + +def echo_log(msg: Union[str, Exception], level: int = logging.INFO, is_log_exc: bool = False): + """ + 输出日志文本。默认输出到日志文件,但是可能不便于查询,这里应该考虑支持输出到日志数据库。 + + :param msg: 消息内容,当是 Exception 对象时,从 args 中取出第一项作为消息 + :param level: 消息等级 + :param is_log_exc: 是否输出异常的 Traceback 信息到日志文件 + """ + _root = logging.root + _logging = get_logger() + _log_level = level + if isinstance(msg, Exception): + _log_level = logging.ERROR + if len(msg.args) > 0 and isinstance(msg.args[0], str): + msg = msg.args[0] + else: + msg = str(msg) + + _logging.log(level=_log_level, msg=msg) + if is_log_exc: + exception_to_file() + + +def exception_to_file(): + """ + 自动检测异常,并输出异常的 Traceback 信息到日志文件。 + """ + _, _, tb = sys.exc_info() + if tb is not None: + _msg_list = ['Traceback: \n\n'] + traceback.format_tb(tb) + log_to_file(msg=''.join(_msg_list), level=logging.ERROR) + + +def log_to_file(msg: Any, level: int = logging.INFO): + """ + 输出消息到日志文件。 + + :param msg: 消息 + :param level: 消息等级 + """ + _logger = get_logger() + _record = _logger.makeRecord(name=_logger.name, level=level, fn=__file__, lno=0, args=(), exc_info=None, msg=msg) + for hdl in _logger.handlers: + if isinstance(hdl, logging.FileHandler): + hdl.handle(_record) diff --git a/paste/db/__init__.py b/paste/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paste/db/baseadapter.py b/paste/db/baseadapter.py new file mode 100755 index 0000000..8f811af --- /dev/null +++ b/paste/db/baseadapter.py @@ -0,0 +1,830 @@ +""" +数据访问适配器。主要封装了数据访问的基础方法。 +""" +import datetime +import random +import uuid +from typing import Optional, Any, Union, List, Sequence + +import pandas as pd +from sqlalchemy import Column, String, DateTime, Date, inspect, Table, Integer, select, Numeric, ForeignKey, func, \ + text, desc, Result, Engine, tuple_ +from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker +from sqlalchemy.orm import sessionmaker, Session, InstanceState, InstrumentedAttribute, registry +from sqlalchemy.sql import Select +from sqlalchemy.sql.base import ColumnCollection +from sqlalchemy.sql.compiler import StrSQLCompiler +from sqlalchemy.sql.elements import BinaryExpression + +from paste.db import engine +from paste.util import udict + +LOCAL_DATE_FORMAT = '%Y-%m-%d' +LOCAL_TIME_FORMAT = '%H:%M:%S' +LOCAL_DATETIME_FORMAT = f'{LOCAL_DATE_FORMAT} {LOCAL_TIME_FORMAT}' + + +def guid() -> str: + """ + 生成 GUID。 + + :return: GUID + """ + return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(uuid.uuid1()) + str(random.random()))).replace('-', '') + + +def get_session() -> Session: + """ + 获取线程安全的连接会话对象。 + + :return: 连接会话对象 + """ + _engine: Engine = engine.connect_engine() + session_maker = sessionmaker(bind=_engine) + return session_maker() + + +def get_aio_session() -> AsyncSession: + """ + 获取异步连接会话对象。 + + :return: 连接会话对象 + """ + _engine: AsyncEngine = engine.async_connect_engine() + session_maker = async_sessionmaker(bind=_engine, class_=AsyncSession) + return session_maker() + + +Base = registry().generate_base() +""" +取得 SQLAlchemy ORM 的声明式基类。 +""" + + +class BaseAdapter(Base): + """ + 适配器类,集成数据库引擎、连接、会话、元数据操作、部分 DML 数据处理等功能。 + 主要实现了同步执行方法。 + """ + __abstract__ = True + + def __init__(self, **kwargs): + """ + 构造模型对象。 + + :param args: + :param kwargs: + """ + self._is_new = True + super().__init__(**kwargs) + + @classmethod + def get_session(cls): + """ + 线程安全的连接会话。 + + :return: 同步会话对象 + """ + return get_session() + + @classmethod + def get_aio_session(cls): + """ + 异步连接会话。 + + :return: 异步会话对象 + """ + return get_aio_session() + + @classmethod + def ping(cls): + """ + 尝试连接数据库服务器。 + + :return: 连接结果 + """ + engine.connect_engine().connect() + + @classmethod + async def tables_in_db(cls): + """ + 取得数据库中所有表名称。 + :return: 表名称列表 + """ + _query = select(text('table_name')).select_from( + text('information_schema.tables') + ).where( + text('table_schema = database()') + ) + _names = list() + _session = cls.get_aio_session() + try: + _result: Result = await _session.execute(_query) + _names = list(_result.scalars().all()) + finally: + await _session.close() + return _names + + @classmethod + async def is_table_exist(cls, table_name: str): + """ + 判断表是否存在。 + :param table_name: 要判断的表名称 + """ + _query = select(func.count()).select_from( + text('information_schema.tables') + ).where( + text('table_schema = database()'), text('table_name = :table_name') + ).params( + table_name=table_name + ) + _count = 0 + _session = cls.get_aio_session() + try: + _result: Result = await _session.execute(_query) + _count = int(_result.scalar()) + finally: + await _session.close() + return _count > 0 + + @classmethod + def table(cls) -> Table: + """ + 取得当前模型对象所对应的表对象。 + + :return: 表对象 + """ + tables = cls.metadata.tables + assert cls.__tablename__ is not None, '属性:cls.__tablename__ 未找到.' + assert cls.__tablename__ in tables, ('数据库中未找到表:%s.' % cls.__tablename__) + return tables[cls.__tablename__] + + @classmethod + def columns(cls) -> List[Column]: + """ + 当前对象配置的所有列。 + + :return: 所有列 + """ + cols = cls.table().columns + assert isinstance(cols, ColumnCollection), '%s 列类型错误' % cls.__name__ + return list(cols) + + @classmethod + def instrument_attr(cls, column_name: str) -> Optional[InstrumentedAttribute]: + """ + 取得查询条件绑定属性,用于配置查询条件表达式。 + + :param column_name: 列名称 + :return: 类配置的列信息,即查询条件绑定属性 + """ + attr = getattr(cls, column_name) + if not isinstance(attr, InstrumentedAttribute): + return None + return attr + + @classmethod + def label(cls, column: Union[Column, Any]): + """ + 以字段的 comment 作为标签返回。 + + :return: 指定列的备注 + """ + return column.comment + + @classmethod + def labels(cls) -> dict: + """ + 取得所有列的名称描述。 + + :return: 所有列的描述字典,结构为:{`field`: `comment`} + """ + _label_dict = {} + for _col in cls.columns(): + _label_dict[_col.key] = _col.comment + return _label_dict + + @classmethod + def field(cls, column: Union[Column, Any]) -> str: + """ + 取得字段名称。 + + :param column: 列对象 + :return: 字段名称 + """ + return column.key + + @classmethod + def fields(cls) -> List[str]: + """ + 取得所有字段名称列表。 + + :return: 所有字段名称列表 + """ + return [col.key for col in cls.columns()] + + @classmethod + def new_id(cls, **kwargs) -> str: + """ + 用GUID作为新的ID。 + + :return: ID + """ + return guid() + + @property + def is_new(self): + """ + 通过检查是否已经存在ID判断是否为新建对象。 + + :return: 是否为新建对象 + """ + return hasattr(self, '_is_new') and self._is_new + + @classmethod + def raw_sql(cls, query) -> StrSQLCompiler: + """ + 显示编译后的原始 SQL 命令。 + + :return: 编译后的 SQL 命令文本 + """ + raw_sql = query.compile(compile_kwargs={'literal_binds': True}) + return raw_sql + + @classmethod + def row_count(cls, *where_expressions: BinaryExpression) -> int: + """ + 按条件取得记录行数。 + + :param where_expressions: 查询条件 + :return: 返回记录行数 + """ + query = select(func.count(1)).select_from(cls).where(*where_expressions) + session = cls.get_session() + count = session.execute(query).scalars().first() + session.close() + return count + + @classmethod + def exist_by_id(cls, id_val: Union[str, int]): + """ + 检查是否有存在的主键。 + + :param id_val: 主键值 + """ + _wheres = [cls.instrument_attr('id') == id_val] + count: int = cls.row_count(*_wheres) + return count >= 1 + + @classmethod + def search_wheres(cls, likes: Optional[dict[str, str]] = None, **kwargs): + """ + 按参数组织查询条件。 + + :param likes: 需要执行模糊查询的列 + :param kwargs: 属性填充参数 + :return: 查询条件列表 + """ + if likes is None: + likes = [] + _query_model = cls().copy_from_dict(kwargs) + return _query_model.filter_expressions(likes=likes) + + @classmethod + def datalist_query(cls, row_list: List[dict], condition_cols: List[Column]) -> Optional[Select]: + """ + 根据数据列表、查询条件列生成查询。仅支持单表查询。 + + :param row_list: 数据列表 + :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 + :return: 查询到的数据模型列表 + """ + if not row_list or not condition_cols: + return None + + _condition_fields = [_col.key for _col in condition_cols] + + # 验证字段存在 + for field in _condition_fields: + if not hasattr(cls, field): + raise AttributeError(f"Field '{field}' not found in {cls.__name__}") + + # 构建值列表的元组形式 + _values_tuples = [] + for _row in row_list: + _val_list = [] + for _f in _condition_fields: + _val = udict.get_with_default(_row, _f, '') + # 单独处理日期时间格式,但是此处固定格式 + if isinstance(_val, datetime.datetime): + _val = _val.strftime(LOCAL_DATETIME_FORMAT) + if isinstance(_val, datetime.date): + _val = _val.strftime(LOCAL_DATE_FORMAT) + _val_list.append(_val) + + # 过滤掉全为空的元组 + if any(v != '' for v in _val_list): + _values_tuples.append(tuple(_val_list)) + + if not _values_tuples: + return None + + # 去重,提高性能 + _values_tuples = list(set(_values_tuples)) + + # 使用参数化查询 + if len(condition_cols) == 1: + # 单字段查询 + field = getattr(cls, _condition_fields[0]) + _single_values = [t[0] for t in _values_tuples] + query = select(cls).where(field.in_(_single_values)) + else: + # 多字段组合查询 + conditions = [getattr(cls, field) for field in _condition_fields] + query = select(cls).where(tuple_(*conditions).in_(_values_tuples)) + + return query + + @classmethod + def dataframe_query(cls, row_df: pd.DataFrame, condition_cols: List[Column]): + """ + 根据数据列表、查询条件列生成查询。仅支持单表查询。 + + :param row_df: 数据框架 + :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 + :return: 查询到的数据模型列表 + """ + if row_df.empty or not condition_cols: + return None + + _condition_fields = [_col.key for _col in condition_cols] + + # 验证字段存在 + for field in _condition_fields: + if not hasattr(cls, field): + raise AttributeError(f"Field '{field}' not found in {cls.__name__}") + + # 去除重复行并删除包含 NaN 的行 + row_df = row_df[_condition_fields].drop_duplicates().dropna() + + if row_df.empty: + return None + + if len(condition_cols) == 1: + # 单字段查询 + _field_name = _condition_fields[0] + field = getattr(cls, _field_name) + # 直接获取 Series 的值 + _values_list = row_df[_field_name].tolist() + if not _values_list: + return None + query = select(cls).where(field.in_(_values_list)) + else: + # 多字段组合查询 - 使用向量化操作避免 iterrows() + _values_tuples = [tuple(row) for row in row_df.values] + # 再次去重确保 + _values_tuples = list(set(_values_tuples)) + + if not _values_tuples: + return None + + conditions = [getattr(cls, field) for field in _condition_fields] + query = select(cls).where(tuple_(*conditions).in_(_values_tuples)) + + return query + + @classmethod + def find_by_id(cls, id_val: Union[str, int], reset_session: Optional[bool] = True) -> Optional['BaseAdapter']: + """ + 根据主键ID查找数据,确认有 id 字段后方可使用。 + + :param id_val: 主键值 + :param reset_session: 重置会话连接 + :return: 数据模型对象 + """ + query = select(cls).where(cls.instrument_attr('id') == id_val) + session = cls.get_session() + model = session.execute(query).scalars().first() + session.close() + + if reset_session and isinstance(model, BaseAdapter): + model.reset_session() + + return model + + @classmethod + def find_by_created(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']: + """ + 按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。 + + :param from_t: 开始时间 + :param to_t: 结束时间 + """ + if to_t is None: + to_t = from_t + + query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t)) + session = cls.get_session() + rows = session.execute(query).scalars().all() + session.close() + return rows + + @classmethod + def find_by_updated(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']: + """ + 按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。 + + :param from_t: 开始时间 + :param to_t: 结束时间 + """ + if to_t is None: + to_t = from_t + + query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t)) + session = cls.get_session() + rows = session.execute(query).scalars().all() + session.close() + return rows + + @classmethod + def find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]) -> Sequence['BaseAdapter']: + """ + 根据数据列表查询已经在数据库中的数据。 + + :param row_list: 数据列表 + :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 + :return: 查询到的数据模型列表 + """ + rows: Sequence[cls] = [] + _query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols) + if _query is not None: + session = cls.get_session() + rows = session.execute(_query).scalars().all() + session.close() + + return rows + + @classmethod + def find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]): + """ + 根据数据框架查询已经在数据库中的数据。 + + :param row_df: 数据框架 + :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 + :return: 查询到的数据模型列表 + """ + _query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols) + if _query is not None: + session = cls.get_session() + _rows = session.connection().execute(_query).all() + session.close() + return pd.DataFrame(_rows) + + return None + + def load(self, model_data: dict): + """ + 从字典对象载入数据。 + 该方法仅从字典中读取与对象属性对应的数据,忽略其他数据。 + 注意:与 copyXXX 方法不同,该方法跳过 model_data 中的 None 值,保留原始值不变 + + :param model_data: 包含数据的字典对象 + """ + for attr in self.__dir__(): + if model_data.get(attr, None) is not None: + self.__setattr__(attr, model_data.get(attr)) + return self + + def copy_from(self, source: 'BaseAdapter', mapping_fields: Optional[List[str]] = None): + """ + 从源数据模型复制数据,仅复制相同字段的数据。 + + :param source: 源数据对象 + :param mapping_fields: 映射字段列表,若为空,则从源数据模型中获取字段列表 + """ + + # 源对象字段列表 + if mapping_fields is None: + mapping_fields = [column.key for column in source.columns()] + + for column in self.columns(): + # 若源对象中不包含同名字段,则跳过 + if column.key not in mapping_fields: + continue + + # 取出对象属性值 + attr_val = source.ins_value(column.key) + setattr(self, column.key, attr_val) + + return self + + def copy_from_dict(self, source: dict, mapping_keys: Union[list, tuple, set] = None, skip_none: bool = False): + """ + 从源字典对象复制数据,仅复制相同键值的数据。 + + :param source: 源数据对象 + :param mapping_keys: 映射键列表,若为空,则从源字典对象中获取键列表 + :param skip_none: 是否跳过 None 值,即若源数据对象中属性值为 None 时,跳过 + :return 返回自身 + """ + + # 源对象字段列表 + if mapping_keys is None: + mapping_keys = source.keys() + + for column in self.columns(): + if column.key not in mapping_keys: + # 若源对象中不包含同名字段,则跳过 + continue + + # 取出对象属性值 + attr_val = source.get(column.key, None) + if skip_none and attr_val is None: + # 若设置了跳过 None 且此时属性值为空时,跳过 + continue + setattr(self, column.key, attr_val) + + return self + + def inspect(self) -> InstanceState: + """ + 取得对象的监视对象。 + + :return: 监视对象 + """ + return inspect(self) + + def session(self) -> Session: + """ + 取得当前对象的连接会话对象。 + + :return: 连接会话对象 + """ + return self.inspect().session + + def has_session(self) -> bool: + """ + 检测对象是否已有连接会话对象。 + + :return: 有则返回 True 否则返回 False + """ + return self.session() is not None + + def add_session(self): + """ + 将对象加入到连接会话对象。 + """ + if not self.has_session(): + self.get_session().add(self) + return self.session() + + def close_session(self): + """ + 关闭会话。 + """ + if self.has_session(): + self.session().close() + + def reset_session(self): + """ + 重置对象的连接会话。 + """ + self.close_session() + self.add_session() + + def ins_value(self, column_name: str) -> Any: + """ + 取得查询条件绑定的值,该值与对象属性一致。与:: + self.column_name + 或 + self['column_name'] + + :param column_name: 列名 + :return: 对象属性值 + """ + # 取出对象属性值 + attr_val = getattr(self, column_name) + if attr_val is None: + return None + return attr_val + + def filter_expressions(self, likes: Optional[dict[str, str]] = None) -> List[BinaryExpression]: + """ + 自动生成查询条件表达式。 + 参数 likes 为需要执行模糊查询的字段字典:: + + { + field_name1: '%{}%', + field_name2: '{}%', + field_name3: '%{}', + ... + } + + :param likes: 需要执行模糊查询的字段字典 + :return: 包含查询条件表达式的列表 + """ + expressions = list() + for column in self.columns(): + # 取出对象属性值 + attr_val = self.ins_value(column.key) + # 若属性值为空,不增加条件,跳出 + if attr_val in (None, ''): + continue + + # 取出类属性 + attr = self.instrument_attr(column.key) + # 若类属性类型不正确,不增加条件,跳出 + if attr is None: + continue + + # 取出列的数据类型 + column_type = column.type + if isinstance(column_type, (String, Integer, Numeric, ForeignKey)): + # 针对属性数据类型,进行不同处理 + if isinstance(attr_val, (list, tuple, set)): + if not attr_val: + # 跳过0长数组 + continue + # 数组使用 in + expressions.append(attr.in_(attr_val)) + elif isinstance(attr_val, (int, float)): + # 数值类型 + expressions.append(attr == attr_val) + else: + # 字符类型,使用 like + if likes and attr.key in likes: + _f_str = likes[attr.key] + expressions.append(attr.like(_f_str.format(attr_val))) + else: + expressions.append(attr == attr_val) + elif isinstance(column_type, (DateTime, Date)): + if isinstance(attr_val, (list, tuple, set)): + if len(attr_val) == 2: + # 如果长度为 2 则使用 between and 方式 + expressions.append(attr.between(attr_val[0], attr_val[1])) + else: + # 数组使用 in + expressions.append(attr.in_(attr_val)) + else: + expressions.append(attr.between(attr_val, attr_val)) + else: + # 其他类型,用等号 + expressions.append(attr == attr_val) + + return expressions + + def gen_query(self, likes: Optional[dict[str, str]] = None): + """ + 根据自身参数生成查询对象。 + + :param likes: 需要执行模糊查询的字段字典 + :return: 查询对象 + """ + cls = self.__class__ + expressions = self.filter_expressions(likes) + _query: Select = select(cls).where(*expressions) + return _query + + def find(self) -> Sequence['BaseAdapter']: + """ + 根据自身参数,查询数据库。 + + :return: 查询到的结果对象 + """ + session = self.get_session() + _model_list = session.execute(self.gen_query()).scalars().all() + session.close() + return _model_list + + def find_first(self) -> 'BaseAdapter': + """ + 根据自身参数,查询数据库,仅查询第一条。 + + :return: 查询到的结果对象 + """ + session = self.get_session() + _model = session.execute(self.gen_query()).scalars().first() + session.close() + return _model + + def find_piece(self, *where: BinaryExpression, offset=0, limit=500, is_desc=False, + likes: Optional[dict[str, str]] = None): + """ + 根据自身参数,查询数据库。 + + :param where: 查询条件 + :param offset: 偏移量 + :param limit: 读取数量 + :param is_desc: 是否逆序排列 + :param likes: 模糊条件 + :return: 查询到的结果对象 + """ + clz = self.__class__ + expressions = self.filter_expressions(likes=likes) + if where is not None: + expressions += where + + query = select(clz).where(*expressions) + if limit > 0: + query = query.limit(limit=limit) + if offset >= 0: + query = query.offset(offset=offset) + if is_desc: + if hasattr(clz, 'id'): + query = query.order_by(desc(clz.id)) + + session = self.get_session() + _model_list = session.execute(query).scalars().all() + session.close() + return _model_list + + def before_save(self): + """ + 保存前的动作,一般应当在子类中覆盖该方法,增加在保存前应当执行的动作。 + """ + if self.is_new: + if hasattr(self, 'id'): + setattr(self, 'id', self.new_id()) + + if hasattr(self, 'created_at'): + setattr(self, 'created_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT)) + + if hasattr(self, 'updated_at'): + setattr(self, 'updated_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT)) + + return self + + def save(self, auto_expunge: Optional[bool] = True, session: Optional[Session] = None): + """ + 保存数据模型对象,若该对象尚未有连接会话,则自动加入连接会话。 + + :param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效 + :param session: 会话对象,主要用于事务 + :return: 保存状态 + """ + _has_session: bool = True + """ + 该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。 + """ + + self.before_save() + if session is None: + _has_session = False + session = self.add_session() + + try: + session.add(self) + if not _has_session: + # 使用新会话时,直接提交 + session.commit() + self._is_new = False + except Exception as e: + session.rollback() + raise e + else: + if auto_expunge and not _has_session: + session.refresh(self) + session.expunge(self) + return True + + def to_dict(self) -> dict: + """ + 数据模型转字典。递归处理内部类型对象。 + + :return: 转换后的字典数据 + """ + # 模型数据字典 + m_dict = {} + + # 遍历处理内部转换 + for _key, _val in dict(self.__dict__).items(): + if f'{_key}'.startswith('_'): + # 跳过所有私有属性 + continue + + if isinstance(_val, BaseAdapter): + # 内部数据对象数据,直接转换 + m_dict[_key] = _val.to_dict() + elif isinstance(_val, list): + # 遍历转换数据对象列表 + _tmp_list = [] + for _i, _v in enumerate(_val): + if isinstance(_v, BaseAdapter): + _tmp_list.append(_v.to_dict()) + else: + _tmp_list.append(_v) + m_dict[_key] = _tmp_list + elif isinstance(_val, dict): + # 遍历转换数据对象字典 + _tmp_dict = {} + for _ik, _iv in _val.items(): + if isinstance(_iv, BaseAdapter): + _tmp_dict[_ik] = _iv.to_dict() + else: + _tmp_dict[_ik] = _iv + m_dict[_key] = _tmp_dict + else: + # 其他属性直接赋值 + m_dict[_key] = _val + + return m_dict diff --git a/paste/db/basemodel.py b/paste/db/basemodel.py new file mode 100644 index 0000000..c14f473 --- /dev/null +++ b/paste/db/basemodel.py @@ -0,0 +1,629 @@ +""" +数据模型基础类,继承于数据表。集成了模型的基础功能,如数据验证、错误消息、数据影射、对象比较等功能。 +""" +import datetime +from decimal import Decimal, ROUND_HALF_UP +from typing import Union, Any, Optional, Callable + +import pandas as pd +from sqlalchemy import Column, text, desc + +from paste.db import baseadapter +from paste.db.basetable import BaseTable +from paste.util import udict, ustr +from paste.util.pagination import Pagination +from paste.util.snow_id import IdWorker + +LOCAL_DATE_FORMAT = baseadapter.LOCAL_DATE_FORMAT +LOCAL_TIME_FORMAT = baseadapter.LOCAL_TIME_FORMAT +LOCAL_DATETIME_FORMAT = baseadapter.LOCAL_DATETIME_FORMAT + + +class BaseModel(BaseTable): + """ + 数据模型基类。集成了验证辅助功能。 + """ + + __abstract__ = True + + @classmethod + def new_id(cls, datacenter_id: int = 1, worker_id: int = 1, sequence: int = 0) -> int: + """ + 生成新的 Snow ID 对象,并生成 ID 值。 + + :param datacenter_id: 数据中心(机器区域)ID + :param worker_id: 机器ID + :param sequence: 起始序号 + :return: 新的 Snow ID 值 + """ + return IdWorker.get_id_worker(datacenter_id, worker_id, sequence).get_id() + + @classmethod + def now(cls): + """ + 取得当前时间的格式化字符串。 + + :return: 当前时间格式化字符串 + """ + return datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT) + + @classmethod + def is_len(cls, v: str, length: int): + """ + 检测字符串长度的函数,例如检测是否是18位。主要用于数据校验。 + + :param v: 待检测的值 + :param length: 目标长度 + :return: 相同返回 True,否则返回 False + """ + v = f"{v}" + return not cls.is_empty_or_none(v) and len(v) == length + + @classmethod + def is_in_range(cls, v: Union[int, float], v_min: Union[int, float], v_max: Union[int, float]): + """ + 返回检测数值范围的函数,检测是处于最大最小值范围内。主要用于数据校验。 + + :param v: 待检测的值 + :param v_min: 最小值(包含) + :param v_max: 最大值(包含) + :return: 在范围内返回 True,否则返回 False + """ + return not cls.is_empty_or_none(v) and v_min <= v <= v_max + + @classmethod + def is_in_items(cls, v: Union[int, float], items: list = None): + """ + 返回检测数值是否在列表中。主要用于数据校验。 + + :param v: 待检测的值 + :param items: 所有项目 + :return: 在列表内返回 True,否则返回 False + """ + return items is not None and v in items + + @classmethod + def is_empty_or_none(cls, v: Any): + """ + 检查是 None 或 空字符串。 + + :param v: 待检查的内容 + :return: 为 None 或 Nan 或 '' 时返回 True,否则返回 False + """ + return v is None or pd.isna(v) or f"{v}" == '' + + @classmethod + def not_empty_or_none(cls, v: Any): + """ + 与 isEmptyOrNone 函数功能相反。 + """ + return not cls.is_empty_or_none(v) + + @classmethod + def is_digit(cls, v: str): + """ + 检查字符串是否是整数。 + + :param v: 带检查内容 + :return: 是整数返回 True,否则返回 False + """ + v = f"{v}" + return v.isdigit() + + @classmethod + def is_decimal(cls, v: str): + """ + 检查是否是浮点数,若为整数,也返回 True。 + :param v: 待检查内容 + :return: 浮点数或整数返回 True,否则返回 False + """ + v = f"{v}" + is_decimal = True + vs = v.replace(',', '').split('.') + for _v in vs: + is_decimal = is_decimal and cls.is_digit(_v) + return is_decimal + + @classmethod + def is_datetime(cls, v: str): + """ + 检查是否是日期时间格式。 + :param v: 待检查内容 + :return: 日期时间返回 True,否则返回 False + """ + try: + datetime.datetime.strptime(v, LOCAL_DATETIME_FORMAT) + except (ValueError, Exception): + return False + return True + + @classmethod + def is_date(cls, v: str): + """ + 检查是否是日期格式。 + + :param v: 待检查内容 + :return: 日期返回 True,否则返回 False + """ + try: + datetime.datetime.strptime(v, LOCAL_DATE_FORMAT) + except (ValueError, Exception): + return False + return True + + @classmethod + def is_time(cls, v: str): + """ + 检查是否是时间格式。 + + :param v: 待检查内容 + :return: 时间返回 True,否则返回 False + """ + try: + datetime.datetime.strptime(v, LOCAL_TIME_FORMAT) + except (ValueError, Exception): + return False + return True + + @classmethod + def error_empty_msg(cls, column: Union[Column, Any]): + """ + 空字符串错误。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须包含内容.' % cls.label(column=column) + + @classmethod + def error_date_msg(cls, column: Union[Column, Any]): + """ + 日期格式错误。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须是日期.' % cls.label(column=column) + + @classmethod + def error_datetime_msg(cls, column: Union[Column, Any]): + """ + 日期时间格式错误。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须是日期时间.' % cls.label(column=column) + + @classmethod + def error_time_msg(cls, column: Union[Column, Any]): + """ + 时间格式错误。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须是时间.' % cls.label(column=column) + + @classmethod + def error_decimal_msg(cls, column: Union[Column, Any]): + """ + 非浮点或双精度类型错误。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须是浮点或双进度类型.' % cls.label(column=column) + + @classmethod + def error_format_msg(cls, column: Union[Column, Any]): + """ + 格式错误。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s格式错误.' % cls.label(column=column) + + @classmethod + def error_int_msg(cls, column: Union[Column, Any]): + """ + 非整数类型错误。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须是整数.' % cls.label(column=column) + + @classmethod + def error_len_msg(cls, column: Union[Column, Any], length: int): + """ + 长度错误消息。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须是%d位.' % (cls.label(column=column), length) + + @classmethod + def error_in_range_msg(cls, column: Union[Column, Any], + v_min: Union[int, float], v_max: Union[int, float]): + """ + 范围错误。主要用于数据值校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须在:[%s,%s] 范围内.' % (cls.label(column=column), f"{v_min}", f"{v_max}") + + @classmethod + def error_in_items_msg(cls, column: Union[Column, Any], items: list = None): + """ + 范围错误。主要用于数据项校验错误。 + + :return: 以字段备注为主的错误消息 + """ + if items is None: + return '%s超出范围.' % cls.label(column=column) + else: + return '%s必须在:[%s] 范围内.' % (cls.label(column=column), ','.join(items)) + + @classmethod + def error_str_msg(cls, column: Union[Column, Any]): + """ + 非字符串类型错误。主要用于数据校验错误。 + + :return: 以字段备注为主的错误消息 + """ + return '%s必须是字符串' % cls.label(column=column) + + field_validators: dict[Column, tuple] = {} + """ + 字段验证器配置。 + 规则为:字段名 -> 验证配置 + 验证配置为一个 tuple 数据,各元素说明如下:: + + 第 0 项:验证方法与消息方法,类型为 method 或 tuple,若仅有验证方法,则直接是方法名即可,若两者皆有,则为 tuple。 + 第 1 项:是否跳过 None 值,类型为 bool。 + 第 2~n 项,验证方法或消息方法的参数,注意验证方法与消息方法除第一项参数外的其他参数必须一致。 + """ + + @classmethod + def validate_fields(cls, row: dict) -> list[dict[str, str]]: + """ + 结合 field_validators 的配置,对字段执行验证。 + 若发现错误,则记录在 _errors 中并返回。 + + :param row: 待验证数据。 + :return: 验证得到的错误描述。 + """ + _errors: list[dict[str: str]] = [] + for _column, _validator in cls.field_validators.items(): + # 消息函数 + _message_func = None + # 验证函数,是否跳过空值 + _verify_func, _skip_null = _validator[:2] + + if isinstance(_verify_func, tuple): + _verify_func, _message_func = _verify_func + + _value = udict.get_with_default(row, _column.key, None) + if _value is None and _skip_null: + continue + + _args = _validator[2:] + _vfy_args = (_value,) + _args + _err_args = (_column,) + _args + + assert isinstance(_verify_func, Callable), '验证器配置错误.' + if not _verify_func(*_vfy_args): + if isinstance(_message_func, Callable): + _errors.append({_column.key: _message_func(*_err_args)}) + else: + _errors.append({_column.key: f'{_column.key} 字段数据错误.'}) + return _errors + + @classmethod + def validate_dict(cls, row: dict, row_list: list[dict], err_list: list[dict]): + """ + 验证字典数据。仅将结果加入对应的列表,不改变原有数据。 + + :param row: 待验证的行数据对象 + :param row_list: 用于存放验证成功模型的列表 + :param err_list: 用于存放错误消息的列表 + """ + try: + row_list.append(row) + except TypeError: + err_list.append(row) + + @classmethod + def validate_dict_list(cls, row_list: list[dict]) -> tuple[list[dict], list[dict]]: + """ + 验证字典列表数据,首先清除历史模型列表和错误消息。 + + :param row_list: 待验证的字典数组 + :return: 数据模型列表和错误消息列表 + """ + _row_list: list[dict] = [] + _err_list: list[dict] = [] + + for row in row_list: + cls.validate_dict(row=row, row_list=_row_list, err_list=_err_list) + return _row_list, _err_list + + @classmethod + def mapping_data_struct(cls, source: Optional[dict], mapping: Optional[dict]): + """ + 将源数据字典中的数据,按照映射关系字典的方式转换为新的字典对象。 + + 下面是一个递归映射关系字典的样本:: + + dict_key_mapping = { + 'devUseNo': 'dev_use_no', + 'mainId': 'id', + 'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''), + 'mainCycleCode': 'main_cycle', + 'fileList': { + '__name__': 'main_files', + '__mapping__': { + 'fileName': 'file_name', + 'filePath': 'file_url', + }, + }, + 'mainDetailList': { + '__name__': 'main_items', + '__mapping__': { + 'id': 'id', + 'itemId': 'item_id', + 'itemName': 'item_name', + 'itemRequest': 'item_request', + 'itemResult': 'item_result', + 'remarks': 'remarks', + 'itemFileList': { + '__name__': 'main_item_files', + '__mapping__': { + 'fileName': 'file_name', + 'filePath': 'file_url', + }, + }, + }, + }, + } + + 映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型:: + + 1、为字符串时,表示从源数据字典中直接读取。 + 2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。 + 3、为字典时,表示有子对象数据,此时需要配置 __name__ 属性和 __mapping__ 属性。 + 4、非以上情况的,直接使用该内容作为目标字典属性的数据。 + + :param source: 源数据字典 + :param mapping: 映射关系字典 + + :return: 转换后的字典 + """ + if source is None or mapping is None: + return None + + target = {} + for _tar_attr, _src_attr in mapping.items(): + if isinstance(_src_attr, str): + # + # 直接处理 key 映射关系 + # 注意,对于需要强制设置为字符串的,不能直接使用字符串,会被误认为是 key 映射关系,应当使用无参数 lambda 表达式。 + # + target[_tar_attr] = source.get(_src_attr, None) + elif isinstance(_src_attr, Callable): + # + # 处理函数或 lambda 表达式 + # + target[_tar_attr] = _src_attr(source) + elif isinstance(_src_attr, dict): + if '__name__' in _src_attr and '__mapping__' in _src_attr: + # + # 包含名称映射的,表示新的映射关系,递归处理 + # 这里仅处理类型为 dict 和 list 的数据 + # + + # 取出内部源数据字典和映射关系 + _sd = source.get(_src_attr.get('__name__'), None) + _mp = _src_attr.get('__mapping__', None) + + if isinstance(_sd, dict): + # + # 直接递归映射 + # + target[_tar_attr] = cls.mapping_data_struct(_sd, _mp) + elif isinstance(_sd, list): + # + # 遍历后递归映射 + # + _t_list = [] + for _sd_item in _sd: + _t_list.append(cls.mapping_data_struct(_sd_item, _mp)) + target[_tar_attr] = _t_list + else: + # + # 非 dict,list 的,直接设置 + # + target[_tar_attr] = _sd + else: + # + # 无映射关系的,直接设置 + # + target[_tar_attr] = _src_attr + else: + # + # 非 str,function,dict 的,直接设置 + # + target[_tar_attr] = _src_attr + + return target + + @classmethod + def transform(cls, rows: list[dict], mapping: dict): + """ + 将源数据 rows 字典中的数据,按照映射关系字典 mapping 的方式转换为新的字典对象。 + + 下面是一个递归映射关系字典的样本:: + + dict_key_mapping = { + 'devUseNo': 'dev_use_no', + 'mainId': 'id', + 'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''), + 'mainCycleCode': 'main_cycle', + 'fileList': { + '__name__': 'main_files', + '__mapping__': { + 'fileName': 'file_name', + 'filePath': 'file_url', + }, + }, + 'mainDetailList': { + '__name__': 'main_items', + '__mapping__': { + 'id': 'id', + 'itemId': 'item_id', + 'itemName': 'item_name', + 'itemRequest': 'item_request', + 'itemResult': 'item_result', + 'remarks': 'remarks', + 'itemFileList': { + '__name__': 'main_item_files', + '__mapping__': { + 'fileName': 'file_name', + 'filePath': 'file_url', + }, + }, + }, + }, + } + + 映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型:: + + 1、为字符串时,表示从源数据字典中直接读取。 + 2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。 + 3、为字典时,表示有子对象数据,此时需要配置 __name__ 属性和 __mapping__ 属性。 + 4、非以上情况的,直接使用该内容作为目标字典属性的数据。 + :param rows: 源数据字典列表 + :param mapping: 映射关系字典 + :return: 转换结果 + """ + _dict_list: list[dict] = [] + for _r in rows: + _tar_dict = cls.mapping_data_struct(_r, mapping) + _dict_list.append(_tar_dict) + return _dict_list + + @classmethod + def convert(cls, dataframe: pd.DataFrame, mapping: dict): + """ + 将源数据框架 dataframe 中的数据,按照映射关系字典 mapping 的方式转换为新的 dataframe 对象。 + + 下面是一个递归映关系射字典的样本:: + + dict_key_mapping = { + 'devUseNo': 'dev_use_no', + 'mainId': 'id', + 'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''), + 'mainCycleCode': 'main_cycle', + } + + 映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型:: + + 1、为字符串时,表示从源数据字典中直接读取。 + 2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。 + 3、非以上情况的,直接使用该内容作为目标字典属性的数据。 + + 注意:与字典映射转换不同,:class:`pd.DataFrame` 映射转换不支持多层递归转换。 + + :param dataframe: 源数据 dataframe + :param mapping: 映射关系字典 + :return: 转换结果 + """ + _tar_df = pd.DataFrame() + for _tar_attr, _src_attr in mapping.items(): + if isinstance(_src_attr, str): + _tar_df[_tar_attr] = dataframe[_src_attr] + elif isinstance(_src_attr, Callable): + _tar_df[_tar_attr] = dataframe.apply(_src_attr, axis=1) + else: + _tar_df[_tar_attr] = _tar_attr + return _tar_df + + @classmethod + def is_equal(cls, data_dict: dict, data_model: 'BaseModel', skip_kes: list[str] = None, decimals: str = '0.00'): + """ + 判断 data_dict 中的值是否都与 equ_model 中的对应值相等。一般而言若相等,则表明无需更新数据模型,否则就需要更新。 + + :param data_dict: 数据字典,用于遍历比对的数据,也是用于更新的数据 + :param data_model: 数据模型 + :param skip_kes: 允许跳过,不做比较的字段 + :param decimals: 浮点数保留的小数位,默认 2 位 + :return: 是否相等,各字段是否相等的对应关系字典 + """ + is_equal = True + equal_dict: dict = {} + if skip_kes is None: + skip_kes = [] + + for _key, _new_val in data_dict.items(): + if _key in skip_kes: + continue + + if _new_val is None: + # 跳过新值中的 None + continue + + if _key not in data_model.__dict__: + # 跳过不存在的属性 + continue + + _old_val = data_model.__dict__.get(_key, None) + if isinstance(_old_val, (Decimal, float)): + _old_val = Decimal(f"{_old_val}").quantize(Decimal(decimals), rounding=ROUND_HALF_UP) + _new_val = Decimal(f"{_new_val}").quantize(Decimal(decimals), rounding=ROUND_HALF_UP) + elif isinstance(_old_val, int): + _new_val = int(_new_val) + elif isinstance(_old_val, datetime.datetime): + _old_val = _old_val.strftime(LOCAL_DATETIME_FORMAT) + _datetime = ustr.to_datetime(_new_val, [LOCAL_DATETIME_FORMAT, LOCAL_DATE_FORMAT]) + _new_val = _datetime.strftime(LOCAL_DATETIME_FORMAT) if _datetime is not None else f"{_new_val}" + elif isinstance(_old_val, datetime.date): + _old_val = _old_val.strftime(LOCAL_DATE_FORMAT) + _date = ustr.to_datetime(_new_val, [LOCAL_DATE_FORMAT, LOCAL_DATETIME_FORMAT]) + _new_val = _date.strftime(LOCAL_DATE_FORMAT) if _date is not None else f"{_new_val}" + else: + _old_val = f"{_old_val}" if _old_val is not None else '' + if isinstance(_new_val, float): + _new_val = int(_new_val) + _new_val = f"{_new_val}" + + _isFieldEqual = _new_val == _old_val + is_equal = is_equal and _isFieldEqual + equal_dict[_key] = _isFieldEqual + + return is_equal, equal_dict + + @classmethod + async def page_info(cls, *where_clause, page_size: int = 20): + """ + 分页参数。 + + :return: 页数, 数据行数 + """ + _row_count = await cls.async_row_count(*where_clause) + _pagination = Pagination(row_count=_row_count) + return _pagination.pages(page_size=page_size), _row_count + + @classmethod + def sort_clauses(cls, sort_d: dict): + """ + 按照参数 sort_d 中的定义,组织排序表达式。参数 sortd_d 应该具有如下结构:: + + { + 'field_name1': 'asc', + 'field_name2': 'desc', + } + + :param sort_d 排序参数 + """ + _sort_clause = [] + for _fn, _st in sort_d.items(): + if _st in ('', 'asc', 'ascend'): + _sort_clause.append(text(_fn)) + if _st in ('desc', 'descend'): + _sort_clause.append(desc(text(_fn))) + return _sort_clause diff --git a/paste/db/basetable.py b/paste/db/basetable.py new file mode 100755 index 0000000..8415b3e --- /dev/null +++ b/paste/db/basetable.py @@ -0,0 +1,337 @@ +""" +集成了表的基本操作。 +""" +import datetime +from typing import Optional, Union, List, Sequence + +import pandas as pd +from sqlalchemy import func, desc, Column, select, util, Result +from sqlalchemy.engine import ScalarResult, CursorResult +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from sqlalchemy.sql import Select +from sqlalchemy.sql.elements import BinaryExpression +from sqlalchemy.sql.operators import ColumnOperators + +from paste.db import baseadapter, engine + + +class BaseTable(baseadapter.BaseAdapter): + """ + 表对象(数据模型)基类。集成基础数据操作功能。 + 实现异步执行方法。 + """ + + __abstract__ = True + + @classmethod + async def raw_execute(cls, query, session: AsyncSession=None, + params=None, options=None) -> Optional[CursorResult]: + """ + 异步连接执行原始查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。 + + :param query: 查询语句 + :param session: 数据库会话对象 + :param params: 查询参数 + :param options: 查询选项 + :return: CursorResult 游标 + """ + if options is None: + options = util.EMPTY_DICT + if session is None: + connection: AsyncConnection = await cls.get_aio_session().connection() + else: + connection = await session.connection() + + try: + cursor_result: CursorResult = await connection.execute(query, params, execution_options=options) + await connection.commit() + return cursor_result + finally: + await connection.close() + + @classmethod + async def orm_execute(cls, query, session: AsyncSession=None, params=None, options=None) -> Optional[Result]: + """ + 异步会话执行查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。 + + :param query: 查询请求 + :param session: 数据库会话对象 + :param params: 查询参数 + :param options: 查询选项 + """ + _has_session: bool = True + if session is None: + _has_session = False + session = cls.get_aio_session() + + try: + result: Result = await session.execute(query, params, execution_options=options) + return result + finally: + if not _has_session: + await session.close() + + @classmethod + async def orm_execute_scalars(cls, query, session: AsyncSession=None, **kwargs) -> ScalarResult: + """ + 使用异步执行查询。并执行 scalars() 方法,这会进行 ORM 映射。 + + :param query: 查询对象 + :param session: 数据库会话对象 + :return: 数据模型对象列表 + """ + result: Result = await cls.orm_execute(query, session, **kwargs) + return result.scalars() + + @classmethod + async def query_all(cls, query: Select, session: AsyncSession=None) -> List['BaseTable']: + """ + 使用异步执行 ORM 查询。并执行 scalars().all() 方法。 + + :param query: 查询对象 + :param session: 数据库会话对象 + :return: 数据模型对象列表,注意:当查询部分字段时仅返回查询列表中第一个字段的数据列表 + """ + scalars = await cls.orm_execute_scalars(query, session) + return list(scalars.all()) + + @classmethod + async def query_first(cls, query: Select, session: AsyncSession=None) -> 'BaseTable': + """ + 使用异步执行 ORM 查询。并执行 scalars().first() 方法。 + + :param query: 查询对象 + :param session: 数据库会话对象 + :return: 数据模型对象列表 + """ + scalars = await cls.orm_execute_scalars(query, session) + return scalars.first() + + @classmethod + async def query_count(cls, query: Select, is_only_count: bool = False, session: AsyncSession=None) -> int: + """ + 使用异步执行 ORM 查询,查询原查询的数据行数。 + + :param query: 查询对象 + :param is_only_count: 是否使用仅 count 方式查询 + :param session: 数据库会话对象 + :return: 数据行数 + """ + if is_only_count: + count_query = query.with_only_columns(func.count(1)) + else: + count_query = select(func.count(1)).select_from(query.subquery()) + # 执行查询 + row_count: int = (await cls.orm_execute_scalars(count_query, session)).first() + return row_count + + @classmethod + async def query_as_df(cls, query, session: AsyncSession=None, params=None, options=None): + """ + 执行 RAW 原始查询,并将数据请求转换为 :class:`pd.DataFrame` 返回。 + 注意:为了保持字段数据精度或便于数据处理,可在 SQL 语句中使用 :meth:`func.ifnull` 或 :meth:`func.convert` 等方法直接转换数据类型。 + + :param query: 查询语句 + :param session: 数据库会话对象 + :param params: 查询参数 + :param options: 查询选项 + :return: 数据 DataFrame + """ + result = await cls.raw_execute(query, session, params, options) + return pd.DataFrame(result.all(), columns=pd.Series(result.keys())) + + @classmethod + async def async_row_count(cls, *where_expressions: Union[ColumnOperators, BinaryExpression], + session: AsyncSession=None) -> int: + """ + 异步按条件取得数据行数量。 + + :param where_expressions: 条件表达式列表 + :param session: 数据库会话对象 + :return: 行数 + """ + query = select(func.count(1)).select_from(cls).where(*where_expressions) + result = await cls.orm_execute_scalars(query, session) + return result.first() + + @classmethod + async def async_exist_by_id(cls, id_val: Union[str, int]): + """ + 检查是否有存在的主键。 + + :param id_val: 主键值 + """ + count: int = await cls.async_row_count(cls.instrument_attr('id') == id_val) + return count >= 1 + + @classmethod + async def async_find_by_id(cls, id_val: Union[str, int]) -> Optional['BaseTable']: + """ + 根据主键ID查找数据,确认有 id 字段后方可使用。 + + :param id_val: 主键值 + :return: 数据模型对象 + """ + query = select(cls).where(cls.instrument_attr('id') == id_val) + model = await cls.query_first(query=query) + return model + + @classmethod + async def async_find_by_created(cls, from_t: datetime.datetime, + to_t: Optional[datetime.datetime] = None) -> List['BaseTable']: + """ + 按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。 + + :param from_t: 开始时间 + :param to_t: 结束时间 + """ + if to_t is None: + to_t = from_t + + query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t)) + return await cls.query_all(query) + + @classmethod + async def async_find_by_updated(cls, from_t: datetime.datetime, + to_t: Optional[datetime.datetime] = None) -> List['BaseTable']: + """ + 按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。 + + :param from_t: 开始时间 + :param to_t: 结束时间 + """ + if to_t is None: + to_t = from_t + + query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t)) + return await cls.query_all(query) + + @classmethod + async def async_find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]): + """ + 根据数据列表查询已经在数据库中的数据。 + + :param row_list: 数据列表 + :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 + :return: 查询到的数据模型列表 + """ + model_list: list[cls] = [] + query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols) + if query is not None: + model_list = await cls.query_all(query=query) + + return model_list + + @classmethod + async def async_find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]): + """ + 根据数据框架查询已经在数据库中的数据。 + + :param row_df: 数据框架 + :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同 + :return: 查询到的数据模型列表 + """ + query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols) + if query is not None: + _result = await cls.orm_execute_scalars(query=query) + _rows: Sequence[cls] = _result.all() + return pd.DataFrame(_rows) + + return None + + async def async_find(self, likes: Optional[dict[str, str]] = None) -> List['BaseTable']: + """ + 根据自身参数,查询数据库。 + + :param likes: 模糊条件 + :return: 查询到的结果对象 + """ + expressions = self.filter_expressions(likes=likes) + query = select(self.__class__).where(*expressions) + return await self.query_all(query) + + async def async_find_first(self, likes: Optional[dict[str, str]] = None) -> Optional['BaseTable']: + """ + 根据自身参数,查询数据库,仅查询第一条。 + + :param likes: 模糊条件 + :return: 查询到的结果对象 + """ + expressions = self.filter_expressions(likes=likes) + query = select(self.__class__).where(*expressions) + result = await self.query_first(query) + return result + + async def async_find_piece(self, *where: Union[ColumnOperators, BinaryExpression], + offset=0, limit=500, is_desc=False, + likes: Optional[dict[str, str]] = None) -> List['BaseTable']: + """ + 根据自身参数,查询数据库。 + + :param where: 查询条件 + :param offset: 偏移量 + :param limit: 读取数量 + :param is_desc: 是否逆序排列 + :param likes: 模糊条件 + :return: 查询到的结果对象 + """ + clz = self.__class__ + expressions = self.filter_expressions(likes=likes) + if where is not None: + expressions += where + + query = select(clz).where(*expressions) + if limit > 0: + query = query.limit(limit=limit) + if offset >= 0: + query = query.offset(offset=offset) + if is_desc: + if hasattr(clz, 'id'): + query = query.order_by(desc(clz.id)) + + return await self.query_all(query) + + async def async_save(self, auto_expunge: Optional[bool] = True, session: Optional[AsyncSession] = None): + """ + 保存数据模型对象,首先强制关闭原有会话,获取新的会话,并加入对象。 + + :param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效 + :param session: 会话对象,主要用于事务 + :return: 保存状态 + """ + _has_session: bool = True + """ + 该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。 + """ + + self.before_save() + self.close_session() + if session is None: + _has_session = False + session = self.get_aio_session() + + try: + session.add(self) + if not _has_session: + # 使用新会话时,直接提交 + await session.commit() + self._is_new = False + except Exception as e: + await session.rollback() + raise e + else: + if auto_expunge and not _has_session: + await session.refresh(self) + session.expunge(self) + return True + finally: + if not _has_session: + # 使用新会话时,主动关闭 + await session.close() + + +def create_all_tables(): + """ + 创建所有的表格。 + """ + baseadapter.registry.metadata.create_all(engine.connect_engine()) diff --git a/paste/db/engine.py b/paste/db/engine.py new file mode 100755 index 0000000..23d8ef6 --- /dev/null +++ b/paste/db/engine.py @@ -0,0 +1,43 @@ +""" +从配置文件读取数据引擎连接信息,连接数据库。 +""" + +from typing import Union + +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.engine.mock import MockConnection +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine + +from paste.core import config + +ASYNC_CONNECTOR_ENGINE = None +GLOBAL_CONNECTOR_ENGINE = None + + +def connect_engine() -> Union[MockConnection, Engine]: + """ + 全局数据连接引擎。 + + :return: 数据连接引擎 + """ + global GLOBAL_CONNECTOR_ENGINE + if GLOBAL_CONNECTOR_ENGINE is None: + GLOBAL_CONNECTOR_ENGINE = create_engine( + config.get_config('db_engine.engine'), **config.get_config('db_engine.engine_option') + ) + return GLOBAL_CONNECTOR_ENGINE + + +def async_connect_engine() -> AsyncEngine: + """ + 异步数据连接引擎。 + + :return: 异步数据连接引擎 + """ + global ASYNC_CONNECTOR_ENGINE + if ASYNC_CONNECTOR_ENGINE is None: + ASYNC_CONNECTOR_ENGINE = create_async_engine( + config.get_config('db_engine.async_engine'), **config.get_config('db_engine.engine_option') + ) + return ASYNC_CONNECTOR_ENGINE diff --git a/paste/db/gen_models.py b/paste/db/gen_models.py new file mode 100644 index 0000000..62bad44 --- /dev/null +++ b/paste/db/gen_models.py @@ -0,0 +1,51 @@ +""" +生成数据模型代码,注意这里显式排除了在配置文件 RBAC 中配置的数据表。 +如果不需要排除,可直接使用 sqlacodegen 自带的命令。 +""" +import subprocess +from os import path + +import pandas as pd + +from paste.core import config +from paste.core.logging import echo_log +from paste.db.basetable import BaseTable + +exclude_tables = list(config.get_config('rbac.table').values()) +""" +需要排除的数据表,默认排除 RBAC 数据表,这部分表格在 RBAC 模块中已经配置好了。 +""" + + +def db_engin(): + return config.get_config('db_engine.engine') + + +async def sqlacodegen(is_exclude_rbac_table: bool = True): + """ + 生成代码文件。 + + :param is_exclude_rbac_table: 是否排除 RBAC 相关的数据表,默认排除不生成数据模型 + """ + _table_names = await BaseTable.tables_in_db() + + # 剔除 RBAC 数据表 + if _table_names and is_exclude_rbac_table: + # 转换为 DataFrame + _tables_df = pd.DataFrame(_table_names) + # 剔除不要包含的表 + _name_df: pd.DataFrame = _tables_df.loc[~_tables_df.iloc[:, 0].isin(exclude_tables)] + # 转换剩余数据为'表名'字符串列表 + _table_names = _name_df[0].tolist() + + if len(_table_names) == 0: + return + + echo_log(f"将为以下表生成数据模型:{_table_names}") + _tables = f"--tables={','.join(_table_names)}" + # 默认创建在当前目录的 models 目录中 + _outfile = f"--outfile={path.join(path.curdir, 'models', 'db_models.py')}" + + _engin = db_engin() + subprocess.call(['sqlacodegen', _engin, _tables, _outfile]) + echo_log(f"生成完成.") \ No newline at end of file diff --git a/paste/db/redis.py b/paste/db/redis.py new file mode 100644 index 0000000..76d9e14 --- /dev/null +++ b/paste/db/redis.py @@ -0,0 +1,995 @@ +""" +封装了 Python 对 Redis 的基本操作。 +同时处理了 Java 在操作 Redis 后留下的字节码问题。 +""" +import asyncio +import hashlib +import pathlib +import random +import types +from logging import ERROR, WARNING +from typing import Optional, Callable, Awaitable, Union, Tuple, Dict + +import javaobj +import redis +from redis.asyncio import ConnectionPool, StrictRedis +from redis.client import Pipeline + +from paste.core import aio_pool, config, logging +from paste.util.snow_id import IdWorker + + +class LuaScriptManager: + """ + Lua 脚本管理器。 + 负责加载、缓存和执行 Lua 脚本。 + """ + + # 默认 Lua 脚本内容(作为内置默认值,无需外部文件) + DEFAULT_SCRIPTS = { + "stock_decr": """ + -- 扣减库存(原子操作) + -- KEYS[1]: 库存 key + -- ARGV[1]: 扣减数量 + -- 返回值: 1=成功, 0=库存不足, -1=key不存在 + + local key = KEYS[1] + local quantity = tonumber(ARGV[1]) + + local current = redis.call('GET', key) + if not current then + return -1 + end + + current = tonumber(current) + if current >= quantity then + redis.call('DECRBY', key, quantity) + return 1 + else + return 0 + end + """, + + "stock_incr": """ + -- 增加库存(原子操作) + -- KEYS[1]: 库存 key + -- ARGV[1]: 增加数量 + -- 返回值: 当前库存 + + local key = KEYS[1] + local quantity = tonumber(ARGV[1]) + + redis.call('INCRBY', key, quantity) + return redis.call('GET', key) + """, + + "stock_peek": """ + -- 查看库存(原子操作) + -- KEYS[1]: 库存 key + -- 返回值: 当前库存 + + local key = KEYS[1] + local current = redis.call('GET', key) + + if not current then + return -1 + end + return tonumber(current) + """, + } + + _scripts: Dict[str, Tuple[str, str]] = {} # name -> (sha, script_content) + _script_dir: Optional[pathlib.Path] = None + _use_external_files: bool = False # 是否使用外部文件 + + @classmethod + def set_script_dir(cls, script_dir: str, use_external: bool = True): + """ + 设置 Lua 脚本目录 + + :param script_dir: 脚本目录路径 + :param use_external: 是否使用外部文件(False 则使用内置默认脚本) + """ + cls._script_dir = pathlib.Path(script_dir) if script_dir else None + cls._use_external_files = use_external + + @classmethod + async def load_script(cls, redis_client: StrictRedis, script_name: str) -> str: + """ + 加载并注册 Lua 脚本 + 优先使用外部文件,不存在则使用内置默认脚本 + + :param redis_client: Redis 客户端 + :param script_name: 脚本名称(如 stock_decr) + :return: 脚本 SHA + """ + script_content = None + + # 尝试从外部文件加载 + if cls._use_external_files and cls._script_dir: + script_path = cls._script_dir / f"{script_name}.lua" + if script_path.exists(): + with open(script_path, 'r', encoding='utf-8') as f: + script_content = f.read() + logging.echo_log(f"Lua 脚本从外部文件加载: {script_path}") + + # 使用内置默认脚本 + if script_content is None: + if script_name not in cls.DEFAULT_SCRIPTS: + raise ValueError(f"脚本不存在: {script_name},且无内置默认值") + script_content = cls.DEFAULT_SCRIPTS[script_name] + logging.echo_log(f"Lua 脚本使用内置默认值: {script_name}") + + # 计算 SHA + sha = hashlib.sha1(script_content.encode()).hexdigest() + + # 缓存脚本 + cls._scripts[script_name] = (sha, script_content) + + # 预加载到 Redis + try: + await redis_client.script_load(script_content) + except Exception: + pass # 预加载失败不影响后续使用 + + return sha + + @classmethod + async def load_default_scripts(cls, redis_client: StrictRedis): + """ + 加载所有默认脚本 + """ + for script_name in cls.DEFAULT_SCRIPTS.keys(): + await cls.load_script(redis_client, script_name) + logging.echo_log(f"已加载 {len(cls.DEFAULT_SCRIPTS)} 个默认 Lua 脚本") + + @classmethod + async def execute(cls, redis_client: StrictRedis, script_name: str, + keys: list, args: list) -> any: + """ + 执行 Lua 脚本 + 优先使用 evalsha(性能更好),失败则降级到 eval + + :param redis_client: Redis 客户端 + :param script_name: 脚本名称 + :param keys: KEYS 参数列表 + :param args: ARGV 参数列表 + :return: 脚本执行结果 + """ + if script_name not in cls._scripts: + # 脚本未加载,尝试加载 + await cls.load_script(redis_client, script_name) + + sha, script_content = cls._scripts[script_name] + + try: + return await redis_client.evalsha(sha, len(keys), *keys, *args) + except redis.ResponseError as e: + if "NOSCRIPT" in str(e): + # 重新加载并重试 + await redis_client.script_load(script_content) + return await redis_client.evalsha(sha, len(keys), *keys, *args) + else: + raise + + @classmethod + async def reload_script(cls, redis_client: StrictRedis, script_name: str) -> str: + """重新加载指定的 Lua 脚本""" + if script_name in cls._scripts: + del cls._scripts[script_name] + return await cls.load_script(redis_client, script_name) + + +class Redis: + """ + Redis 基础操作。 + """ + + connect_pool: Optional[ConnectionPool] = None + + prefix = b'\xac\xed\x00\x05' + utf_flag = b'\x74' + + lua_scripts = LuaScriptManager + """Lua 脚本管理器。""" + + @classmethod + def is_java_serialized(cls, bs: Union[bytes, str]): + """ + 判断是否为 Java 序列化后的数据。 + + :param bs: 字节流 + """ + if not isinstance(bs, bytes): + return False + return bs[:4] == cls.prefix + + @classmethod + async def get_pool(cls) -> ConnectionPool: + """ + 取得 Redis 连接池。 + + :return: 连接池对象 + """ + if cls.connect_pool is None: + _conn_params = config.get_config("redis.connection") + cls.connect_pool = ConnectionPool.from_url(**_conn_params) + return cls.connect_pool + + @classmethod + async def close_pool(cls): + if cls.connect_pool is not None: + await cls.connect_pool.disconnect() + cls.connect_pool = None + + @classmethod + async def get_redis(cls) -> StrictRedis: + """ + 取得数据库对象。 + + :return: 数据库对象 + """ + _pool = await cls.get_pool() + return StrictRedis( + connection_pool=_pool, + socket_timeout=5, + socket_connect_timeout=5, + health_check_interval=30, + socket_keepalive=True + ) + + @classmethod + async def ping(cls): + """ + 测试连接。 + + :return: 测试结果 + """ + async with await cls.get_redis() as _redis: + return await _redis.ping() + + @classmethod + async def get_pipe(cls, transaction: bool = True, shard_hint=None) -> Pipeline: + """ + 取得管道对象。 + """ + async with await cls.get_redis() as _redis: + return _redis.pipeline(transaction=transaction, shard_hint=shard_hint) + + # ========== Lua 脚本初始化 ========== + + @classmethod + async def init_lua_scripts(cls, script_dir: str = None, use_external: bool = False): + """ + 初始化 Lua 脚本 + 建议在应用启动时调用一次 + + :param script_dir: 外部脚本目录(可选) + :param use_external: 是否使用外部文件,默认 False 使用内置脚本 + """ + if script_dir: + cls.lua_scripts.set_script_dir(script_dir, use_external) + + async with await cls.get_redis() as _redis: + await cls.lua_scripts.load_default_scripts(_redis) + + # ========== 库存核心方法(原子操作) ========== + + @classmethod + async def stock_decr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, str]: + """ + 扣减库存(原子操作) + 使用 Lua 脚本保证原子性,防止超卖 + + :param stock_key: 库存 Key(支持分片,如 stock:iPhone15:shard:0) + :param quantity: 扣减数量 + :return: (是否成功, 消息) + """ + async with await cls.get_redis() as _redis: + try: + result = await cls.lua_scripts.execute( + _redis, + "stock_decr", + keys=[stock_key], + args=[quantity] + ) + + if result == 1: + return True, "扣减成功" + elif result == 0: + return False, "库存不足" + else: + return False, "商品不存在" + except Exception as e: + logging.echo_log(f"扣减库存异常: {e}", level=ERROR, is_log_exc=True) + return False, f"系统异常: {e}" + + @classmethod + async def stock_incr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, int]: + """ + 增加库存(原子操作) + 用于退货入库、补货等场景 + + :param stock_key: 库存 Key + :param quantity: 增加数量 + :return: (是否成功, 当前库存) + """ + async with await cls.get_redis() as _redis: + try: + result = await cls.lua_scripts.execute( + _redis, + "stock_incr", + keys=[stock_key], + args=[quantity] + ) + return True, int(result) + except Exception as e: + logging.echo_log(f"增加库存异常: {e}", level=ERROR, is_log_exc=True) + return False, 0 + + @classmethod + async def stock_peek(cls, stock_key: str) -> int: + """ + 查看剩余库存(原子操作) + + :param stock_key: 库存 Key + :return: 剩余库存 + """ + async with await cls.get_redis() as _redis: + try: + result = await cls.lua_scripts.execute( + _redis, + "stock_peek", + keys=[stock_key], + args=[] + ) + return int(result) if result >= 0 else 0 + except Exception as e: + logging.echo_log(f"查询库存异常: {e}", level=ERROR, is_log_exc=True) + return 0 + + # ========== 库存分片辅助方法 ========== + + @classmethod + def get_shard_key(cls, sku_id: str, shard_id: int) -> str: + """ + 获取分片 Key。 + 推荐格式:{业务域}:{实体}:{唯一标识}:{分片/维度}:{扩展} + + :param sku_id: 商品ID + :param shard_id: 分片ID + :return: 分片 Key + """ + return f"stock:{sku_id}:shard:{shard_id}" + + @classmethod + def get_user_shard(cls, sku_id: str, user_id: str, shard_count: int = 10) -> str: + """ + 根据用户ID获取分片 Key + + :param sku_id: 商品ID + :param user_id: 用户ID + :param shard_count: 分片总数 + :return: 分片 Key + """ + shard = hash(user_id) % shard_count + return cls.get_shard_key(sku_id, shard) + + @classmethod + async def init_sharded_stock(cls, sku_id: str, total_stock: int, shard_count: int = 10): + """ + 初始化分片库存 + + :param sku_id: 商品ID + :param total_stock: 总库存 + :param shard_count: 分片数量 + """ + base = total_stock // shard_count + remainder = total_stock % shard_count + + async with await cls.get_redis() as _redis: + for i in range(shard_count): + shard_key = cls.get_shard_key(sku_id, i) + stock = base + (1 if i < remainder else 0) + await _redis.set(shard_key, stock) + logging.echo_log(f"初始化分片 {i}: {shard_key} = {stock}") + + # ========== 基础 KV 操作 ========== + + @classmethod + async def keys(cls): + """ + 取得所有的 Key。 + """ + async with await cls.get_redis() as _redis: + _keys = await _redis.keys() + return _keys + + @classmethod + async def show_keys(cls): + """ + 控制台显示所有的 Keys。 + """ + _keys = await cls.keys() + for _key in _keys: + if isinstance(_key, bytes): + if cls.is_java_serialized(_key): + print(_key[7:].decode('utf-8'), '=>', _key) + else: + print(_key.decode('utf-8'), '=>', _key) + else: + print(_key) + + @classmethod + async def get(cls, key: Union[bytes, str]): + """ + 多种方式读取 Redis 中的数据。 + + :param key: Redis Key 名称 + :return: 数据内容 + """ + async with await cls.get_redis() as _redis: + _result = await _redis.get(key) + + if _result is None and not cls.is_java_serialized(key): + if isinstance(key, str): + key_bytes = key.encode('utf-8') + else: + key_bytes = key + _key = cls.prefix + cls.utf_flag + len(key_bytes).to_bytes(2, 'big') + key_bytes + _result = await _redis.get(_key) + + if _result is None: + return _result + + if isinstance(_result, bytes) and cls.is_java_serialized(_result): + return javaobj.loads(_result) + else: + return _result + + @classmethod + async def set(cls, key: str, value: any, ex: int = None): + """ + 设置键值对 + """ + async with await cls.get_redis() as _redis: + return await _redis.set(key, value, ex=ex) + + @classmethod + async def delete(cls, key: str): + """ + 删除键 + """ + async with await cls.get_redis() as _redis: + return await _redis.delete(key) + + @classmethod + async def exists(cls, key: str) -> bool: + """ + 检查键是否存在 + """ + async with await cls.get_redis() as _redis: + return await _redis.exists(key) > 0 + + @classmethod + async def expire(cls, key: str, seconds: int): + """ + 设置过期时间 + """ + async with await cls.get_redis() as _redis: + return await _redis.expire(key, seconds) + + @classmethod + async def incr(cls, key: str) -> int: + """ + 原子递增 + """ + async with await cls.get_redis() as _redis: + return await _redis.incr(key) + + # ========== 回调处理 ========== + + @classmethod + def get_func_name(cls, func): + """ + 得到方法名称。 + + :param func: 方法对象 + :return: 方法名称 + """ + if isinstance(func, types.FunctionType): + return func.__name__ + elif isinstance(func, types.MethodType): + return func.__func__.__name__ + elif isinstance(func, (classmethod, staticmethod)): + return func.__func__.__name__ + elif hasattr(func, '__call__'): + return func.__class__.__name__ + else: + return str(func) + + @classmethod + async def callback(cls, func: Callable, message_key: str, is_delete=False): + """ + 根据消息 KEY 读取数据,并执行回调函数,如果回调函数正确执行,则根据参数 is_delete 判断删除消息。 + + :param func: 回调函数 + :param message_key: 消息 KEY + :param is_delete: 是否删除处理过的消息 + """ + result = None + async with await cls.get_redis() as _redis: + try: + message_data = await _redis.hgetall(message_key) + if not message_data: + logging.echo_log(f"警告: 空消息数据 {message_key}.", level=WARNING) + return result + + if func: + # 处理回调 + result = func(message_data) + # 处理协程 + if isinstance(result, Awaitable): + result = await result + + if is_delete: + # 回调正确执行,且设置为删除删除的,才会删除消息 + await _redis.delete(message_key) + logging.echo_log(f"消息已删除: {message_key};数据为:{message_data}.") + except redis.RedisError as e: + logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True) + except Exception as e: + logging.echo_log( + f"执行回调异常:{e};方法:{cls.get_func_name(func)};消息: {message_key}.", + level=ERROR, is_log_exc=True + ) + return result + + +class PubSubActor(Redis): + """ + 发布订阅执行器。用于发布消息和订阅消息。 + 订阅采用阻塞式读取,可以在读取到数据后,执行回调方法,并根据参数确定是否删除历史消息。 + """ + + def __init__(self, hash_name: str): + self.hash_name = f"{hash_name}_HASH_NAME" + self.channel = f"{hash_name}_CHANNEL" + + self.running = False + """ + 优雅退出控制标志 + """ + + self.stopping = False + """ + 控制整个 run_forever 循环退出 + """ + + async def publish(self, data: dict) -> str: + """ + 数据写入 Redis 并发布消息。 + + :param data: 写入 Redis 的数据 + :return: 消息ID + """ + async with await self.get_redis() as _redis: + # 生成雪花 ID 作为 Hash Key + _random_num = random.randint(1000, 9999) + _id = IdWorker.get_id_worker(3, 3, _random_num).get_id() + + # 写入Redis hash + await _redis.hset(f"{self.hash_name}:{_id}", mapping=data) + # 发布新消息通知 + await _redis.publish(self.channel, _id) + return _id + + async def subscribe(self, func: Callable = None, is_delete=False): + """ + 监听消息。 + + :param func: 监听回调程序 + :param is_delete: 回调执行完毕后,是否删除消息 + """ + async with await self.get_redis() as _redis: + _pubsub = _redis.pubsub() + await _pubsub.subscribe(self.channel) + + try: + self.running = True + + # 使用 while 循环,而不是直接 async for,以便加入超时控制 + while not self.stopping and self.running: + try: + # 每次循环都重新获取迭代器 + listen_iter = _pubsub.listen() + message = await asyncio.wait_for(listen_iter.__anext__(), timeout=60.0) + + if message["type"] != "message": + continue + + message_id = message["data"] + message_key = f"{self.hash_name}:{message_id}" + + try: + # 隔离处理回调异常 + # 采用后台运行的方式处理,防止消息排队,提高消息处理性能 + await aio_pool.run_background_task(self.callback(func, message_key, is_delete), 10) + # await self.callback(func, message_key, is_delete=is_delete) + except Exception: + # 继续处理下条消息 + continue + + except asyncio.TimeoutError: + # 超时,是心跳成功的标志 + logging.echo_log("心跳:连接正常,继续监听...") + continue + except redis.exceptions.ConnectionError as e: + # 连接错误,触发重连 + logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True) + raise e + except StopAsyncIteration: + # pubsub 正常关闭 + logging.echo_log("PubSub 迭代器已停止.") + break + except (asyncio.CancelledError, KeyboardInterrupt): + logging.echo_log("收到退出信号,停止监听...") + self.running = False + raise + except Exception as e: + logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True) + raise e + finally: + self.running = False + try: + await _pubsub.unsubscribe(self.channel) + await _pubsub.close() + except Exception as close_err: + logging.echo_log(f"资源关闭异常: {close_err}.") + finally: + logging.echo_log("监听已完全停止.") + + async def run_forever(self, func: Callable = None, is_delete=False): + """ + 持久运行的监听器,包含自动重连逻辑和优雅退出。 + """ + while not self.stopping: + try: + logging.echo_log("启动新的监听会话...") + await self.subscribe(func, is_delete) + except (asyncio.CancelledError, KeyboardInterrupt): + logging.echo_log("收到退出信号,停止监听...") + self.stopping = True + break + except Exception as e: + logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True) + + if self.stopping: + logging.echo_log("总开关已打开,停止重连.") + break + + logging.echo_log("等待重新连接...") + try: + # 关键改动:直接使用 sleep,它本身就是可中断的 + await asyncio.sleep(10) + except asyncio.CancelledError: + logging.echo_log("等待期间被取消,准备退出.") + break + + logging.echo_log("监听服务已完全停止.") + + async def history(self, func: Callable = None, is_delete=False): + """ + 处理历史数据。 + + :param func: 监听回调程序 + :param is_delete: 回调执行完毕后,是否删除消息 + """ + async with await self.get_redis() as _redis: + _keys = await _redis.keys() + for _k in _keys: + try: + # 隔离处理回调异常 + # 采用后台运行的方式处理,防止消息排队,提高消息处理性能 + await aio_pool.run_background_task(self.callback(func, _k, is_delete), 10) + # await self.callback(func, _k, is_delete=is_delete) + except Exception: + # 继续处理下条消息 + continue + + def subscribe_stop(self): + self.running = False + self.stopping = True + + +class StreamActor(Redis): + """ + 流执行器。使用 Redis Streams 实现发布消息和消费消息。 + 消费采用消费者组模式,支持消息确认和可靠传递。 + 方法结构与 PubSubActor 保持一致,便于无缝替换。 + + 此版本集成了启动时的僵尸任务自动恢复功能。 + """ + + @classmethod + def actor_config(cls, config_path: str): + """ + 根据路径,取得配置信息。 + + :param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔 + :return: + """ + _stream_name = config_path.split(".")[-1].upper() + _stream_config = config.get_config(config_path) + _group_name = _stream_config.get('group', f"{_stream_name}_GROUP") + _consumer_name = _stream_config.get('consumer', f"{_stream_name}_CONSUMER") + _snow_id = IdWorker.get_id_worker().get_id() + _consumer_name = f"{_consumer_name}_{_snow_id}" + return _stream_name, _group_name, _consumer_name + + @classmethod + def new_actor(cls, config_path: str): + """ + 根据配置文件中的配置小节创建流执行器。 + + :param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔 + :return: 执行器对象 + """ + _stream_name, _group_name, _consumer_name = cls.actor_config(config_path) + return cls(_stream_name, _group_name, _consumer_name) + + def __init__(self, stream_name: str, group_name: str, consumer_name: str): + """ + 初始化流执行器。 + + :param stream_name: Redis Stream 的名称 + :param group_name: 消费者组的名称 + :param consumer_name: 当前消费者的名称 + """ + self.stream_name = stream_name + self.group_name = group_name + self.consumer_name = consumer_name + + self.running = False + """ + 优雅退出控制标志 + """ + + self.stopping = False + """ + 控制整个 run_forever 循环退出 + """ + + async def _ensure_group_exists(self): + """确保消费者组已存在,如果不存在则创建。""" + try: + _redis = await self.get_redis() + await _redis.xgroup_create( + name=self.stream_name, + groupname=self.group_name, + id='0', # 从头开始消费 + mkstream=True # Stream 不存在时自动创建 + ) + logging.echo_log(f"消费者组 '{self.group_name}' 已创建.") + except redis.exceptions.ResponseError as e: + if "Consumer Group name already exists" in str(e): + logging.echo_log(f"消费者组 '{self.group_name}' 已存在.") + else: + raise + + async def publish(self, data: dict) -> str: + """ + 将数据作为消息写入 Redis Stream。 + + :param data: 写入 Stream 的数据字典 + :return: 消息ID + """ + async with await self.get_redis() as _redis: + # 添加时会自动生成唯一的消息ID + message_id = await _redis.xadd(name=self.stream_name, fields=data) + logging.echo_log(f"消息已发布至 Stream '{self.stream_name}',ID: {message_id};数据为:{data}.") + return message_id + + async def reclaim_stale_tasks(self, func: Callable, is_delete: bool, stale_threshold_ms: int = 5 * 60 * 1000): + """ + 检查并尝试重新处理僵尸任务。 + + Args: + func (Callable): 用于处理任务的业务回调函数。 + is_delete (bool): 处理成功后是否确认消息。 + stale_threshold_ms (int): 判定为僵尸任务的空闲时间阈值(毫秒)。 + """ + async with await self.get_redis() as _redis: + # 1. 发现僵尸任务 + try: + stale_tasks = await _redis.xpending_range( + name=self.stream_name, + groupname=self.group_name, + min='-', + max='+', + count=10, # 每次最多处理10个僵尸任务,避免启动时阻塞太久 + idle=stale_threshold_ms + ) + except Exception as e: + logging.echo_log(f"检查僵尸任务时出错: {e}", level=ERROR, is_log_exc=True) + return + + if not stale_tasks: + logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.") + return + + if not stale_tasks or not isinstance(stale_tasks, list): + logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.") + return + message_ids = [task['message_id'] for task in stale_tasks] + logging.echo_log(f"发现 {len(message_ids)} 个僵尸任务,尝试认领并重新处理...") + + # 2. 认领任务 + try: + reclaimed_messages = await _redis.xclaim( + name=self.stream_name, + groupname=self.group_name, + consumername=self.consumer_name, # 认领给自己 + min_idle_time=stale_threshold_ms, + message_ids=message_ids, + justid=False # 我们需要消息内容来处理 + ) + except Exception as e: + logging.echo_log(f"认领僵尸任务时出错: {e}", level=ERROR, is_log_exc=True) + return + + if not reclaimed_messages: + logging.echo_log("未能成功认领任何僵尸任务.") + return + + logging.echo_log(f"成功认领 {len(reclaimed_messages)} 个僵尸任务,开始处理.") + + # 3. 处理被认领的任务 + for message_id, message_data in reclaimed_messages: + # 使用我们已有的 _callback_wrapper 来处理,保证逻辑一致 + await self._callback_wrapper( + func=func, + message_id=message_id, + message_data=message_data, + is_delete=is_delete + ) + + async def history(self, func: Callable, is_delete: bool): + """ + 启动时的恢复程序。 + 检查并处理长时间未完成的僵尸任务,确保系统健壮性。 + """ + logging.echo_log("执行启动恢复程序,检查僵尸任务...") + # 将 func 和 is_delete 传递下去 + await self.reclaim_stale_tasks(func=func, is_delete=is_delete, stale_threshold_ms=5 * 60 * 1000) + logging.echo_log("启动恢复程序执行完毕.") + + async def subscribe(self, func: Callable = None, is_delete=False): + """ + 从消费者组中监听并处理新消息。 + 启动时会先执行恢复程序,处理僵尸任务。 + """ + await self._ensure_group_exists() + + # === 核心改动:启动时先执行恢复,并传入回调参数 === + await self.history(func=func, is_delete=is_delete) + + async with await self.get_redis() as _redis: + try: + self.running = True + logging.echo_log("僵尸任务恢复完成,开始监听新消息...") + + while not self.stopping and self.running: + try: + # 阻塞读取新消息 + streams = await _redis.xreadgroup( + groupname=self.group_name, + consumername=self.consumer_name, + streams={self.stream_name: '>'}, # '>' 表示只读取新消息 + count=1, + block=5000 # 5秒超时,类似 PubSub 的心跳 + ) + + if not streams: + # 超时,是心跳成功的标志 + logging.echo_log("心跳:连接正常,继续监听...") + continue + + # 解析消息 + stream, messages = streams[0] + message_id, message_data = messages[0] + logging.echo_log(f"收到新消息: ID={message_id}, 数据={message_data}") + + try: + # 隔离处理回调异常 + # 采用后台运行的方式处理,防止消息排队,提高消息处理性能 + await aio_pool.run_background_task( + self._callback_wrapper(func, message_id, message_data, is_delete), 10 + ) + except Exception: + # 回调处理失败,消息未被确认,将留在队列中稍后重试 + continue + + except redis.exceptions.ConnectionError as e: + # 连接错误,触发重连 + logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True) + raise e + except (asyncio.CancelledError, KeyboardInterrupt): + logging.echo_log("收到退出信号,停止监听...") + self.running = False + raise + except Exception as e: + logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True) + raise e + finally: + self.running = False + logging.echo_log("Stream 监听已完全停止.") + + async def _callback_wrapper(self, func: Callable, message_id: str, message_data: dict, is_delete: bool): + """ + 一个包装器,用于将 Stream 的消息处理逻辑适配到基类的 callback 方法签名上。 + 这样做可以复用基类 callback 中的异常处理逻辑。 + """ + # 如果没有提供回调函数,则无法处理,直接返回,避免丢失任务 + if not func: + logging.echo_log(f"警告: 收到消息 {message_id} 但未提供业务回调函数,消息将被忽略.", level=WARNING) + return + + # 不能直接调用基类的 callback,因为它会尝试删除 + # 在这里复制它的异常处理逻辑,但使用 Stream 的操作 + result = None + async with await self.get_redis() as _redis: + try: + if func: + # 处理回调 + result = func(message_data) + # 处理协程 + if isinstance(result, Awaitable): + result = await result + + if is_delete: + # 先从 PENDING 列表中移除 + await _redis.xack(self.stream_name, self.group_name, message_id) + # 再从 Stream 中逻辑删除 + await _redis.xdel(self.stream_name, message_id) + logging.echo_log(f"消息已确认 (ACK): {message_id};数据为:{message_data}.") + except redis.RedisError as e: + logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True) + except Exception as e: + logging.echo_log( + f"执行回调异常:{e};方法:{self.get_func_name(func)};消息: {message_id}.", + level=ERROR, is_log_exc=True + ) + return result + + async def run_forever(self, func: Callable = None, is_delete=False): + """ + 持久运行的监听器,包含自动重连逻辑和优雅退出。 + """ + while not self.stopping: + try: + logging.echo_log("启动新的监听会话...") + await self.subscribe(func, is_delete) + except (asyncio.CancelledError, KeyboardInterrupt): + logging.echo_log("收到退出信号,停止监听...") + self.stopping = True + break + except Exception as e: + logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True) + + if self.stopping: + logging.echo_log("总开关已打开,停止重连.") + break + + logging.echo_log("等待重新连接...") + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + logging.echo_log("等待期间被取消,准备退出.") + break + + logging.echo_log("监听服务已完全停止.") + + def subscribe_stop(self): + self.running = False + self.stopping = True diff --git a/paste/rbac/__init__.py b/paste/rbac/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paste/rbac/rbac_assignment.py b/paste/rbac/rbac_assignment.py new file mode 100644 index 0000000..34f3f3b --- /dev/null +++ b/paste/rbac/rbac_assignment.py @@ -0,0 +1,87 @@ +from typing import Union + +from sqlalchemy import select, delete + +from paste.rbac.rbac_models import RbacAssignmentModel + + +class RbacAssignment(RbacAssignmentModel): + """ + 权限分配器,负责用户权限的分配。 + 允许为用户分配角色,也允许为用户直接分配权限。 + """ + + @classmethod + def new(cls, user_id: Union[str, int], item_name: str): + """ + 新建角色分配对象。不检查用户是否存在,也不保存数据库。 + + :param user_id: 用户 ID + :param item_name: 授权项名称 + :return: 新的分配对象 + """ + return cls(user_id=user_id, item_name=item_name).before_save() + + @classmethod + async def new_batch(cls, user_id: int, item_names: set[str]): + """ + 为用户创建角色或权限,自动剔除已经分配的角色或授权项。注意:仅创建不保存。 + + :param user_id: 用户名 + :param item_names: 授权项名称列表,这里允许是角色名称,也允许是权限名称 + """ + from paste.rbac.rbac_user import RbacUser, Supervisors + + # 确认用户存在,且不是超级用户 + _mod_user: RbacUser = await RbacUser(**{RbacUser.id.key: user_id}).async_find_first() + assert _mod_user is not None, f"ID 为 {user_id} 的用户不存在." + assert _mod_user.username not in Supervisors, f"ID 为 {user_id} 的用户 {_mod_user.username} 禁止设置权限." + + # 查库,过滤已经存在的权限分配 + _query = select(cls.item_name).where(cls.user_id == user_id, cls.item_name.in_(item_names)) + _rows = await cls.query_all(_query) + _exist_names: set[str] = set([_r[cls.item_name.key] for _r in _rows]) + item_names = set([_name for _name in item_names if _name not in _exist_names]) + + # 创建模型列表 + _new_assignments = [cls(**{cls.user_id.key: user_id, cls.item_name.key: name}) for name in item_names] + return _new_assignments + + @classmethod + async def assign(cls, user_id: int, item_names: set[str]): + """ + 为用户分配角色或权限,自动剔除已经分配的角色或授权项。 + + :param user_id: 用户名 + :param item_names: 授权项名称列表,这里允许是角色名称,也允许是权限名称 + """ + # 创建模型列表 + _new_assignments = await cls.new_batch(user_id=user_id, item_names=item_names) + + # 保存数据 + _session = cls.get_aio_session() + try: + _session.add_all(_new_assignments) + await _session.commit() + except Exception as e: + await _session.rollback() + raise e + finally: + await _session.close() + + @classmethod + async def delete(cls, user_id: Union[str, int], item_name: str): + """ + 删除授权。 + + :param user_id: 用户 ID + :param item_name: 授权项 + :return: 操作状态,游标返回对象 + """ + assert user_id not in ('', None), '必须提供用户 ID.' + assert item_name not in ('', None), '必须提供权限或角色名称.' + + _query = delete(cls).where(cls.user_id == user_id, cls.item_name == item_name) + _result = await cls.raw_execute(query=_query) + _rowcount = _result.rowcount if isinstance(_result.rowcount, int) else 0 + return _rowcount > 0, _result diff --git a/paste/rbac/rbac_item.py b/paste/rbac/rbac_item.py new file mode 100644 index 0000000..85deb97 --- /dev/null +++ b/paste/rbac/rbac_item.py @@ -0,0 +1,134 @@ +from sqlalchemy import select, delete + +from paste.rbac.rbac_item_child import RbacItemChild +from paste.rbac.rbac_models import RbacItemModel + + +class RbacItem(RbacItemModel): + """ + 授权项。 + 分为角色和权限两类,分别由对应子类实现。 + """ + + TYPE_ROLE = 1 + """ + 角色类型。 + """ + + TYPE_PERMISSION = 2 + """ + 权限类型。 + """ + + @classmethod + async def all_parents(cls, item_names: set[str]): + """ + 按层次,递归查询授权项列表中所有授权项的父授权项名称。 + + :param item_names: 授权项名称列表 + :return: 所有各层级的父授权项 + """ + _query = select(RbacItemChild.parent).where(RbacItemChild.child.in_(item_names)) + _rows = await cls.query_all(_query) + _item_names: set[str] = set(_rows) + + if _item_names: + _parents = await cls.all_parents(_item_names) + _item_names = _item_names.union(_parents) + + return _item_names + + @classmethod + async def all_children(cls, item_names: set[str]): + """ + 按层次,递归查询授权项列表中所有授权项的子授权项名称。 + + :param item_names: 授权项名称列表 + :return: 所有各层级的子授权项 + """ + _query = select(RbacItemChild.child).where(RbacItemChild.parent.in_(item_names)) + _rows = await cls.query_all(_query) + _item_names: set[str] = set(_rows) + + if _item_names: + _children = await cls.all_children(_item_names) + _item_names = _item_names.union(_children) + + return _item_names + + @classmethod + async def find_by_name(cls, name: str): + """ + 根据授权项名称查找授权项。 + + :param name: 授权项名称 + :return: 授权项 + """ + _query = select(cls).where(cls.name == name) + _model: cls = await cls.query_first(_query) + return _model + + async def add_children(self, item_names: set[str]): + """ + 增加子授权项,自动剔除已包含的角色或授权项。角色和权限都能增加子授权项。 + + :param item_names: 待分配的子授权项名称列表 + """ + # 首先根据授权项名称列表查出所有的授权项,剔除错误的名称 + _query = select(RbacItem.name).where(RbacItem.name.in_(item_names)) + _rows = await self.query_all(_query) + item_names: set[str] = set(_rows) + + # 取得所有祖先,确保所有子授权项,没有出现在祖先中,防止循环授权 + _all_parents = await self.all_parents(item_names=item_names) + + # 查库,过滤已经存在的子授权项 + _query = select(RbacItemChild.child).where( + RbacItemChild.parent == self.name, + RbacItemChild.child.in_(item_names) + ) + _rows = await self.query_all(_query) + _exist_children: set[str] = set(_rows) + + # 创建模型列表,剔除已包含的角色或授权项,剔除出现在祖先中的授权项,以及剔除自身 + _new_children = [ + RbacItemChild(**{RbacItemChild.parent.key: self.name, RbacItemChild.child.key: _name}) + for _name in item_names + if _name not in _exist_children and _name not in _all_parents and _name != self.name + ] + + # 保存数据 + _session = self.get_aio_session() + try: + _session.add_all(_new_children) + await _session.commit() + except Exception as e: + await _session.rollback() + raise e + finally: + await _session.close() + + async def get_children(self): + """ + 通过中间关系查询所有子授权项。注意:查询的是直接子授权项,不包含继承获得的授权项。 + + :return: 子权限项列表 + """ + _query = select(RbacItem).join( + RbacItemChild, RbacItemChild.child == RbacItem.name + ).where( + RbacItemChild.parent == self.name, + ) + _item_model_list: list[RbacItem] = await self.query_all(_query) + return _item_model_list + + async def remove_children(self, item_names: set[str]): + """ + 删除子授权项。 + + :param item_names: 子授权项名称集合 + :return: 成功删除的项数 + """ + _delete = delete(RbacItemChild).where(RbacItemChild.parent == self.name, RbacItemChild.child.in_(item_names)) + _result = await self.raw_execute(_delete) + return _result.rowcount diff --git a/paste/rbac/rbac_item_child.py b/paste/rbac/rbac_item_child.py new file mode 100644 index 0000000..9dc630d --- /dev/null +++ b/paste/rbac/rbac_item_child.py @@ -0,0 +1,8 @@ +from paste.rbac.rbac_models import RbacItemChildModel + + +class RbacItemChild(RbacItemChildModel): + """ + 授权项目关系。 + """ + pass diff --git a/paste/rbac/rbac_models.py b/paste/rbac/rbac_models.py new file mode 100644 index 0000000..19b8a6b --- /dev/null +++ b/paste/rbac/rbac_models.py @@ -0,0 +1,93 @@ +# +# 数据模型配置文件,注意:与数据模型对应的表名称来自配置文件。 +# 若使用 生成代码,则这部分表将不会自动生成。 +# + +from sqlalchemy import Column, String, DateTime, LargeBinary, text, BigInteger, Integer, SmallInteger, Text, ForeignKey +from sqlalchemy.orm import relationship + +from paste.core import config +from paste.db.basemodel import BaseModel + + +class RbacRuleModel(BaseModel): + __tablename__ = config.get_config('rbac.table.rule') + __table_args__ = {'comment': '规则'} + + name = Column(String(64, 'utf8mb4_unicode_ci'), primary_key=True, comment='名称') + data = Column(LargeBinary, comment='规则对象') + created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间') + updated_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='更新时间') + + +class RbacUserModel(BaseModel): + __tablename__ = config.get_config('rbac.table.user') + __table_args__ = {'comment': '用户'} + + id = Column(BigInteger, primary_key=True, comment='系统编号') + username = Column(String(255, 'utf8mb4_unicode_ci'), nullable=False, unique=True, comment='用户名') + password_hash = Column(String(255, 'utf8mb4_unicode_ci'), nullable=False, comment='密码') + password_reset_token = Column(String(255, 'utf8mb4_unicode_ci'), comment='重置标记') + auth_key = Column(String(255, 'utf8mb4_unicode_ci'), comment='授权码') + status = Column(Integer, nullable=False, server_default=text("'0'"), comment='用户状态') + type = Column(String(64, 'utf8mb4_unicode_ci'), nullable=False, server_default=text("'user'"), comment='用户类型') + created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间') + updated_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='更新时间') + + def before_save(self): + super().before_save() + if self.is_new: + self.status = 1 + + +class RbacItemModel(BaseModel): + __tablename__ = config.get_config('rbac.table.item') + __table_args__ = {'comment': '授权项(角色/权限)'} + + name = Column(String(64, 'utf8mb4_unicode_ci'), primary_key=True, comment='名称') + type = Column(SmallInteger, nullable=False, comment='类型,1角色,2权限') + description = Column(Text(collation='utf8mb4_unicode_ci'), comment='描述') + rule_name = Column( + ForeignKey(f"{config.get_config('rbac.table.rule')}.name", ondelete='SET NULL', onupdate='CASCADE'), + index=True, comment='规则名称' + ) + created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间') + updated_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='更新时间') + + rbac_rule = relationship(RbacRuleModel) + + +class RbacAssignmentModel(BaseModel): + __tablename__ = config.get_config('rbac.table.assignment') + __table_args__ = {'comment': '权限分配'} + + item_name = Column( + ForeignKey(f"{config.get_config('rbac.table.item')}.name", ondelete='CASCADE', onupdate='CASCADE'), + primary_key=True, nullable=False, comment='授权项(角色/权限)' + ) + user_id = Column( + ForeignKey(f"{config.get_config('rbac.table.user')}.id", ondelete='CASCADE', onupdate='CASCADE'), + primary_key=True, nullable=False, index=True, comment='用户' + ) + created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间') + + rbac_item = relationship(RbacItemModel) + rbac_user = relationship(RbacUserModel) + + +class RbacItemChildModel(BaseModel): + __tablename__ = config.get_config('rbac.table.item_child') + __table_args__ = {'comment': '授权关系'} + + parent = Column( + ForeignKey(f"{config.get_config('rbac.table.item')}.name", ondelete='CASCADE', onupdate='CASCADE'), + primary_key=True, nullable=False, comment='角色/权限' + ) + child = Column( + ForeignKey(f"{config.get_config('rbac.table.item')}.name", ondelete='CASCADE', onupdate='CASCADE'), + primary_key=True, nullable=False, index=True, comment='授权项(角色/权限)' + ) + created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间') + + child_item = relationship(RbacItemModel, primaryjoin='RbacItemChildModel.child == RbacItemModel.name') + parent_item = relationship(RbacItemModel, primaryjoin='RbacItemChildModel.parent == RbacItemModel.name') diff --git a/paste/rbac/rbac_permission.py b/paste/rbac/rbac_permission.py new file mode 100644 index 0000000..71b1758 --- /dev/null +++ b/paste/rbac/rbac_permission.py @@ -0,0 +1,149 @@ +import importlib + +from sqlalchemy import delete + +from paste.core import config +from paste.rbac.rbac_item import RbacItem +from paste.web.application import Application + + +class RbacPermission(RbacItem): + """ + 权限。 + 大多数情况下,权限都对应着一个具体的请求操作,且经由导入期自动导入。 + 权限的 name 属性为请求控制器 RequestHandler 的 route_pattern 属性值, + 权限的 description 属性对应于 RequestHandler 类的文档注释的第一行。 + 权限可以分配给用户,也可以分配给角色或其他权限。 + + 此外允许手动创建权限,主要是为规则创建一个权限载体,当其他的权限属于这个权限的子权限时,相当于同时拥有了这个权限的规则。 + """ + + @classmethod + async def create(cls, name: str, description: str = None, rule_name: str = None): + """ + 创建权限。 + + :param name: 权限名称 + :param rule_name: 规则名称 + :param description: 权限描述 + :return: 权限对象 + """ + assert name not in ('', None), '必须提供权限名称.' + _permission = cls(name=name, description=description, type=cls.TYPE_PERMISSION) + _permission.rule_name = rule_name if rule_name else _permission.rule_name + await _permission.async_save() + return _permission + + @classmethod + async def delete(cls, name: str): + """ + 删除权限。 + + :param name: 授权项名称 + :return: 操作状态,游标返回对象 + """ + assert name not in ('', None), '必须提供权限名称.' + _query = delete(cls).where(cls.name == name) + _result = await cls.raw_execute(query=_query) + _rowcount = _result.rowcount if isinstance(_result.rowcount, int) else 0 + return _rowcount > 0, _result + + @classmethod + async def modify(cls, name: str, description: str = None, rule_name: str = None): + """ + 编辑权限。 + + :param name: 名称 + :param description: 描述 + :param rule_name: 规则名称 + :return: 权限对象 + """ + assert name not in ('', None), '必须提供权限名称.' + _permission: cls = await cls(name=name).async_find_first() + assert _permission, f"未找到名称为:{name} 的权限." + + _permission.description = description if description else _permission.description + _permission.rule_name = rule_name if rule_name else _permission.rule_name + await _permission.async_save() + return _permission + + @classmethod + def identify_permission(cls): + """ + 根据应用程序配置,从应用程序目录中识别所有的控制器,及其对应的路由。 + 读取配置文件中关于 tornado 部分的配置,扫描配置包中的所有 Handler 类。 + 识别需要授权的接口,即包含 auth_permission 装饰器的接口。 + 忽略无需授权的接口,如:部分 OpenAPI 或 FrontendAPI。 + + :return: 识别到的控制器列表,注意列表中是 tuple(route, handler_type) + """ + apps_config: dict = config.get_config('tornado') + _handlers: list[tuple[str, type]] = [] + + for _n, _app_cfgs in apps_config.items(): + for app_cfg in _app_cfgs: + _handlers_pkg = app_cfg.get('handlers_pkg') + _modules_itr = Application.modules_iterator(package=_handlers_pkg) + for file_finder, handler_name, is_package in _modules_itr: + if is_package: + continue + + _module = importlib.import_module(handler_name) + _hls_list = Application.fetch_handlers(module=_module) + for _hls in _hls_list: + _uri, _hdl = _hls + # 检查 post 或 get 是否被 auth_permission 装饰 + for method_name in ['post', 'get']: + method = getattr(_hdl, method_name, None) + if method and callable(method): + if getattr(method, '__auth_permission__', False): + _handlers.append(_hls) + break + + return _handlers + + @classmethod + async def import_permissions(cls): + """ + 导入所有的可分配权限到数据库,若已经在数据库存在,则更新。 + """ + _handlers = cls.identify_permission() + + # 根据所有的路由信息,查出已经有的权限数据 + _routes: list[str] = [rc[0] for rc in _handlers] + _item_model_list: list[cls] = await cls(**{cls.name.key: _routes}).async_find() + + # 利用路由 Key 建立索引 + _item_model_dict: dict[str: cls] = {_item.name: _item for _item in _item_model_list} + + _permissions: list[cls] = [] + for _route, _cls in _handlers: + # 取得类描述 + _desc = f"{_cls.__doc__}".strip().split('\n')[0].strip(), + # 利用路由取出模型 + _perm_item: cls = _item_model_dict.get(_route, None) + if _perm_item is None: + # 未得到模型,创建 + _perm_item = cls(**{cls.name.key: _route, cls.description.key: _desc}) + else: + # 已得到模型,更新 + _perm_item.description = _desc + _perm_item.close_session() + + # 加入列表,批量保存 + _permissions.append(_perm_item) + + # 保存数据 + _session = cls.get_aio_session() + try: + _session.add_all(_permissions) + await _session.commit() + except Exception as e: + await _session.rollback() + raise e + finally: + await _session.close() + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.type = self.TYPE_PERMISSION diff --git a/paste/rbac/rbac_role.py b/paste/rbac/rbac_role.py new file mode 100644 index 0000000..b85922e --- /dev/null +++ b/paste/rbac/rbac_role.py @@ -0,0 +1,63 @@ +from sqlalchemy import delete + +from paste.rbac.rbac_item import RbacItem + + +class RbacRole(RbacItem): + """ + 角色。 + 是一系列关联角色或权限的组合。可以包含权限,也可以包含其他角色。 + """ + + @classmethod + async def create(cls, name: str, description: str = None, rule_name: str = None): + """ + 创建角色。 + + :param name: 角色名称 + :param description: 角色描述 + :param rule_name: 规则名称 + :return: 角色对象 + """ + assert name not in ('', None), '必须提供角色名称.' + _role = cls(name=name, description=description, type=cls.TYPE_ROLE) + _role.rule_name = rule_name if rule_name else _role.rule_name + await _role.async_save() + return _role + + @classmethod + async def delete(cls, name: str): + """ + 删除角色。 + + :param name: 授权项名称 + :return: 操作状态,游标返回对象 + """ + assert name not in ('', None), '必须提供角色名称.' + _query = delete(cls).where(cls.name == name) + _result = await cls.raw_execute(query=_query) + _rowcount = _result.rowcount if isinstance(_result.rowcount, int) else 0 + return _rowcount > 0, _result + + @classmethod + async def modify(cls, name: str, description: str = None, rule_name: str = None): + """ + 编辑角色。 + + :param name: 角色名称 + :param description: 描述 + :param rule_name: 规则名称 + :return: 角色对象 + """ + assert name not in ('', None), '必须提供角色名称.' + _role: cls = await cls(name=name).async_find_first() + assert _role, f"未找到名称为:{name} 的角色." + + _role.description = description if description else _role.description + _role.rule_name = rule_name if rule_name else _role.rule_name + await _role.async_save() + return _role + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.type = self.TYPE_ROLE diff --git a/paste/rbac/rbac_rule.py b/paste/rbac/rbac_rule.py new file mode 100644 index 0000000..92830b1 --- /dev/null +++ b/paste/rbac/rbac_rule.py @@ -0,0 +1,147 @@ +import importlib +import pickle + +from sqlalchemy import select, text, delete + +from paste.rbac.rbac_item import RbacItem +from paste.rbac.rbac_models import RbacRuleModel +from paste.rbac.rbac_permission import RbacPermission +from paste.rbac.rbac_user import RbacUser + + +class RbacRule(RbacRuleModel): + """ + 规则是授权项的附带验证条件。在验证授权项时,只能判断是否具有某个授权项,无法对具体数据执行进一步验证。比如 + 验证某一项数据是否允许某个用户执行某操作,此时就需要用到规则。 + + 规则是在单独定义的验证方法,这个方法被持久化保存在数据库中,具体执行某个需要鉴权的操作时,若该权限配置了规 + 则,那么规则方法会一并参与到鉴权过程中去,以确定用户是否有对具体数据执行操作的权限。 + + 必须允许多个规则应用于一个授权项,但事实上是一个授权项自身只能绑定一个规则,解决方案是将一系列需要规则鉴权 + 的操作作为规则权限的子授权项。这样,在对这个权限进行鉴权操作时,会自底向上逐个检查父权限是否有规则,若有规 + 则,那么会进入这个父权限的规则方法,执行,并校验其返回值。 + + 因此要允许手动创建授权项,且允许手动为授权项绑定规则,然后将其他同样需要执行该规则的权限配置为该权限的子权 + 限。 + + 由于在规则执行系统中,不知道未来将会编写和配置的规则,因此无法调用到对应的规则,只能将来编写好规则后,通过 + 持久化对象到数据库中,通过查库还原将来的规则对象后再执行 run 方法。 + """ + + @classmethod + async def create(cls, full_class_name: str): + """ + 添加规则。 + + :param full_class_name: 规则类 + :return: 保存状态、当前规则 + """ + _rule_cls = cls.load_rule_class(full_class_name) + assert _rule_cls, f"未找到名称为:{full_class_name} 的规则类." + + _rule_model = _rule_cls() + _rule_model.data = _rule_model.dumps() + if not _rule_model.name: + _rule_model.name = _rule_cls.__name__ + + await _rule_model.async_save() + return _rule_model + + @classmethod + async def delete(cls, name: str): + """ + 删除规则。 + + :param name: 要删除的规则名称 + :return: 是否删除成功、删除的行数 + """ + assert name not in ('', None), '必须提供规则名称.' + _query = delete(cls).where(cls.name == name) + _result = await cls.raw_execute(query=_query) + _rowcount = _result.rowcount if isinstance(_result.rowcount, int) else 0 + return _rowcount > 0, _rowcount + + @classmethod + async def modify(cls, name: str, full_class_name: str): + """ + 编辑规则。 + + :param name: 名称 + :param full_class_name: 规则类 + :return: 保存状态、当前规则 + """ + _rule_cls = cls.load_rule_class(full_class_name) + assert _rule_cls, f"未找到名称为:{full_class_name} 的规则类." + _rule_model = _rule_cls() + + assert name not in ('', None), '必须提供规则名称.' + _rule: cls = await cls(name=name).async_find_first() + assert _rule, f"未找到名称为:{name} 的规则." + + _rule.data = _rule_model.dumps() + _rule.name = _rule_model.name + if not _rule.name: + _rule.name = _rule_cls.__name__ + + await _rule.async_save() + return _rule + + @classmethod + async def find_by_item_names(cls, item_names: set[str]): + """ + 根据授权项名称(权限名称或角色名称)取得所有角色。 + + :param item_names: 授权项名称列表 + :return: 规则列表 + """ + # 取出所有授权项中的规则,忽略没有规则的 + _query = select(cls).join( + RbacItem, RbacItem.rule_name == cls.name + ).where( + RbacItem.name.in_(item_names), + text(f"ifnull({RbacItem.rule_name.key},'')!=''") + ) + _rule_list: list[cls] = await cls.query_all(_query) + return _rule_list + + @classmethod + def load_rule_class(cls, full_class_name: str): + """ + 通过规则类名称,加载规则类。若类所在的模块不存在,则报异常。 + + :param full_class_name: 完整规则名称,从顶层模块名称开始,直到类名称。 + :return: 找到的规则类,找不到返回 None + """ + _full_paths = full_class_name.split('.') + _cls_name = _full_paths[-1] + _mod_name = '.'.join(_full_paths[:-1]) + + try: + _module = importlib.import_module(_mod_name) + # 迭代模块成员 + for _n in dir(_module): + _cls = getattr(_module, _n) + if isinstance(_cls, type) and issubclass(_cls, RbacRule) and _cls.__name__ == _cls_name: + return _cls + except Exception: + return None + + return None + + def run(self, rbac_user: RbacUser, rbac_permission: RbacPermission, *args, **kwargs) -> bool: + """ + 运行规则。当用户在执行具体操作时,该规则会自动被唤起执行。 + + :param rbac_user: 用户 + :param rbac_permission: 权限对象 + :return: 允许执行返回 True, 否则返回 False + """ + return True + + def dumps(self): + """ + 序列化为可持久化文本。 + + :return: 可持久化文本 + """ + return pickle.dumps(self) diff --git a/paste/rbac/rbac_user.py b/paste/rbac/rbac_user.py new file mode 100644 index 0000000..b8b03fc --- /dev/null +++ b/paste/rbac/rbac_user.py @@ -0,0 +1,259 @@ +import pickle +from typing import Optional, Awaitable + +from sqlalchemy import select + +from paste.rbac.rbac_assignment import RbacAssignment +from paste.rbac.rbac_item import RbacItem +from paste.rbac.rbac_models import RbacUserModel +from paste.rbac.rbac_permission import RbacPermission +from paste.rbac.rbac_role import RbacRole +from paste.security import shash + +Supervisors = ('administrator', 'root', 'supervisor') +""" +超级管理员名称,这些名称不能用于一般用户。 +""" + + +class RbacUser(RbacUserModel): + """ + RBAC 用户类。 + """ + + STATUS_DEFAULT = 0b00000000000000000000000000000 + """ + 用户默认状态:0。 + """ + STATUS_ENABLED = 0b00000000000000000000000000001 + """ + 用户激活状态:1。 + """ + STATUS_DISABLED = 0b00000000000000000000000000010 + """ + 用户禁用状态:2。 + """ + STATUS_DELETED = 0b00000000000000000000000000100 + """ + 用户删除状态:4。 + """ + + STATUS_LIST = [STATUS_DEFAULT, STATUS_ENABLED, STATUS_DISABLED, STATUS_DELETED] + """ + 允许的所有用户状态。 + """ + + STATUS_DESCRIPTION = { + STATUS_DEFAULT: '默认', + STATUS_ENABLED: '激活', + STATUS_DISABLED: '已禁用', + STATUS_DELETED: '已删除', + } + """ + 用户状态描述。 + """ + + TYPE_USER = 'user' + """ + 用户类型:用户。 + """ + + TYPE_ADMINISTRATOR = 'admin' + """ + 用户类型:管理员。 + """ + + @classmethod + async def import_supervisors(cls): + """ + 导入超级用户。 + + :return: 初始化状态 + """ + # 查出已经有的超级用户 + _user_model_list: list[cls] = await cls(**{cls.username.key: Supervisors}).async_find() + _user_model_dict: dict[str: cls] = {_user.username: _user for _user in _user_model_list} + + init_status: bool = True + for _username in Supervisors: + _user: cls = _user_model_dict.get(_username, None) + if _user is None: + # 用户不存在,创建超级用户 + _save_status, _ = await cls.create(username=_username, password=_username) + + return init_status + + @classmethod + async def create(cls, username: str, password: str): + """ + 创建用户。 + + :param username: 用户名 + :param password: 密码 + :return: 保存状态 + """ + assert username is not None and password is not None, '必须提供用户名和密码.' + _usr_model = cls() + _usr_model.before_save() + _usr_model.username = username + _usr_model.password_hash = shash.generate_password_hash(pwd=password) + _status = await _usr_model.async_save() + return _status, _usr_model + + @classmethod + async def find_by_username(cls, username: str): + """ + 根据用户名查找用户。 + + :param username: 用户名 + :return: 用户对象 + """ + query = select(cls).where(cls.username == username) + model = await cls.query_first(query) + return model + + @classmethod + def status_description(cls, status): + """ + 取得状态描述。 + + :param status: + :return: + """ + if status in cls.STATUS_DESCRIPTION: + return cls.STATUS_DESCRIPTION[status] + else: + return '状态未知' + + async def assign(self, item_names: set[str]): + """ + 为用户分配权限。这里调用了权限分配器的分配方法,自动剔除已经分配过的角色或权限。 + + :param item_names: 授权项名称列表 + """ + await RbacAssignment.assign(user_id=self.id, item_names=item_names) + + async def can(self, permission_name: str, **kwargs) -> bool: + """ + 验证用户是否具有 permission_name 权限。 + + 该方法主要是调用 :class:`RbacRule` 的 run() 方法,执行规则检验。 + + 在执行规则验证的时候是自底向上,查询父 RbacItem 中的规则,并调用其 run 方法。调用时不分先后,随机执行。 + + 只要有一个规则返回 False,则后续规则方法不再继续执行。 + + :param permission_name: 权限名称 + :param kwargs: 可选参数,传递给规则的 run 方法的参数 + :return: 验证状态 + """ + from paste.rbac.rbac_rule import RbacRule + + # 如果用户没有初始化,直接禁止 + if self.username is None: + return False + + # 如果是超级用户,直接返回 True + if self.is_supervisors(): + return True + + # 取得所有授权项,然后取得所有对应规则 + _item_names = await self.get_all_permissions() + _rule_list: list[RbacRule] = await RbacRule.find_by_item_names(_item_names) + + # 取出授权项 + _permission: RbacPermission = await RbacPermission.find_by_name(permission_name) + + # 遍历执行 run 方法 + for _rule in _rule_list: + # 还原持久化数据到对象 + _rule: RbacRule = pickle.loads(_rule.data) + + # 执行 run 方法,若是协程,则继续等待协程完成 + _result = _rule.run(rbac_user=self, rbac_permission=_permission, **kwargs) + if isinstance(_result, Awaitable): + _result = await _result + + if not _result: + return False + + return True + + async def get_all_roles(self): + """ + 取得所有角色名称。 + + :return: 角色名称列表 + """ + _roles = await self.roles() + _role_names = [_r.name for _r in _roles] + return _role_names + + async def get_all_permissions(self): + """ + 取得所有授权项的名称,包含角色和权限。 + + :return: 授权项名称列表 + """ + _dir_perms = await self.get_direct_permissions() + _inh_perms = await self.get_inherit_permissions(direct_perms=_dir_perms) + _dir_perms = _dir_perms.union(_inh_perms) + return _dir_perms + + async def get_direct_permissions(self): + """ + 通过权限分配器查询用户直接拥有的角色和权限。 + + :return: 授权项名称列表 + """ + _query = select(RbacAssignment.item_name, RbacAssignment.user_id).where(RbacAssignment.user_id == self.id) + _item_names = await RbacAssignment.query_all(_query) + return set(_item_names) + + async def get_inherit_permissions(self, direct_perms: Optional[set[str]] = None): + """ + 通过直接拥有的权限查询继承而来的角色和权限。 + 继承而来的权限包括用户所有角色的权限及其子权限。 + + :param direct_perms: 用户直接拥有的权限名称列表 + :return: 授权项名称列表 + """ + # 若不提供用户直接拥有的权限,则查询 + if direct_perms is None: + direct_perms = await self.get_direct_permissions() + + _perm_names = await RbacItem.all_children(item_names=direct_perms) + return _perm_names + + async def has_permission(self, permission_name: str): + """ + 验证用户是否有执行某个路由的权限。 + + :param permission_name: 路由模式,一般为 route_pattern + :return: 是否有权限 + """ + _permissions = await self.get_all_permissions() + return permission_name in _permissions + + async def roles(self): + """ + 通过中间关系直接查询用户的所有角色。 + + :return: 角色列表 + """ + _query = select(RbacRole).join( + RbacAssignment, RbacAssignment.item_name == RbacRole.name + ).where( + RbacRole.type == RbacRole.TYPE_ROLE, + RbacAssignment.user_id == self.id + ) + _roles: list[RbacRole] = await RbacRole.query_all(_query) + return _roles + + def is_supervisors(self): + """ + 检查是否是超级用户。 + + :return: 是超级用户放回 True,否则返回 False + """ + return self.username in Supervisors diff --git a/paste/rbac/test_rule.py b/paste/rbac/test_rule.py new file mode 100644 index 0000000..70e840e --- /dev/null +++ b/paste/rbac/test_rule.py @@ -0,0 +1,19 @@ +from paste.core import logging +from paste.rbac.rbac_permission import RbacPermission +from paste.rbac.rbac_rule import RbacRule +from paste.rbac.rbac_user import RbacUser + + +class TestRule(RbacRule): + """ + 测试规则类。实际编写时要注意父类继承关系。 + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.name = '测试规则' + self.data = self.dumps() + + def run(self, rbac_user: RbacUser, rbac_permission: RbacPermission, *args, **kwargs): + logging.echo_log(f"正在运行规则:{self.name}.") + return True diff --git a/paste/security/__init__.py b/paste/security/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paste/security/cryp_rsa.py b/paste/security/cryp_rsa.py new file mode 100644 index 0000000..72dca12 --- /dev/null +++ b/paste/security/cryp_rsa.py @@ -0,0 +1,59 @@ +import base64 + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa + + +def generate_rsa_keypair(): + """ + 生成RSA密钥对,并返回公钥的PEM格式字符串。 + + 此函数生成一个2048位的RSA密钥对,使用65537作为公钥指数。 + 仅返回公钥部分的PEM编码字符串,私钥不返回以确保安全性。 + + 返回: + str: 公钥的PEM格式字符串,包含-----BEGIN PUBLIC KEY-----和-----END PUBLIC KEY-----头尾。 + + 示例: + -----BEGIN PUBLIC KEY----- + MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA... + -----END PUBLIC KEY----- + + 注意: + - 私钥在此函数中生成但未返回,应由调用方安全存储。 + - 不建议在生产环境中直接使用此函数生成密钥,应使用更安全的密钥管理服务。 + """ + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_key = private_key.public_key() + + # 获取公钥的 PEM 格式 + public_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ).decode('utf-8') + + return public_pem + + +def rsa_encrypt_pkcs1_v1_5(public_key_pem, plaintext): + """ + 使用 PKCS#1 v1.5 填充对字符串进行 RSA 加密。 + + :param public_key_pem: 公钥的 PEM 格式字符串 + :param plaintext: 要加密的明文字符串 + :return: Base64 编码的加密后字节串 + """ + # 加载公钥 + public_key = serialization.load_pem_public_key(public_key_pem.encode()) + + # 将字符串编码为字节 + plaintext_bytes = plaintext.encode('utf-8') + + # 使用 PKCS#1 v1.5 填充进行加密 + ciphertext = public_key.encrypt( + plaintext_bytes, + padding.PKCS1v15() + ) + + # 返回 Base64 编码的加密结果 + return base64.b64encode(ciphertext).decode('utf-8') \ No newline at end of file diff --git a/paste/security/shash.py b/paste/security/shash.py new file mode 100755 index 0000000..99b63e6 --- /dev/null +++ b/paste/security/shash.py @@ -0,0 +1,106 @@ +import hashlib +import hmac +import secrets +from typing import Tuple + +SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +""" +加盐字符。 +""" + +DEFAULT_PBKDF2_ITERATIONS = 260000 +""" +加密迭代数。 +""" + + +def gen_salt(length: int) -> str: + """ + Generate a random string of SALT_CHARS with specified ``length``. + """ + if length <= 0: + raise ValueError("Salt length must be positive") + + return "".join(secrets.choice(SALT_CHARS) for _ in range(length)) + + +def _hash_internal(method: str, salt: str, password: str) -> Tuple[str, str]: + """ + Internal password hash helper. Supports plaintext without salt, unsalted and salted passwords. + In case salted passwords are used hmac is used. + """ + if method == "plain": + return password, method + + salt = salt.encode("utf-8") + password = password.encode("utf-8") + + if method.startswith("pbkdf2:"): + if not salt: + raise ValueError("Salt is required for PBKDF2") + + args = method[7:].split(":") + + if len(args) not in (1, 2): + raise ValueError("Invalid number of arguments for PBKDF2") + + method = args.pop(0) + iterations = int(args[0] or 0) if args else DEFAULT_PBKDF2_ITERATIONS + return ( + hashlib.pbkdf2_hmac(method, password, salt, iterations).hex(), + f"pbkdf2:{method}:{iterations}", + ) + + if salt: + return hmac.new(salt, password, method).hexdigest(), method + + return hashlib.new(method, password).hexdigest(), method + + +def generate_password_hash(pwd: str, method: str = "pbkdf2:sha256", salt_length: int = 16) -> str: + """ + Hash a password with the given method and salt with a string of + the given length. The format of the string returned includes the method + that was used so that :func:`check_password_hash` can check the hash. + + The format for the hashed string looks like this:: + + method$salt$hash + + This method can **not** generate unsalted passwords but it is possible + to set param method='plain' in order to enforce plaintext passwords. + If a salt is used, hmac is used internally to salt the password. + + If PBKDF2 is wanted it can be enabled by setting the method to + ``pbkdf2:method:iterations`` where iterations is optional:: + + pbkdf2:sha256:80000$salt$hash + pbkdf2:sha256$salt$hash + + :param pwd: the password to hash + :param method: the hash method to use (one that hashlib supports). Can + optionally be in the format ``pbkdf2:method:iterations`` + to enable PBKDF2 + :param salt_length: the length of the salt in letters + """ + salt = gen_salt(salt_length) if method != "plain" else "" + h, actual_method = _hash_internal(method, salt, pwd) + return f"{actual_method}${salt}${h}" + + +def check_password_hash(pwd_hash: str, password: str) -> bool: + """ + Check a password against a given salted and hashed password value. + In order to support unsalted legacy passwords this method supports + plain text passwords, md5 and sha1 hashes (both salted and unsalted). + + Returns `True` if the password matched, `False` otherwise + :param pwd_hash: a hashed string like returned by + :func:`generate_password_hash` + :param password: the plaintext password to compare against the hash + """ + if pwd_hash.count("$") < 2: + return False + + method, salt, hash_val = pwd_hash.split("$", 2) + return hmac.compare_digest(_hash_internal(method, salt, password)[0], hash_val) diff --git a/paste/security/token.py b/paste/security/token.py new file mode 100644 index 0000000..86a3261 --- /dev/null +++ b/paste/security/token.py @@ -0,0 +1,89 @@ +import base64 +import datetime +import uuid +from typing import Optional + +import jwt + +SECRET_KEY = 'IrOYFjXtQPuofoXEpsR+2VsvjnaCWkPzvfmym1qNmcI=' +""" +自定义密钥,生成方法见 generate_secret_key() 函数。该密钥应当在应用程序服务器启动时更新,并使用动态生成的密钥。 +""" + +PRIVATE_ISS = 'hǎi_tén_education_technology_co_ltd' +""" +签发者签名。 +""" + + +def generate_secret_key(): + """ + 生成如 IrOYFjXtQPuofoXEpsR+2VsvjnaCWkPzvfmym1qNmcI= 的随机串。 + """ + return base64.b64encode(uuid.uuid4().bytes + uuid.uuid4().bytes).decode() + + +def reset_secret_key(): + """ + 重置加密密钥。 + + :return: 加密密钥 + """ + global SECRET_KEY + SECRET_KEY = generate_secret_key() + return SECRET_KEY + + +def get_secret_key(): + """ + 取得加密密钥。 + + :return: 加密密钥 + """ + global SECRET_KEY + return SECRET_KEY + + +def encode_token(exp: Optional[datetime.datetime] = None, **kwargs): + """ + 对用户信息加密生成令牌。 + + :param exp: 过期时间,不传默认 7 天过期 + :param kwargs: 要加入到数据部分的参数 + :return: 加密后的 token 内容 + """ + try: + iat = datetime.datetime.now(datetime.timezone.utc) + exp = exp if exp else iat + datetime.timedelta(days=7) + # 即 JWT 三部分中的载荷部分 + # 过期时间最后是和 UTC 作比较,所以设置的时候使用 datetime.datetime.now() + payload = { + 'iss': PRIVATE_ISS, # 签发者 + 'iat': iat, # 签发时间 + 'exp': exp, # 过期时间,这里设置7天 + 'params': {} # 参数,存放用户自定义数据 + } + + payload['params'].update(kwargs) + + # 开始进行加密,返回字符串,传入例如密钥,指定加密算法 + # encode 返回的是 bytes,需要 decode() 得到 str,不转换的话,在封装json的时候报错 + auth_token = jwt.encode(payload, SECRET_KEY, algorithm='HS256') + if isinstance(auth_token, bytes): + auth_token = auth_token.decode() + return auth_token + except Exception as e: + raise e + + +def decode_token(auth_token: str): + """ + 解码 token 并从中提取用户信息,进行验证。 + + :param auth_token: 令牌 + """ + # 如果需要关闭过期时间的验证,可以在 options 中使用 verify_exp + # jwt.decode(auth_token, secret_key, issuer=private_iss, algorithms=['HS256'], options={'verify_exp': False}) + # 传入了密钥和算法,和加密是对应的,因此密钥一定不要泄露 + token_payload = jwt.decode(auth_token, SECRET_KEY, issuer=PRIVATE_ISS, algorithms=['HS256']) + return token_payload diff --git a/paste/service/__init__.py b/paste/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paste/service/daemonize.py b/paste/service/daemonize.py new file mode 100644 index 0000000..b82f749 --- /dev/null +++ b/paste/service/daemonize.py @@ -0,0 +1,168 @@ +import atexit +import os +import signal +import sys +from typing import Callable + +import psutil + +from paste.util import ufile + + +class DaemonizeService: + """ + 驻内存服务。 + """ + + def __init__(self, pid_file, name: str = '', stdin: str = None, stdout: str = None, stderr: str = None): + """ + 初始化服务。 + + :param pid_file: pid 文件路径 + :param name: 服务名称 + :param stdin: 输入文件路径 + :param stdout: 输出文件路径 + :param stderr: 错误日志文件路径 + """ + self.start_callback = None + self.start_callback_args = () + self.start_callback_kwargs = {} + + self.term_callback = None + self.term_callback_args = () + self.term_callback_kwargs = {} + + self.pid_file = pid_file + self.name = name + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + + def set_start_callback(self, callback, *args, **kwargs): + """ + 设置启动回调函数。不返回参数。 + + :param callback: 回调函数 + :param args: 回调参数 + :param kwargs: 回调参数 + """ + self.start_callback = callback + self.start_callback_args = args + self.start_callback_kwargs = kwargs + + def set_term_callback(self, callback, *args, **kwargs): + """ + 设置终止回调函数。不返回参数。 + + :param callback: 回调函数 + :param args: 回调参数 + :param kwargs: 回调参数 + """ + self.term_callback = callback + self.term_callback_args = args + self.term_callback_kwargs = kwargs + + def daemonize(self): + """ + 设置和启动常驻服务程序。 + """ + if os.path.exists(self.pid_file): + _pid = int(ufile.read_to_buffer(self.pid_file).decode('utf8').strip()) + if psutil.pid_exists(_pid): + raise RuntimeError(f"[{self.name}] 正在运行.") + else: + os.remove(self.pid_file) + + try: + if os.fork() > 0: + raise SystemExit(0) # Parent exit + except OSError: + raise RuntimeError('创建子进程失败 #1.') + + os.chdir(os.path.abspath(os.path.curdir)) + os.umask(0) + os.setsid() + + try: + if os.fork() > 0: + raise SystemExit(0) + except OSError: + raise RuntimeError('创建子进程失败 #2.') + + sys.stdout.flush() + sys.stderr.flush() + + # 替换 stdin, stdout, 和 stderr 的文件描述符 + if self.stdin is not None: + with open(self.stdin, 'rb', 0) as file: + os.dup2(file.fileno(), sys.stdin.fileno()) + + if self.stdout is not None: + with open(self.stdout, 'ab', 0) as file: + os.dup2(file.fileno(), sys.stdout.fileno()) + + if self.stderr is not None: + with open(self.stderr, 'ab', 0) as file: + os.dup2(file.fileno(), sys.stderr.fileno()) + + # 写入 PID 文件 + with open(self.pid_file, 'w') as file: + print(os.getpid(), file=file) + + # 注册退出函数,进程退出时(包括异常退出)移除pid文件 + atexit.register(lambda: os.remove(self.pid_file)) + # 监听终止信号,绑定到处理程序 + signal.signal(signal.SIGTERM, self.sigterm_handler) + + def sigterm_handler(self, signum, frame): + """ + 终止信号处理程序。这里是执行回调函数,处理服务退出前的准备。 + :param signum: 信号代码 + :param frame: 帧 + :return: 程序退出码 + """ + if isinstance(self.term_callback, Callable): + self.term_callback(*self.term_callback_args, **self.term_callback_kwargs) + raise SystemExit(1) + + def start(self): + """ + 启动服务。 + """ + try: + self.daemonize() + if isinstance(self.start_callback, Callable): + self.start_callback(*self.start_callback_args, **self.start_callback_kwargs) + except RuntimeError as e: + print(e, file=sys.stderr) + raise SystemExit(1) + + def stop(self): + """ + 停止服务。 + """ + if os.path.exists(self.pid_file): + _pid = int(ufile.read_to_buffer(self.pid_file).decode('utf8').strip()) + if psutil.pid_exists(_pid): + os.kill(_pid, signal.SIGTERM) + else: + os.remove(self.pid_file) + else: + print(f"[{self.name}] 尚未启动.", file=sys.stderr) + raise SystemExit(1) + + def cli_run(self): + """ + 命令行启动。 + """ + if len(sys.argv) != 2: + print(f"Please use like: python3 {sys.argv[0]} [start|stop].", file=sys.stderr) + raise SystemExit(1) + + if sys.argv[1] == 'start': + self.start() + elif sys.argv[1] == 'stop': + self.stop() + else: + print(f"未知命令: {sys.argv[1]}", file=sys.stderr) + raise SystemExit(1) diff --git a/paste/service/server.py b/paste/service/server.py new file mode 100644 index 0000000..89d4088 --- /dev/null +++ b/paste/service/server.py @@ -0,0 +1,152 @@ +""" +服务管理包。用这个包能根据服务的名称启动服务。 +""" + +import importlib +import os.path +from types import ModuleType +from typing import Callable, Awaitable + +import pandas as pd + +from paste.core import aio_pool, config +from paste.core.logging import logger_config_name +from paste.util import ufile, udict + +service_flag = ['service_name', 'pid_file', 'start_service', 'start', 'stop'] +""" +模块若要成为服务,必须同时具备以上属性。 +""" + + +def get_services(full_log_path=True): + """ + 取得所有服务列表及其运行状态,返回格式为 Pandas DataFrame。 + + :param full_log_path: 是否输出完整日志文件路径 + :return: 服务运行状态数据框架 + """ + _service_info = [] + _service_module = 'service' + _service_path = os.path.join(os.path.curdir, _service_module) + for _root, _dirs, _files in os.walk(_service_path): + for _file in _files: + _file_name, _ = os.path.splitext(_file) + if not _file_name.endswith('_service'): + continue + _mod_name = '.'.join([_service_module, _file_name]) + _service_info.append(read_service_info(_mod_name, full_log_path)) + _service_df = pd.DataFrame(_service_info) + return _service_df + + +def is_service(service_module: ModuleType): + """ + 检查模块是否是服务模块。 + + :param service_module: + :return: + """ + _is_service = True + for _attr in service_flag: + if hasattr(service_module, _attr): + continue + _is_service = False + return _is_service + + +def read_service_info(full_module_name: str, full_log_path=True): + """ + 读取模块信息。 + + :param full_module_name: 完整模块路径 + :param full_log_path: 是否输出完整日志文件路径 + :return: 服务模块信息 + """ + _module = importlib.import_module(full_module_name) + assert is_service(_module), f"未设置关键属性,请确认是否为服务." + _pid = '' + _process_info = {} + _is_running = False + if os.path.exists(_module.pid_file): + _pid = ufile.read_to_buffer(_module.pid_file).decode('utf8').strip() + _process_info = aio_pool.process_info(int(_pid)) + _is_running = True if _process_info else False + + _configure = config.load_config() + _logger_config_name = getattr(_module, 'logger_config_name', logger_config_name) + if full_log_path: + _logger_file_path = os.path.abspath(udict.get_by_path(_configure, f"{_logger_config_name}.filename")) + else: + _logger_file_path = udict.get_by_path(_configure, f"{_logger_config_name}.filename") + + # 代码中的 service_name 在这里作为服务描述 + # 而 full_module_name 作为服务名称 + _info = { + 'service': full_module_name, + 'service_name': _module.service_name, + 'is_running': os.path.exists(_module.pid_file), + 'logger_config_name': _logger_config_name, + 'logger_file_path': _logger_file_path, + 'pid_file': _module.pid_file, + 'pid': _pid, + 'process_name': _process_info.get('name', '') if _process_info else '', + 'cpu_usage': _process_info.get('cpu_usage', '') if _process_info else '', + 'memory_usage': _process_info.get('memory_usage', '') if _process_info else '', + 'running_time': _process_info.get('running_time', '') if _process_info else '', + } + return _info + + +def start_service(full_module_name: str): + """ + 在控制台启动服务,注意:当控制台关闭时,服务随即停止。 + + :param full_module_name: 完整模块路径。 + :return: 操作状态,仅代表操作状态,并非立即启动服务 + """ + _module = importlib.import_module(full_module_name) + _start: Callable = getattr(_module, 'start_service', None) + if not isinstance(_start, Callable): + return + _result = _start() + # 处理异步方法执行 + if isinstance(_result, Awaitable): + _runner = aio_pool.get_aio_runner() + _result = _runner(_result) + + +def start(full_module_name: str): + """ + 启动服务。 + + :param full_module_name: 完整模块路径。 + :return: 操作状态,仅代表操作状态,并非立即启动服务 + """ + _module = importlib.import_module(full_module_name) + _start: Callable = getattr(_module, 'start', None) + if not isinstance(_start, Callable): + return + _result = _start() + # 处理异步方法执行 + if isinstance(_result, Awaitable): + _runner = aio_pool.get_aio_runner() + _result = _runner(_result) + + +def stop(full_module_name: str): + """ + 停止服务。 + + :param full_module_name: 完整模块路径。 + :return: 操作状态,仅代表操作状态,并非立即结束服务 + """ + _module = importlib.import_module(full_module_name) + _stop: Callable = getattr(_module, 'stop', None) + if not isinstance(_stop, Callable): + return + _result = _stop() + # 处理异步方法执行 + if isinstance(_result, Awaitable): + _runner = aio_pool.get_aio_runner() + _result = _runner(_result) diff --git a/paste/service/task_service.py b/paste/service/task_service.py new file mode 100644 index 0000000..56f2e35 --- /dev/null +++ b/paste/service/task_service.py @@ -0,0 +1,496 @@ +""" +系统服务,用于读取服务配置文件,启动或停止相关的服务。 +""" + +import asyncio +import datetime +import logging +from asyncio import AbstractEventLoop, Task +from enum import Enum +from typing import Callable, Awaitable, Optional + +from dateutil.relativedelta import relativedelta + +from paste.core import aio_pool +from paste.core.logging import echo_log +from paste.db.baseadapter import BaseAdapter +from paste.service.daemonize import DaemonizeService + + +class PeriodType(Enum): + WEEKLY = "weekly" + MONTHLY = "monthly" + YEARLY = "yearly" + QUARTERLY = "quarterly" + + +class TaskService: + """ + 任务服务,专用于创建或停止服务。 + """ + + task_event_loop: Optional[AbstractEventLoop] = None + """ + 任务件循环对象。 + """ + + def __init__(self, service_name: str = None, pid_file: str = None): + """ + 构造函数。 + + :param service_name: 服务名称 + :param pid_file: 进程 ID 文件路径 + """ + + self.service_name = service_name + """ + 服务名称。 + """ + if self.service_name is None: + self.service_name = '未命名服务' + + self.pid_file = pid_file + """ + PID 文件路径。 + """ + if self.pid_file is None: + _now = datetime.datetime.now() + self.pid_file = f'/tmp/task_service_{_now.strftime("%Y%m%d%H%M%S%f")}.pid' + + self.task_list: list[Task] = [] + """ + 任务列表。 + """ + + self._create_task_params: list[dict] = [] + """ + 创建任务的参数列表。 + """ + + self.is_running = True + """ + 是否允许运行。 + """ + + self.log_next_time = True + """ + 是否记录下次执行时间。 + """ + + def event_loop(self): + """ + 在需要调用的时间点取得事件循环对象。 + + :return: 事件循环对象 + """ + self.task_event_loop = aio_pool.get_aio_loop() + return self.task_event_loop + + def create_delay_task(self, fn: Callable = None, delay: int = 60, **kwargs): + """ + 创建延时任务工厂,每次任务完成后,将等待固定时长后继续执行。 + + :param fn: 要执行的任务函数 + :param delay: 延时长度,单位:秒 + :param kwargs fn 函数的参数 + """ + + def log_next(log_next_time: bool): + if log_next_time and self.log_next_time: + echo_log(f"距下次执行:{fn.__name__} 还有:{delay} 秒.") + + async def task_warp(): + """ + 任务包装器。 + """ + if fn is not None: + try: + _result = fn(**kwargs) + if isinstance(_result, Awaitable): + await _result + except Exception as e: + echo_log(msg=e, level=logging.ERROR, is_log_exc=True) + + async def task_loop(): + """ + 任务循环。 + """ + if fn is None: + return + + _log_next_time = True + _next_time = None + while self.is_running: + _delta_seconds = relativedelta(datetime.datetime.now(), _next_time).seconds if _next_time else 1 + + if _delta_seconds > 0: + await task_warp() + # 执行服务后,更新日期值 + _next_time = datetime.datetime.now() + relativedelta(seconds=delay) + _log_next_time = True + else: + log_next(_log_next_time) + _log_next_time = False + await asyncio.sleep(0.5) + continue + + _loop = self.event_loop() + _tsk: Task = _loop.create_task(task_loop()) + return _tsk + + def create_daily_task(self, fn: Callable = None, run_on_start=False, + year: int = None, month: int = None, day: int = None, + hour: int = None, minute: int = None, **kwargs): + """ + 日常任务工厂,每次任务完成后,在第二天的固定时间继续执行。若设置的时间小于当前时间,则自动加一天。 + + :param fn: 要执行的任务函数 + :param run_on_start: 是否在启动时立即运行一次,默认不运行 + :param year: 年 + :param month: 月 + :param day: 日 + :param hour: 时 + :param minute: 分 + :param kwargs fn 函数的参数 + """ + _now = datetime.datetime.now() + year = _now.year if year is None else year + month = _now.month if month is None else month + day = _now.day if day is None else day + hour = _now.hour if hour is None else hour + minute = _now.minute if minute is None else minute + _next_time = datetime.datetime(year, month, day, hour, minute, 0) + if relativedelta(_next_time, datetime.datetime.now()).seconds < 0: + # 小于当前时间的,自动加一天 + _next_time = _next_time + relativedelta(days=1) + + def log_next(log_next_time: bool, next_time: datetime.datetime): + if log_next_time and self.log_next_time: + _delta = relativedelta(next_time, datetime.datetime.now()) + _d, _h, _m, _s = _delta.days, _delta.hours, _delta.minutes, _delta.seconds + echo_log(f"距下次执行:{fn.__name__} 还有:{_d} 天 {_h} 时 {_m} 分 {_s} 秒.") + + async def task_warp(): + """ + 任务包装器。 + """ + if fn is not None: + try: + _result = fn(**kwargs) + if isinstance(_result, Awaitable): + await _result + except Exception as e: + echo_log(msg=e, level=logging.ERROR, is_log_exc=True) + + async def task_loop(next_time: datetime.datetime): + """ + 任务循环。 + + :param next_time: 下次执行时间 + """ + if fn is None: + return + + _log_next_time = True + _run_on_start = run_on_start + while self.is_running: + _delta_seconds = relativedelta(datetime.datetime.now(), next_time).seconds + + if _run_on_start or _delta_seconds > 0: + await task_warp() + _run_on_start = False + # 执行服务后,更新日期值 + next_time = next_time + relativedelta(days=1) + _log_next_time = True + else: + log_next(_log_next_time, next_time) + _log_next_time = False + await asyncio.sleep(0.5) + continue + + _loop = self.event_loop() + _tsk: Task = _loop.create_task(task_loop(next_time=_next_time)) + return _tsk + + def create_weekly_task(self, fn: Callable = None, weekday: int = 0, + hour: int = 0, minute: int = 0, run_on_start: bool = False, **kwargs): + """每周某天固定时间执行(周一=0,周日=6)""" + return self.create_periodic_task( + fn, PeriodType.WEEKLY, run_on_start=run_on_start, + hour=hour, minute=minute, weekday=weekday, **kwargs + ) + + def create_monthly_task(self, fn: Callable = None, day_of_month: int = 1, + hour: int = 0, minute: int = 0, run_on_start: bool = False, **kwargs): + """每月固定日期执行""" + return self.create_periodic_task( + fn, PeriodType.MONTHLY, run_on_start=run_on_start, + hour=hour, minute=minute, day_of_month=day_of_month, **kwargs + ) + + def create_yearly_task(self, fn: Callable = None, month: int = 1, day_of_month: int = 1, + hour: int = 0, minute: int = 0, run_on_start: bool = False, **kwargs): + """每年固定日期执行""" + return self.create_periodic_task( + fn, PeriodType.YEARLY, run_on_start=run_on_start, + hour=hour, minute=minute, month=month, day_of_month=day_of_month, **kwargs + ) + + def create_quarterly_task(self, fn: Callable = None, start_month: int = 1, day_of_month: int = 1, + hour: int = 0, minute: int = 0, run_on_start: bool = False, **kwargs): + """每季度固定日期执行(start_month: 1,4,7,10)""" + return self.create_periodic_task( + fn, PeriodType.QUARTERLY, run_on_start=run_on_start, + hour=hour, minute=minute, month=start_month, day_of_month=day_of_month, **kwargs + ) + + def create_periodic_task(self, fn: Callable = None, period_type: PeriodType = PeriodType.WEEKLY, + run_on_start: bool = False, hour: int = 0, minute: int = 0, + weekday: Optional[int] = None, # 0=周一, 6=周日,仅 weekly 有效 + day_of_month: Optional[int] = None, # 1-31,仅 monthly/quarterly/yearly 有效 + month: Optional[int] = None, # 1-12,仅 quarterly/yearly 有效 + **kwargs + ): + """ + 通用周期任务工厂。 + + :param fn: 任务函数 + :param period_type: 周期类型 (weekly/monthly/yearly/quarterly) + :param run_on_start: 启动时是否立即运行一次 + :param hour: 时 (0-23) + :param minute: 分 (0-59) + :param weekday: 星期几 (0=周一, 6=周日),仅 period_type=weekly 时使用 + :param day_of_month: 每月几号 (1-31),仅 monthly/quarterly/yearly 时使用 + :param month: 月份 (1-12),仅 quarterly/yearly 时使用(quarterly 时表示起始季度月份) + """ + + def get_next_run_time(now: datetime.datetime) -> Optional[datetime.datetime]: + """根据规则计算下一次执行时间""" + + if period_type == PeriodType.WEEKLY: + if weekday is None: + raise ValueError("weekly 模式需要指定 weekday") + + # 计算目标星期 + days_ahead = (weekday - now.weekday()) % 7 + + # 如果今天就是目标星期 + if days_ahead == 0: + target_time = now.replace(hour=hour, minute=minute, second=0, microsecond=0) + # 如果今天的目标时间已过,则推迟到下周 + if target_time <= now and not run_on_start: + days_ahead = 7 + else: + return target_time + + next_date = now + relativedelta(days=days_ahead) + return datetime.datetime( + next_date.year, next_date.month, next_date.day, + hour, minute, 0 + ) + + elif period_type == PeriodType.MONTHLY: + if day_of_month is None: + raise ValueError("monthly 模式需要指定 day_of_month") + + # 获取当前月份的最后一天 + last_day = (now.replace(day=1) + relativedelta(months=1) - relativedelta(days=1)).day + target_day = min(day_of_month, last_day) + + candidate = now.replace(day=target_day, hour=hour, minute=minute, second=0, microsecond=0) + + # 如果候选时间已过,则下个月 + if candidate <= now and not run_on_start: + next_month = now + relativedelta(months=1) + last_day_next = (next_month.replace(day=1) + relativedelta(months=1) - relativedelta(days=1)).day + target_day_next = min(day_of_month, last_day_next) + candidate = next_month.replace(day=target_day_next, hour=hour, minute=minute, second=0, + microsecond=0) + + return candidate + + elif period_type == PeriodType.QUARTERLY: + if day_of_month is None or month is None: + raise ValueError("quarterly 模式需要指定 day_of_month 和 month(起始季度月份)") + + # 修正:计算季度的月份 + q_months = [] + for i in range(4): + qm = month + i * 3 + if qm > 12: + qm -= 12 + q_months.append(qm) + + # 查找下一个季度月 + target_month = None + target_year = now.year + + for qm in sorted(q_months): + if qm > now.month: + target_month = qm + break + + if target_month is None: + target_month = q_months[0] + target_year += 1 + + # 处理日期有效性 + last_day = (datetime.datetime(target_year, target_month, 1) + + relativedelta(months=1) - relativedelta(days=1)).day + target_day = min(day_of_month, last_day) + + candidate = datetime.datetime(target_year, target_month, target_day, hour, minute, 0) + + if candidate <= now and not run_on_start: + # 跳到下个季度 + candidate = candidate + relativedelta(months=3) + + return candidate + + elif period_type == PeriodType.YEARLY: + if day_of_month is None or month is None: + raise ValueError("yearly 模式需要指定 day_of_month 和 month") + + # 检查年份 + try: + candidate = datetime.datetime(now.year, month, day_of_month, hour, minute, 0) + except ValueError: + # 日期无效(如2月30日),取当月最后一天 + last_day = (datetime.datetime(now.year, month, 1) + + relativedelta(months=1) - relativedelta(days=1)).day + candidate = datetime.datetime(now.year, month, last_day, hour, minute, 0) + + if candidate <= now and not run_on_start: + try: + candidate = datetime.datetime(now.year + 1, month, day_of_month, hour, minute, 0) + except ValueError: + last_day = (datetime.datetime(now.year + 1, month, 1) + + relativedelta(months=1) - relativedelta(days=1)).day + candidate = datetime.datetime(now.year + 1, month, last_day, hour, minute, 0) + + return candidate + + async def task_warp(): + if fn is not None: + try: + _result = fn(**kwargs) + if isinstance(_result, Awaitable): + await _result + except Exception as e: + echo_log(msg=e, level=logging.ERROR, is_log_exc=True) + + async def task_loop(): + if fn is None: + return + + _log_next_time = True + _run_on_start = run_on_start + next_time = get_next_run_time(datetime.datetime.now()) + while self.is_running: + now = datetime.datetime.now() + if _run_on_start or now >= next_time: + await task_warp() + _run_on_start = False + # 执行完后,基于当前时间重新计算下一次 + next_time = get_next_run_time(now) + _log_next_time = True + else: + if _log_next_time and self.log_next_time: + delta = relativedelta(next_time, now) + echo_log( + f"距下次执行:{fn.__name__} 还有:{delta.days}天 {delta.hours}时 {delta.minutes}分 {delta.seconds}秒") + _log_next_time = False + await asyncio.sleep(1) + + _loop = self.event_loop() + return _loop.create_task(task_loop()) + + def add_task(self, creator: Callable, fn: Callable, **kwargs): + """ + 添加任务。注意:这里只是存储创建参数,直到任务启动前,才实际把任务创建出来。 + + :param creator: 任务工厂,对应:应延时任务工厂、日常任务工厂 + :param fn: 任务函数,即任务对应的实际功能函数 + :param kwargs: 任务函数的参数 + """ + _d = { + 'creator': creator, # 创建器,对应延时任务和日常任务 + 'fn': fn, 'kwargs': kwargs # 任务函数及其参数 + } + self._create_task_params.append(_d) + + def rebuild_task_list(self): + """ + 重建任务列表。 + :return: 任务列表 + """ + self.task_list.clear() + + # 遍历创建器列表,创建任务 + for _ctp in self._create_task_params: + _creator: Callable = _ctp.get('creator') + _fn: Callable = _ctp.get('fn') + _kwargs: dict = _ctp.get('kwargs') + _task = _creator(_fn, **_kwargs) + self.task_list.append(_task) + + return self.task_list + + async def run_tasks(self): + """ + 执行任务。支持多任务同时启动,如:预统计服务、设备数据同步服务等。 + """ + try: + self.rebuild_task_list() + echo_log(f'{self.service_name}启动成功.') + await asyncio.gather(*self.task_list) + except Exception as e: + echo_log(msg=e, level=logging.ERROR, is_log_exc=True) + + def start_service(self, env_check: bool = True): + """ + 以控制台服务方式启动服务。 + """ + echo_log(f'正在启动{self.service_name}...') + + try: + if env_check: + # 检测 MySQL 服务是否正常 + echo_log('检测 Database 服务...') + BaseAdapter.ping() + + # 注意,这里是取得协程 + _future = self.run_tasks() + # 开始执行任务事件循环 + _loop = self.event_loop() + _loop.run_until_complete(_future) + except KeyboardInterrupt: + echo_log(msg='KeyboardInterrupt') + self.stop_service() + except Exception as e: + echo_log(msg=e, level=logging.ERROR, is_log_exc=True) + + def stop_service(self): + """ + 停止服务。 + """ + self.is_running = False + echo_log(f'{self.service_name}已停止.') + + def start(self, env_check: bool = True): + """ + 以驻内存服务方式启动服务。 + """ + ds = DaemonizeService(pid_file=self.pid_file, name=f'{self.service_name}') + ds.set_start_callback(self.start_service, env_check=env_check) + ds.set_term_callback(self.stop_service) + ds.start() + + def stop(self, env_check: bool = True): + """ + 停止驻内存服务。 + """ + ds = DaemonizeService(pid_file=self.pid_file, name=f'{self.service_name}') + ds.set_start_callback(self.start_service, env_check=env_check) + ds.set_term_callback(self.stop_service) + ds.stop() diff --git a/paste/util/__init__.py b/paste/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paste/util/encoder.py b/paste/util/encoder.py new file mode 100644 index 0000000..ee62746 --- /dev/null +++ b/paste/util/encoder.py @@ -0,0 +1,185 @@ +import base64 +import binascii +import datetime +import decimal +import json +import re +import zlib +from typing import Union + +import numpy as np +from pandas._libs.tslibs.nattype import NaTType + +from paste.db.baseadapter import BaseAdapter +from paste.db.basemodel import LOCAL_DATETIME_FORMAT, LOCAL_DATE_FORMAT, LOCAL_TIME_FORMAT + + +class JsonDumpsEncoder(json.JSONEncoder): + """ + JSON 转字符串时对一些特殊类型进行转码(编码)方法。 + """ + + def default(self, obj): + if isinstance(obj, NaTType): + return '' + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, (np.floating, decimal.Decimal)): + return float(obj) + elif isinstance(obj, bytes): + return obj.decode(encoding='utf-8', errors='ignore') + elif isinstance(obj, datetime.datetime): + return obj.strftime(LOCAL_DATETIME_FORMAT) + elif isinstance(obj, datetime.date): + return obj.strftime(LOCAL_DATE_FORMAT) + elif isinstance(obj, datetime.time): + return obj.strftime(LOCAL_TIME_FORMAT) + elif isinstance(obj, BaseAdapter): + return obj.to_dict() + + return super().default(obj) + + +class BaseX: + """ + Base 编码解码方法,主要用于解码,针对编码数据自动检测编码类型。 + 能根据编码方式自动选择解码方法,同时在解码后尝试执行标准 Zip 解压。 + """ + + Type_Base16 = 'b16' + Type_Base32 = 'b32' + Type_Base64 = 'b64' + Type_Base85 = 'b85' + + Decoders = { + Type_Base16: base64.b16decode, + Type_Base32: base64.b32decode, + Type_Base64: base64.b64decode, + Type_Base85: base64.b85decode, + } + """ + 解码器。 + """ + + @classmethod + def base_x_detect(cls, data: Union[bytes, str]): + """ + 检测采用的内置 Base 编码种类。返回数据与以下编码方式对应:: + + 1、b16: Base16 + 2、b32: Base32 + 3、b64: Base64 + 4、b85: Base85 + + :param data: Base 编码数据,允许为字节流或字符串 + :return: 编码名称,全小写 + """ + if isinstance(data, bytes): + try: + data = data.decode() + except (UnicodeDecodeError, Exception): + return + + try: + _reg = re.compile("^[0-9A-F=]+$") + if _reg.match(data) is not None: + return cls.Type_Base16 + except (re.error, Exception): + pass + + try: + _reg = re.compile("^[A-Z2-7=]+$") + if _reg.match(data) is not None: + return cls.Type_Base32 + except (re.error, Exception): + pass + + try: + _reg = re.compile("^[A-Za-z0-9+/=]+$") + if _reg.match(data) is not None: + return cls.Type_Base64 + except (re.error, Exception): + pass + + try: + _reg = re.compile("^[A-Za-z0-9!#$%&()*+-;<=>?@^_`{|}~']+$") + if _reg.match(data) is not None: + return cls.Type_Base85 + except (re.error, Exception): + pass + + @classmethod + def base_x_decode(cls, data: Union[bytes, str], base_type: str = None): + """ + 自动检测编码种类后,解码 Base 编码数据。参数 base_type 与以下编码方式对应:: + + 1、b16: Base16 + 2、b32: Base32 + 3、b64: Base64 + 4、b85: Base85 + + 若在解码过程中发生异常,则返回原始数据。 + + :param data: Base 编码数据,允许为字节流或字符串 + :param base_type: Base 编码类型,如为 b64 则代表须用 Base64 解码 + :return: 解码后数据 + """ + _res_data = b'' + if isinstance(data, bytes): + try: + _tmp_data = data.decode() + _res_data = _tmp_data + except (UnicodeDecodeError, Exception): + return data + else: + _res_data = data + + if base_type not in cls.Decoders: + # 检测 Base 编码种类 + base_type = cls.base_x_detect(_res_data) + + if base_type in cls.Decoders: + try: + # 尝试 BaseX 解码 + _decoder = cls.Decoders.get(base_type) + _tmp_data = _decoder(_res_data) + _res_data = _tmp_data + except (binascii.Error, UnicodeDecodeError): + return data + + return _res_data + + @classmethod + def auto_decode_unzip(cls, data: Union[bytes, str], base_type: str = None): + """ + 对参数尝试执行自动 BaseX 解码 和 Zip 解压:: + + 1、若能 BaseX 解码,则执行解码,否则保持原始数据不变。 + 2、若能 Zip 解压,则执行解压,否则保持上一层数据不变。 + + 若各种解码方法都无法顺利解码,或在解码过程中发生异常,则返回原始数据。 + + 参数 base_type 与以下编码方式对应:: + + 1、b16: Base16 + 2、b32: Base32 + 3、b64: Base64 + 4、b85: Base85 + + :param data: Base64 数据,允许为字节流或字符串 + :param base_type: Base 编码类型,如为 b64 则代表须用 Base64 解码 + :return: 解码后数据 + """ + # 尝试 BaseX 解码 + _res_data = cls.base_x_decode(data, base_type) + + try: + # 尝试 Zip 解压 + _tmp_data = zlib.decompress(_res_data) + _res_data = _tmp_data + except (zlib.error, TypeError): + return _res_data + + return _res_data diff --git a/paste/util/pagination.py b/paste/util/pagination.py new file mode 100644 index 0000000..4637c18 --- /dev/null +++ b/paste/util/pagination.py @@ -0,0 +1,125 @@ +""" +基础分页程序,处理分页计算,后续应当扩展其功能。 +""" + + +class Pagination: + """ + 分页程序。 + """ + + def __init__(self, row_count: int): + """ + 初始化分页。 + + :param row_count: 总记录行数 + """ + self._offset = 0 + """ + 偏移量。 + """ + + self._pages = -1 + """ + 总页数。 + """ + + self._page_number = 1 + """ + 当前页码。 + """ + + self.row_count = row_count + """ + 数据行数。 + """ + + self.page_size = 20 + """ + 每页显示的数据量。默认 20 行每页。 + """ + + @property + def page_count(self): + """ + 取得页数。该属性必须在调用 :meth:`.pages` 方法后调用,例如:: + + >>> self.pages() + >>> self.page_count + + :return: 页数 + """ + return self._pages + + @property + def page_number(self): + """ + 取得当前页码。该属性必须在调用 :meth:`.number` 方法后调用, 例如:: + + >>> self.number(3) + >>> self.page_number + + :return: 页码 + """ + return self._page_number + + @property + def offset_size(self): + """ + 取得偏移量。 + """ + return self._offset + + def pages(self, page_size: int = 20): + """ + 计算页数。 + + :param page_size: 每页行数,必须处于 [1, 1000] 区间中。若不在此区间,则强制转换到此区间。默认每页 20 条。 + :return: 计算取得的页数。 + """ + page_size = 1 if page_size < 1 else page_size + page_size = 1000 if page_size > 1000 else page_size + self.page_size = page_size + + if self.row_count == 0: + self._pages = 1 + else: + _v1 = self.row_count / page_size + _v2 = self.row_count // page_size + self._pages = _v2 if _v1 == _v2 else _v2 + 1 + + return self._pages + + def number(self, page_number: int): + """ + 检查页码范围。 + + :param page_number: 页码 + :return: 正确页码 + """ + _pages = self.pages(self.page_size) + self._page_number = 1 if page_number < 1 else page_number + self._page_number = _pages if self._page_number > _pages else self._page_number + return self._page_number + + def offset(self, page_number: int): + """ + 偏移量。 + + :param page_number: 页码 + :return: 偏移量 + """ + self._offset = self.page_size * (page_number - 1) + return self._offset + + def paging(self, page_number: int = 1, page_size: int = 20): + """ + 分页计算,支持链式调用。 + + :params page_number 页码 + :params page_size 每页显示的数量 + :return self + """ + self.pages(page_size=page_size) + self.offset(self.number(page_number)) + return self diff --git a/paste/util/pdf.py b/paste/util/pdf.py new file mode 100644 index 0000000..c5f35cd --- /dev/null +++ b/paste/util/pdf.py @@ -0,0 +1,63 @@ +import mimetypes +import os +import urllib.parse +import urllib.request + +import weasyprint as wp + +from paste.core.logging import echo_log + + +class Html2Pdf: + """ + 将 HTML 内容转换为 PDF 文件。 + """ + + @classmethod + def custom_url_fetcher(cls, url, timeout=30, **kwargs): + """ + 自定义 URL 加载器,增加超时时间 + """ + # 处理 file:// URLs + if url.startswith('file://'): + parsed = urllib.parse.urlparse(url) + path = urllib.request.url2pathname(parsed.path) + + if not os.path.exists(path): + raise ValueError(f"File not found: {path}") + + _mime_type, _ = mimetypes.guess_type(path) + if not _mime_type: + _mime_type = 'application/octet-stream' + + return { + 'mime_type': _mime_type, + 'encoding': 'binary', + 'filename': os.path.basename(path), + 'file_obj': open(path, 'rb'), + } + + # 增加超时时间(默认是 30 秒) + return wp.default_url_fetcher(url, timeout=timeout, **kwargs) + + @classmethod + def write_pdf(cls, content, output_pdf=None, base_url=""): + """ + 将 HTML 转换为 PDF。 + + :param content: HTML 字符串 + :param output_pdf: 输出的 PDF 文件路径,默认为空 + :param base_url: 跨域默认地址 + """ + try: + # HTML 转换为 PDF + _html = wp.HTML(string=content, url_fetcher=cls.custom_url_fetcher, base_url=base_url) + _bytes = _html.write_pdf(output_pdf) + if output_pdf: + echo_log(f"PDF 已成功生成在: {output_pdf}.") + else: + echo_log(f"PDF 已成功生成.") + return _bytes + except Exception as e: + echo_log(f"转换失败: {e}") + raise e diff --git a/paste/util/snow_id.py b/paste/util/snow_id.py new file mode 100644 index 0000000..b00b2b8 --- /dev/null +++ b/paste/util/snow_id.py @@ -0,0 +1,126 @@ +""" +雪花 ID 生成程序。 +""" + +import time +import logging + + +# 64位ID的划分 +WORKER_ID_BITS = 5 +DATACENTER_ID_BITS = 5 +SEQUENCE_BITS = 12 + +# 最大取值计算 +MAX_WORKER_ID = -1 ^ (-1 << WORKER_ID_BITS) # 2**5-1 0b11111 +MAX_DATACENTER_ID = -1 ^ (-1 << DATACENTER_ID_BITS) + +# 移位偏移计算 +WORKER_ID_SHIFT = SEQUENCE_BITS +DATACENTER_ID_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS +TIMESTAMP_LEFT_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS + DATACENTER_ID_BITS + +# 序号循环掩码 +SEQUENCE_MASK = -1 ^ (-1 << SEQUENCE_BITS) + +# Twitter元年时间戳 +TW_EPOCH = 1288834974657 + +ID_WORKER = None + + +class InvalidSystemClock(Exception): + """ + 时钟回拨异常 + """ + pass + + +class IdWorker(object): + """ + 用于生成 Snow ID。 + """ + + @classmethod + def get_id_worker(cls, datacenter_id=1, worker_id=1, sequence=0): + """ + 创建 Snow ID 对象。 + + :param datacenter_id: 数据中心(机器区域)ID + :param worker_id: 机器ID + :param sequence: 起始序号 + """ + global ID_WORKER + if ID_WORKER is None: + ID_WORKER = IdWorker(datacenter_id, worker_id, sequence) + return ID_WORKER + + def __init__(self, datacenter_id, worker_id, sequence=0): + """ + 初始化。 + + :param datacenter_id: 数据中心(机器区域)ID + :param worker_id: 机器ID + :param sequence: 起始序号 + """ + # sanity check + if worker_id > MAX_WORKER_ID or worker_id < 0: + raise ValueError('worker_id值越界') + + if datacenter_id > MAX_DATACENTER_ID or datacenter_id < 0: + raise ValueError('datacenter_id值越界') + + self.worker_id = worker_id + self.datacenter_id = datacenter_id + self.sequence = sequence + + self.last_timestamp = -1 # 上次计算的时间戳 + + @staticmethod + def _gen_timestamp(): + """ + 生成整数时间戳。 + + :return:int timestamp + """ + return int(time.time() * 1000) + + def get_id(self): + """ + 获取新ID。 + + :return: 新的 Snow ID + """ + timestamp = self._gen_timestamp() + + # 时钟回拨 + if timestamp < self.last_timestamp: + logging.error(f"时钟正在向后倒转。拒绝请求直至 {self.last_timestamp}.") + raise InvalidSystemClock + + if timestamp == self.last_timestamp: + self.sequence = (self.sequence + 1) & SEQUENCE_MASK + if self.sequence == 0: + timestamp = self._til_next_millis(self.last_timestamp) + else: + self.sequence = 0 + + self.last_timestamp = timestamp + + new_id = ((timestamp - TW_EPOCH) << TIMESTAMP_LEFT_SHIFT) | (self.datacenter_id << DATACENTER_ID_SHIFT) | \ + (self.worker_id << WORKER_ID_SHIFT) | self.sequence + return new_id + + def _til_next_millis(self, last_timestamp): + """ + 等到下一毫秒。 + """ + timestamp = self._gen_timestamp() + while timestamp <= last_timestamp: + timestamp = self._gen_timestamp() + return timestamp + + +if __name__ == '__main__': + worker = IdWorker(1, 1, 0) + print(worker.get_id()) diff --git a/paste/util/svg.py b/paste/util/svg.py new file mode 100644 index 0000000..76b06d7 --- /dev/null +++ b/paste/util/svg.py @@ -0,0 +1,704 @@ +import re +from typing import Optional + +import svgwrite +from svgwrite.container import Group +from svgwrite.path import Path +from svgwrite.shapes import Rect +from svgwrite.text import Text + +from paste.util import ustr + + +class TextRect(Group): + """ + 可显示文本的矩形。 + """ + + def __init__(self, text, insert, text_extra: dict = None, rect_extra: dict = None, **extra): + # 父类初始化 + super().__init__(**extra) + + self.text = text + """ + 要显示的文本内容。 + """ + + self.extra = extra + """ + 组合扩展信息。 + """ + + self.rectExtra = rect_extra if rect_extra is not None else {} + """ + 外框扩展信息。 + """ + + self.textExtra = text_extra if text_extra is not None else {} + """ + 文本扩展信息。 + """ + + self.rectInsert = insert + """ + 整体位置参数,即外框的位置参数。 + """ + + # 初始化文本尺寸 + _fs = self.font_size + + self.textInsert = self.text_pos + """ + 文本位置参数。 + """ + + # 文本初始化 + self.textElement = Text(self.text, insert=self.textInsert, **self.textExtra) + # 矩形初始化 + self.rectElement = Rect(insert=self.rectInsert, size=self.rect_size, **self.rectExtra) + + # 加入元素 + self.add(self.rectElement) + self.add(self.textElement) + + @property + def font_size(self): + """ + 从样式中识别字体大小,单位用像素,缺省 14px。 + + :return: 字体大小 + """ + _font_size = self.textExtra.get('font-size', self.extra.get('font-size', f"{14}px")) + self.textExtra['font-size'] = _font_size + _size = re.sub(r'\D', '', _font_size.strip()) + return int(_size) + + @property + def text_width(self): + """ + 文本宽度(近似)。 + """ + total = len(self.text) + q_count = ustr.str_q_count(self.text) + return q_count * self.font_size + (total - q_count) * self.font_size * 0.5 + + @property + def text_height(self): + """ + 文本高度(近似)。 + """ + return self.font_size * 1.2 + + @property + def rect_width(self): + """ + 外框宽度。 + """ + return self.text_width + self.font_size * 1.5 + + @property + def rect_height(self): + """ + 外框高度。 + """ + return self.text_height * 1.9 + + @property + def rect_size(self): + """ + 外框尺寸。 + """ + return self.rect_width, self.rect_height + + @property + def text_pos(self): + """ + 文本位置。 + """ + return \ + self.rectInsert[0] + (self.rect_width - self.text_width) * 0.5, \ + self.rectInsert[1] + self.text_height * 1.25 + + def reposition(self, position: tuple): + """ + 重新定位。 + + :param position: 位置坐标 + """ + self.rectInsert = position + self.rectElement.attribs['x'] = self.rectInsert[0] + self.rectElement.attribs['y'] = self.rectInsert[1] + + self.textInsert = self.text_pos + self.textElement.attribs['x'] = self.textInsert[0] + self.textElement.attribs['y'] = self.textInsert[1] + + def point_bottom(self): + """ + 底部点。 + """ + return self.rectInsert[0] + self.rect_width / 2, self.rectInsert[1] + self.rect_size[1] + + def point_top(self): + """ + 顶部点。 + """ + return self.rectInsert[0] + self.rect_width / 2, self.rectInsert[1] + + def point_left(self): + """ + 左侧点。 + """ + return self.rectInsert[0], self.rectInsert[1] + self.rect_height / 2 + + def point_right(self): + """ + 右侧点。 + """ + return self.rectInsert[0] + self.rect_width, self.rectInsert[1] + self.rect_height / 2 + + @classmethod + def horizontal_path(cls, start: tuple, end: tuple, **extra): + """ + 生成水平方向连接线。 + + :param start: 起点坐标 + :param end: 终点坐标 + :param extra: 扩展参数 + :return: 路径对象 + """ + _p_control = [ + (start[0] + end[0]) * 0.5, + start[1] + ] + + _p_center = [ + (start[0] + end[0]) * 0.5, + (start[1] + end[1]) * 0.5 + ] + + _path = Path(**extra) + + _path.push(['M', start]) + _path.push(['Q', _p_control + _p_center]) + _path.push(['T', end]) + + return _path + + @classmethod + def vertical_path(cls, start: tuple, end: tuple, **extra): + """ + 生成垂直方向连接线。 + + :param start: 起点坐标 + :param end: 终点坐标 + :param extra: 扩展参数 + :return: 路径对象 + """ + _p_control = [ + start[0], + (start[1] + end[1]) * 0.5 + ] + + _p_center = [ + (start[0] + end[0]) * 0.5, + (start[1] + end[1]) * 0.5 + ] + + _path = Path(**extra) + + _path.push(['M', start]) + _path.push(['Q', _p_control + _p_center]) + _path.push(['T', end]) + + return _path + + def choose_point(self, sibling: 'TextRect'): + """ + 选择与目标文本框的连线点。 + + 返回起点(tuple)在自生文本框上,终点(tuple)在目标文本框上。 + + :param sibling: 目标文本框 + :return: 起点、终点、是否水平连线 + """ + _start = self.point_bottom() + _end = sibling.point_top() + _is_horizontal = True + + if self.point_bottom()[1] > sibling.point_top()[1]: + if self.point_right()[0] < sibling.point_left()[0]: + _start = self.point_right() + _end = sibling.point_left() + _is_horizontal = False + elif self.point_left()[0] > sibling.point_right()[0]: + _start = self.point_left() + _end = sibling.point_right() + _is_horizontal = False + else: + _start = self.point_top() + _end = sibling.point_bottom() + _is_horizontal = True + elif self.point_top()[1] < sibling.point_bottom()[1]: + if self.point_right()[0] < sibling.point_left()[0]: + _start = self.point_right() + _end = sibling.point_left() + _is_horizontal = False + elif self.point_left()[0] > sibling.point_right()[0]: + _start = self.point_left() + _end = sibling.point_right() + _is_horizontal = False + else: + _start = self.point_bottom() + _end = sibling.point_top() + _is_horizontal = True + else: + if self.point_right()[0] < sibling.point_left()[0]: + _start = self.point_right() + _end = sibling.point_left() + _is_horizontal = False + elif self.point_left()[0] > sibling.point_right()[0]: + _start = self.point_left() + _end = sibling.point_right() + _is_horizontal = False + else: + _start = self.point_bottom() + _end = sibling.point_top() + _is_horizontal = True + + return _start, _end, _is_horizontal + + def connect(self, sibling: 'TextRect', **extra): + """ + 取得连接路径。 + + :param sibling: 目标文本框 + :param extra: 连线扩展参数 + :return: 连接路径 + """ + _start, _end, _is_horizontal = self.choose_point(sibling) + if _is_horizontal: + return self.vertical_path(_start, _end, **extra) + else: + return self.horizontal_path(_start, _end, **extra) + + +class RelationGraph: + """ + SVG 关系图。 + + 根据 title 名称和 row_list 列表数据输出 svg 格式的关系图谱。 + """ + + def __init__(self, filename: str = 'noname.svg'): + self.filename = filename + + self.width = 800 + """ + 画布宽度。 + """ + self.height = 600 + """ + 画布高度。 + """ + self.vhSpace = 170 + """ + 内容垂直浮动空间。 + """ + self.lrSpace = 100 + """ + 内容左右留白空间。 + """ + + self.titleTextExtra = { + 'font-size': '16px', 'fill': 'rgb(255, 255, 255)' + } + """ + 标题文本样式。 + """ + + self.titleRectExtra = { + 'rx': 10, 'ry': 10, 'fill': 'rgb(233, 72, 41)', 'fill-opacity': 1, 'stroke': 'rgb(233, 72, 41)' + } + """ + 标题外框样式 + """ + + self.textExtra = { + 'font-size': '14px', 'fill': 'rgb(255, 255, 255)' + } + """ + 普通文本样式。 + """ + + self.rectExtra = { + 'rx': 10, 'ry': 10, 'fill': 'rgb(65, 130, 164)', 'fill-opacity': 1, 'stroke': 'rgb(65, 130, 164)' + } + """ + 普通文本外框样式。 + """ + + self.pathExtra = { + 'fill': 'none', 'stroke': 'rgb(65, 130, 164)' + } + """ + 连线样式。 + """ + + self.drawing = svgwrite.Drawing(filename=self.filename) + """ + 主绘图对象。 + """ + + self.attribs = self.drawing.attribs + """ + 图像样式。 + """ + + self.save = self.drawing.save + """ + 保存文件方法。 + """ + + self.attribs.update({ + 'width': self.width, 'height': self.height + }) + + self.titleTr: Optional[TextRect] = None + """ + 标题文本框对象。 + """ + + def draw(self, title: str, row_list: list[dict]): + """ + 绘制图形。 + + :param row_list: 数据对象列表,必须包含 unit_name, unit_uscid, enterprise_id 三个字段 + :param title: 标题文本 + :return: 自身对象 + """ + # 重定设图像参数 + self.attribs.update(self.attribs) + + # 创建标题文本框 + self.titleTr = TextRect( + text=title, insert=(0, 0), text_extra=self.titleTextExtra, rect_extra=self.titleRectExtra, **{ + 'debug': False + } + ) + self.titleTr.reposition(( + (self.width - self.titleTr.rect_width) * 0.5, (self.height - self.titleTr.rect_height) * 0.5 - 20 + )) + self.drawing.add(self.titleTr) + + _tr_list: list[TextRect] = [] + for _i, _row in enumerate(row_list): + # 遍历数据,初始创建所有的文本框,得到文本框尺寸信息 + # 同时保留所有需要输出的数据 + _text = f"{_row['short_name']} ({_row['count']})" + _tr = TextRect( + text=_text, insert=(0, 0), rect_extra=self.rectExtra, **self.textExtra, **{ + 'debug': False, + 'data-name': _row['unit_name'], + 'data-uscid': _row['unit_uscid'], + 'data-enterprise-id': _row['enterprise_id'], + } + ) + _tr_list.append(_tr) + + _harf = int(len(_tr_list) / 2) if int(len(_tr_list) % 2) == 0 else int(len(_tr_list) / 2) + 1 + _top_list = [] + _lft_list = _tr_list[:_harf] + _rit_list = _tr_list[_harf:] + _btm_list = [] + + if len(_tr_list) >= 12: + _top_list = _lft_list[:2] + _lft_list = _lft_list[2:] + if len(_tr_list) >= 14: + _btm_list = _rit_list[-2:] + _rit_list = _rit_list[:-2] + + # 遍历所有顶部文本框,重新定位 + for _i, _tr in enumerate(_top_list): + if _i == 0: + _position = ( + self.titleTr.point_top()[0] - _tr.rect_width - 15, + self.titleTr.point_top()[1] - self.vhSpace - _tr.rect_height - 15 + ) + else: + _position = ( + self.titleTr.point_top()[0] + 15, + self.titleTr.point_top()[1] - self.vhSpace - _tr.rect_height - 15 + ) + _tr.reposition(_position) + + # 遍历所有底部文本框,重新定位 + for _i, _tr in enumerate(_btm_list): + if _i == 0: + _position = ( + self.titleTr.point_bottom()[0] - _tr.rect_width - 15, + self.titleTr.point_bottom()[1] + self.vhSpace + _tr.rect_height + 15 + ) + else: + _position = ( + self.titleTr.point_bottom()[0] + 15, + self.titleTr.point_bottom()[1] + self.vhSpace + _tr.rect_height + 15 + ) + _tr.reposition(_position) + + _top = self.titleTr.point_top()[1] - self.vhSpace + # 遍历所有左则文本框,重新定位 + for _tr in _lft_list: + _w = _tr.rect_width + _h = _tr.rect_height + _space = self.titleTr.point_bottom()[1] - self.titleTr.point_top()[1] + self.vhSpace * 2 + _h + + _margin = 0 + if len(_lft_list) > 1: + _margin = (_space - len(_lft_list) * _h) / (len(_lft_list) - 1) + + _left = self.titleTr.point_left()[0] - _w - self.lrSpace + _position = (_left, _top) + _tr.reposition(_position) + if _tr.point_left()[0] < 20: + _left = 20 + _position = (_left, _top) + _tr.reposition(_position) + _top += _h + _margin + + _top = self.titleTr.point_top()[1] - self.vhSpace + # 遍历所有右侧文本框,重新定位 + for _tr in _rit_list: + _w = _tr.rect_width + _h = _tr.rect_height + _space = self.titleTr.point_bottom()[1] - self.titleTr.point_top()[1] + self.vhSpace * 2 + _h + + _margin = 0 + if len(_rit_list) > 1: + _margin = (_space - len(_rit_list) * _h) / (len(_rit_list) - 1) + + _left = self.titleTr.point_right()[0] + self.lrSpace + _position = (_left, _top) + _tr.reposition(_position) + if _tr.point_right()[0] > self.width - 20: + _left = self.width - _tr.rect_width - 20 + _position = (_left, _top) + _tr.reposition(_position) + + _top += _h + _margin + + for _tr in _tr_list: + self.drawing.add(self.titleTr.connect(_tr, **self.pathExtra)) + + for _tr in _tr_list: + self.drawing.add(_tr) + + +class EnterpriseGraph: + """ + SVG 企业汇总信息图。 + + 根据 title 名称和 row_list 列表数据输出 svg 格式的关系图谱。 + """ + + def __init__(self, filename: str = 'noname.svg'): + self.filename = filename + + self.width = 800 + """ + 画布宽度。 + """ + self.height = 300 + """ + 画布高度。 + """ + self.vhSpace = 50 + """ + 内容垂直浮动空间。 + """ + self.lrSpace = 100 + """ + 内容左右留白空间。 + """ + + self.titleTextExtra = { + 'font-size': '16px', 'fill': 'rgb(255, 255, 255)' + } + """ + 标题文本样式。 + """ + + self.titleRectExtra = { + 'rx': 10, 'ry': 10, 'fill': 'rgb(233, 72, 41)', 'fill-opacity': 1, 'stroke': 'rgb(233, 72, 41)' + } + """ + 标题外框样式 + """ + + self.textExtra = { + 'font-size': '14px', 'fill': 'rgb(255, 255, 255)' + } + """ + 普通文本样式。 + """ + + self.rectExtra = { + 'rx': 10, 'ry': 10, 'fill': 'rgb(65, 130, 164)', 'fill-opacity': 1, 'stroke': 'rgb(65, 130, 164)' + } + """ + 普通文本外框样式。 + """ + + self.pathExtra = { + 'fill': 'none', 'stroke': 'rgb(65, 130, 164)' + } + """ + 连线样式。 + """ + + self.drawing = svgwrite.Drawing(filename=self.filename) + """ + 主绘图对象。 + """ + + self.attribs = self.drawing.attribs + """ + 图像样式。 + """ + + self.save = self.drawing.save + """ + 保存文件方法。 + """ + + self.attribs.update({ + 'width': self.width, 'height': self.height + }) + + self.titleTr: Optional[TextRect] = None + """ + 标题文本框对象。 + """ + + def draw(self, title: str, data_item: dict): + """ + 绘制图形。 + + :param data_item: 数据项字典,中文名称:数据值 + :param title: 标题文本 + :return: 自身对象 + """ + # 重定设图像参数 + self.attribs.update(self.attribs) + + # 创建标题文本框 + self.titleTr = TextRect( + text=title, insert=(0, 0), text_extra=self.titleTextExtra, rect_extra=self.titleRectExtra, **{ + 'debug': False + } + ) + self.titleTr.reposition(( + (self.width - self.titleTr.rect_width) * 0.5, (self.height - self.titleTr.rect_height) * 0.5 - 20 + )) + self.drawing.add(self.titleTr) + + _tr_list: list[TextRect] = [] + for _key, _val in data_item.items(): + # 遍历数据,初始创建所有的文本框,得到文本框尺寸信息 + # 同时保留所有需要输出的数据 + _text = f"{_key}:{_val}" + _tr = TextRect( + text=_text, insert=(0, 0), rect_extra=self.rectExtra, **self.textExtra, **{ + 'debug': False, + } + ) + _tr_list.append(_tr) + + _harf = int(len(_tr_list) / 2) if int(len(_tr_list) % 2) == 0 else int(len(_tr_list) / 2) + 1 + _top_list = [] + _lft_list = _tr_list[:_harf] + _rit_list = _tr_list[_harf:] + _btm_list = [] + + if len(_tr_list) >= 12: + _top_list = _lft_list[:2] + _lft_list = _lft_list[2:] + if len(_tr_list) >= 14: + _btm_list = _rit_list[-2:] + _rit_list = _rit_list[:-2] + + # 遍历所有顶部文本框,重新定位 + for _key, _tr in enumerate(_top_list): + if _key == 0: + _position = ( + self.titleTr.point_top()[0] - _tr.rect_width - 15, + self.titleTr.point_top()[1] - self.vhSpace - _tr.rect_height - 15 + ) + else: + _position = ( + self.titleTr.point_top()[0] + 15, + self.titleTr.point_top()[1] - self.vhSpace - _tr.rect_height - 15 + ) + _tr.reposition(_position) + + # 遍历所有底部文本框,重新定位 + for _key, _tr in enumerate(_btm_list): + if _key == 0: + _position = ( + self.titleTr.point_bottom()[0] - _tr.rect_width - 15, + self.titleTr.point_bottom()[1] + self.vhSpace + _tr.rect_height + 15 + ) + else: + _position = ( + self.titleTr.point_bottom()[0] + 15, + self.titleTr.point_bottom()[1] + self.vhSpace + _tr.rect_height + 15 + ) + _tr.reposition(_position) + + _top = self.titleTr.point_top()[1] - self.vhSpace + # 遍历所有左则文本框,重新定位 + for _tr in _lft_list: + _w = _tr.rect_width + _h = _tr.rect_height + _space = self.titleTr.point_bottom()[1] - self.titleTr.point_top()[1] + self.vhSpace * 2 + _h + + _margin = 0 + if len(_lft_list) > 1: + _margin = (_space - len(_lft_list) * _h) / (len(_lft_list) - 1) + + _left = self.titleTr.point_left()[0] - _w - self.lrSpace + _position = (_left, _top) + _tr.reposition(_position) + if _tr.point_left()[0] < 20: + _left = 20 + _position = (_left, _top) + _tr.reposition(_position) + _top += _h + _margin + + _top = self.titleTr.point_top()[1] - self.vhSpace + # 遍历所有右侧文本框,重新定位 + for _tr in _rit_list: + _w = _tr.rect_width + _h = _tr.rect_height + _space = self.titleTr.point_bottom()[1] - self.titleTr.point_top()[1] + self.vhSpace * 2 + _h + + _margin = 0 + if len(_rit_list) > 1: + _margin = (_space - len(_rit_list) * _h) / (len(_rit_list) - 1) + + _left = self.titleTr.point_right()[0] + self.lrSpace + _position = (_left, _top) + _tr.reposition(_position) + if _tr.point_right()[0] > self.width - 20: + _left = self.width - _tr.rect_width - 20 + _position = (_left, _top) + _tr.reposition(_position) + + _top += _h + _margin + + for _tr in _tr_list: + self.drawing.add(self.titleTr.connect(_tr, **self.pathExtra)) + + for _tr in _tr_list: + self.drawing.add(_tr) diff --git a/paste/util/tail_read.py b/paste/util/tail_read.py new file mode 100644 index 0000000..94c0daf --- /dev/null +++ b/paste/util/tail_read.py @@ -0,0 +1,164 @@ +import os +from typing import Optional + +from paste.core import config + + +class TailRead: + """ + 文件逆向读取器。 + 主要针对读取日志文件。当遇到大日志文件时,需要从后向前读取,这样读取的速度更快。 + 当日志文件动态增加时,再正向读取,此时仅读取差异内容,实现小数据量交互。 + """ + + @classmethod + def logReader(cls, log_fn: str = None): + """ + 取得配置文件设置的日志文件读取器。 + + :param log_fn: 日志文件名 + :return: 默认日志文件读取器 + """ + if log_fn is None: + log_fn = config.get_config('logger.filename') + return TailRead(fn=log_fn) + + def __init__(self, fn: str): + self.file_name = fn + """ + 要读取的文件名。 + """ + + self.file_io = open(self.file_name, 'rb') + """ + 文件 IO 对象。 + """ + + self.current_position: Optional[int] = None + """ + 当前读取点位置。 + """ + + _f_size = self.size() + if _f_size > 1: + # 移动到文件末尾 + self.file_io.seek(_f_size - 1) + else: + self.file_io.seek(0) + + def size(self): + """ + 取得文件大小。 + + :return: 文件大小 + """ + return os.path.getsize(self.file_name) + + def readTail(self, lines: int = 100): + """ + 从文件中,逆向读取 lines 行。读取结束后,将读取定位移动到所有读取到的字节的最后。 + + :param lines: 读取的行数 + :return: 读取到的数据,读取完成点 + """ + _buffer: bytes = b'' + _c_pos = self.file_io.tell() + + # 从当前位置读取一位,判断是否是回车 + # 若是,则增加一行,确保读取足够的行数 + _byte = self.file_io.read(1) + if _byte == b'\n': + lines += 1 + # 重新回到原始位置 + self.file_io.seek(_c_pos) + + _r_pos = _c_pos + while lines > 0: + # 读取一个字节 + _byte = self.file_io.read(1) + if _byte == b'': + # 无数据,退出 + break + if _byte == b'\n': + # 减少行数 + lines -= 1 + # 逆向前移 + _r_pos -= 1 + if _r_pos <= 0: + # 超出第一位时,退出 + break + self.file_io.seek(_r_pos) + # 加入缓存 + _buffer = _byte + _buffer + + # 扣除首字节回车符号 + if _buffer[0:1] == b'\n': + _buffer = _buffer[1:] + + # 第一位已经读取,因此正向移动一位 + self.current_position = _c_pos + 1 + self.file_io.seek(self.current_position) + return _buffer, self.current_position + + def readLines(self, lines: int = 100, crt_pos: int = None): + """ + 读取文件数据,默认读取 100 行。当不传入 crt_pos 时逆向读取,传入时正向读取。 + 具有动态方向,确保第一次是最大量读取,以后每次都是增量读取,减少传递的数据量。 + + :param lines: 要读取的行数 + :param crt_pos: 当前读取位置 + :return: + """ + _buffer: bytes = b'' + if crt_pos is None: + # 参数 cur_pos 为 None 时,逆向读取 + self.file_io.seek(self.size()-1) + _buffer, crt_pos = self.readTail(lines) + else: + # 参数 cur_pos 有值时,正向读取 + self.file_io.seek(crt_pos) + while lines: + _bytes = self.file_io.readline() + if _bytes == b'': + # 无数据,退出 + break + else: + # 减少行数 + lines -= 1 + # 加入缓存 + _buffer += _bytes + crt_pos = self.file_io.tell() + + self.current_position = crt_pos + return _buffer, self.current_position + + def read(self, lines: int = 200, crt_pos: int = None): + """ + 读取文件数据。注意:: + + 1、首次读取时 crt_pos 应为 None 此时逆向读取,返回读取到的数据流和读取点位置。 + 2、当有 crt_pos 参数时,先检查文件是否发生了变化,若文件变大,则正向读取增量部分,若文件变小则置空。 + 3、若有 crt_pos 且文件没有发生变化,则返回空字节流,读取位置不变。 + + :param lines: 要读取的最大行数,默认 200 行 + :param crt_pos: 当前读取位置,为 None 时逆向读取,否则正向读取 + :return: 读取到的字节流 + """ + _buffer = b'' + if crt_pos is None: + # 参数 cur_pos 为 None 时,逆向读取 + _buffer, crt_pos = self.readLines(lines, crt_pos=crt_pos) + else: + # 参数 cur_pos 有值时 + # 检查文件是否发生了变化 + _log_size = self.size() + if _log_size > crt_pos + 1: + # 内容增加,继续正向读取 + _buffer, crt_pos = self.readLines(lines, crt_pos=crt_pos) + elif _log_size < crt_pos: + # 内容减少,置空读取位置 + # 置空后,再次调用本函数,执行逆向读取 + crt_pos = None + + self.current_position = crt_pos + return _buffer, self.current_position diff --git a/paste/util/udict.py b/paste/util/udict.py new file mode 100644 index 0000000..7f7cacd --- /dev/null +++ b/paste/util/udict.py @@ -0,0 +1,40 @@ +from typing import Any, Optional, Dict + + +def get_with_default(dict_obj: dict, key: Any, default: Optional[Any] = None): + """ + 从字典中取得对应的值,若为 None,则返回默认值。 + 注意,字典自带 get 方法是当 key 存在,则返回对应的值,无论是否为 None; + 而该方法是无论 key 是否存在,只要值为 None 均返回默认值。 + + :param dict_obj: 字典对象 + :param key: 键 + :param default: 默认值 + """ + _val = dict_obj.get(key, default) + if _val is None: + _val = default + return _val + + +def get_by_path(dict_obj: Dict[str, Any], path: str, default: Optional[Any] = None): + """ + 按路径取得字典中的数据。要求路径指向的也必须是字典,除最后一项。 + + :param dict_obj: 字典对象 + :param path: 字典中的 key 路径,以"."号分隔 + :param default: 默认值 + :return: + """ + _dict: Optional[Dict[str, Any]] = dict_obj + _keys = path.split(".") + + if len(_keys) > 1: + # 遍历到倒数第二项 + for _key in _keys[:-1]: + _dict = _dict.get(_key, None) + if not isinstance(_dict, dict): + return default + + # 返回最后一项内容 + return _dict.get(_keys[-1], default) \ No newline at end of file diff --git a/paste/util/ufile.py b/paste/util/ufile.py new file mode 100644 index 0000000..891ef0b --- /dev/null +++ b/paste/util/ufile.py @@ -0,0 +1,293 @@ +import base64 +import datetime +import io +import os +import re +import unicodedata +from typing import Optional, IO, Union + +import cv2 +import numpy as np +from PIL import Image + +file_types = { + 'jpeg': (b'\xFF\xD8\xFF', b'\xff\xd8\xff'), + 'png': (b'\x89PNG',), + 'gif': (b'GIF8',), + 'bmp': (b'BM',), + 'tiff': (b'II*\x00', b'MM\x00*'), + 'webp': (b'RIFF\x00\x00\x00\x00WEBP',), + 'ico': (b'\x00\x00\x01\x00',), + 'psd': (b'8BPS',), + 'svg': (b' 1024*2: + file_data = file_data[:1024*2] + + file_type: Optional[Union[str, tuple[str]]] = '' + for _key, _val in file_types.items(): + for _bs in _val: + if file_data.startswith(_bs): + file_type = _key + break + if file_type: + break + + if isinstance(file_type, tuple): + if file_type[0] == 'doc': + # 使用读取到的全部数据(≤4KB)进行启发式判断 + file_type = _heuristic_office_type(file_data) + elif file_type[0] == 'docx': + # 使用读取到的全部数据(≤4KB)进行启发式判断 + file_type = _heuristic_office_x_type(file_data) + + return file_type + + +def _heuristic_office_type(data: bytes) -> str: + """ + 仅基于前 4KB 数据,启发式判断是 .doc、.xls 还是 .ppt + 依据:各格式在 OLE 结构中的典型字符串偏移位置 + """ + # 关键词及其对应类型 + patterns = [ + (b'W\x00o\x00r\x00d\x00D\x00o\x00c\x00u\x00m\x00e\x00n\x00t', 'doc'), + (b'W\x00o\x00r\x00d', 'doc'), + (b'WordDocument', 'doc'), + (b'Word', 'doc'), + (b'W\x00o\x00r\x00k\x00b\x00o\x00o\x00k', 'xls'), + (b'B\x00o\x00o\x00k', 'xls'), + (b'Workbook', 'xls'), + (b'Book', 'xls'), + (b'P\x00o\x00w\x00e\x00r\x00P\x00o\x00i\x00n\x00t', 'ppt'), + (b'PowerPoint', 'ppt'), + ] + # 一次性遍历:在 data 中查找任一关键词 + # 由于模式短,且数据小(≤4KB),用简单循环即可 + for keyword, file_type in patterns: + if keyword in data: + return file_type + # 未匹配时,保守返回 "" + return "" + + +def _heuristic_office_x_type(data: bytes) -> str: + """ + 仅用 `in` 判断 .docx/.xlsx/.pptx,精准匹配 Open XML 标准 MIME 类型 + 不解压、不解析、不猜,就看有没有那三个关键字符串 + """ + # 关键词及其对应类型 + patterns = [ + (b'word/PK', 'docx'), + (b'xl/PK', 'xlsx'), + (b'ppt/PK', 'pptx'), + ] + # 一次性遍历:在 data 中查找任一关键词 + # 由于模式短,且数据小(≤4KB),用简单循环即可 + for keyword, file_type in patterns: + if keyword in data: + return file_type + # 未匹配时,保守返回 "" + return "" + + +def get_file_info(file_path): + """ + 取得文件信息,包括:文件大小、创建时间。 + + :param file_path: 文件绝对路径 + :return: 大小,创建时间 + """ + _ctime = datetime.datetime.fromtimestamp(os.path.getctime(file_path)) + _ctime = _ctime.strftime('%Y-%m-%d %H:%M:%S') + + _f_size = os.path.getsize(file_path) + + # 将字节转换为 KB + _size_kb = _f_size / 1024 + if _size_kb < 1024: + return f"{_size_kb:.2f} KB", _ctime + + # 将 KB 转换为 MB + _size_mb = _size_kb / 1024 + if _size_mb < 1024: + return f"{_size_mb:.2f} MB", _ctime + + # 将 MB 转换为 GB + _size_gb = _size_mb / 1024 + return f"{_size_gb:.2f} GB", _ctime + + +def read_to_buffer(file) -> bytes: + """ + 以二进制只读方式从文件载入数据到字节流。 + """ + assert os.path.isfile(file), 'File not found: %s' % file + with open(file, 'rb') as f: + buf = f.read(os.path.getsize(file)) + f.close() + return buf + + +def sanitize_filename(filename: str) -> str: + """ + 统一严格过滤文件名中的非法字符(跨 Windows/Linux/macOS 安全)。 + + 规则: + 1. 过滤所有系统禁止的字符(包括控制字符 \x00-\x1f) + 2. 处理 Windows 保留名称(如 CON、NUL 等) + 3. 替换空格和 # 为下划线 + 4. 禁止以空格或点开头/结尾 + 5. 限制文件名长度(255 字符) + + :param filename: 文件名 + :return: 替换非法字符为 _ 的安全文件名 + """ + # 1. Unicode 规范化(防止混淆攻击) + filename = unicodedata.normalize("NFKC", filename) + + # 2. 替换所有非法字符为下划线(包括空格和 #) + # 包括:\ / : * ? " < > | \x00-\x20(控制字符和空格)# + safe_name = re.sub(r'[\\/:*?"<>|\x00-\x20#]', '_', filename) + + # 3. 处理 Windows 保留名称(如 CON.txt -> _CON.txt) + win_reserved = [ + "CON", "PRN", "AUX", "NUL", + "COM1", "COM2", "COM3", "COM4", + "LPT1", "LPT2", "LPT3", "CLOCK$" + ] + if safe_name.upper().split(".")[0] in win_reserved: + safe_name = f"_{safe_name}" + + # 4. 移除首尾空格和点(避免隐形问题) + safe_name = safe_name.strip(". ") + + # 5. 确保文件名非空(如果输入全是非法字符) + if not safe_name: + safe_name = "unnamed_file" + + # 6. 限制长度(Windows 最大 255 字符) + return safe_name[:255] + + +def check_and_create_dir(file_path, mode=0o777, exist_ok=False): + """ + 根据传入的文件路径检查目录是否存在,若不存在,则创建。 + + :param file_path: 文件路径 + :param mode: 目录权限,默认最高权限 + :param exist_ok: 目录存在时,是否抛出异常,默认不抛出 + """ + # 获取文件所在目录 + _directory = os.path.dirname(file_path) + if not os.path.exists(_directory): + # 如果目录不存在,则创建目录 + os.makedirs(_directory, mode, exist_ok) + + +def load_image_from_base64(base64_str): + """ + 将 Base64 字符串转换为 face_recognition 可用的 numpy 数组。 + + :param base64_str: 经过 Base64 编码的图像数据 + :return: 图像数据(numpy数组) + """ + try: + # 移除 Base64 头部(如果存在) + if "," in base64_str: + base64_str = base64_str.split(",")[1] + + # 解码为二进制 + image_data = base64.b64decode(base64_str) + + # 验证图像完整性 + Image.open(io.BytesIO(image_data)).verify() + # 转换为 RGB numpy 数组 + image = Image.open(io.BytesIO(image_data)) + if image.mode != "RGB": + image = image.convert("RGB") + + return np.array(image) + except Exception as e: + raise ValueError(f"无效的 Base64 图像数据: {e}") + + +def load_png_from_base64(base64_str): + """ + 从Base64字符串读取PNG图像并保留Alpha通道。 + + :param base64_str: 经过 Base64 编码的 PNG 图像数据 + :return: 图像数据(numpy数组),包含BGRA四个通道 + """ + try: + # 1. 解码Base64字符串 + img_data = base64.b64decode(base64_str) + + # 2. 将字节数据转换为numpy数组 + np_array = np.frombuffer(img_data, np.uint8) + + # 3. 使用IMREAD_UNCHANGED标志解码图像以保留Alpha通道 + img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) + + # 4. 检查是否成功读取 + if img is None: + raise ValueError("无法解码图像数据") + + # 5. 检查是否有Alpha通道 + if img.shape[2] != 4: + print("警告: 图像没有Alpha通道,将添加全不透明Alpha通道") + # 将BGR转换为BGRA,添加全不透明Alpha通道 + img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA) + + return img + + except Exception as e: + print(f"读取Base64图像时出错: {str(e)}") + return None \ No newline at end of file diff --git a/paste/util/ufont.py b/paste/util/ufont.py new file mode 100644 index 0000000..19fc4f3 --- /dev/null +++ b/paste/util/ufont.py @@ -0,0 +1,58 @@ +from matplotlib import pyplot as plt +from matplotlib.font_manager import fontManager, FontProperties + + +def get_fonts(): + """ + 取得系统字体,并与要采用的字体合并后,取得可用字体。 + """ + # 系统所有可用字体 + os_fonts = {f.name for f in fontManager.ttflist} + # 自定义字体,优先级按顺序排列 + custom_fonts = ( + 'PingFang SC', 'Hiragino Sans GB', 'Heiti SC', 'SimSong', 'SimHei', + 'WenQuanYi Micro Hei', 'WenQuanYi Zen Hei', 'Source Han Sans SC', + 'Noto Sans CJK', 'Noto Sans CJK SC', 'DejaVu Sans' + ) + # 可用字体 + available_font = set(custom_fonts) & os_fonts + # 字典排序 + available_font = sorted( + available_font, key=lambda x: custom_fonts.index(x) if x in custom_fonts else len(custom_fonts) + ) + return available_font + + +def get_font_metrics(font_name='Microsoft YaHei', font_size=11, dpi=72): + """ + 使用 matplotlib 获取字体度量信息。 + + :param font_name: 字体名称 + :param font_size: 字号 + :param dpi: 显示像素,像素没英寸 + :return: (英文字符宽度_cm, 中文字符宽度_cm) + """ + # 创建高分辨率图形 + fig = plt.figure(figsize=(10, 2), dpi=dpi) + ax = fig.add_subplot(111) + ax.axis('off') + + # 设置字体 + font = FontProperties(family=font_name, size=font_size) + + # 测试英文字符 + text_en = ax.text(0.1, 0.5, 'aaaaa', fontproperties=font) + fig.canvas.draw() + en_width_px = text_en.get_window_extent().width / 5 # 5个字符的平均宽度 + + # 测试中文字符 + text_cn = ax.text(0.1, 0.5, '中中中中中', fontproperties=font) + fig.canvas.draw() + cn_width_px = text_cn.get_window_extent().width / 5 # 5个字符的平均宽度 + + plt.close(fig) + + # 转换为厘米 + px_per_cm = dpi / 2.54 + # 增加100%宽度 + return en_width_px / px_per_cm * 2, cn_width_px / px_per_cm * 2 \ No newline at end of file diff --git a/paste/util/uimg.py b/paste/util/uimg.py new file mode 100644 index 0000000..d12b99c --- /dev/null +++ b/paste/util/uimg.py @@ -0,0 +1,214 @@ +""" +基本公共函数。 +""" +import base64 +import os +import re +from typing import Union +from urllib.parse import urlparse + +import requests + +from paste.db import basemodel + + +def fetch_image(img_url: str) -> tuple[requests.Response, str]: + """ + 获取外部图像。 + + :param img_url: 图像 URL + :return: (响应对象,内容类型) + :raises ValueError: URL 格式无效 + :raises requests.exceptions.RequestException: 请求失败 + """ + # 验证 URL 格式 + parsed_url = urlparse(img_url) + if not all([parsed_url.scheme, parsed_url.netloc]): + raise ValueError("Invalid URL") + + # 设置请求头,模拟浏览器请求 + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) ' + 'AppleWebKit/537.36 (KHTML, like Gecko) ' + 'Chrome/91.0.4472.124 Safari/537.36' + } + + # 获取外部图像 + response = requests.get(img_url, headers=headers, stream=True, timeout=10) + response.raise_for_status() + + # 获取内容类型,如果没有则默认为 image/jpeg + content_type = response.headers.get('Content-Type', 'image/jpeg') + + return response, content_type + + +def save_image_to_dir(image_data: bytes, image_type: str, output_dir: str) -> str: + """ + 将图像数据保存到指定目录,返回相对路径。 + + :param image_data: 图像二进制数据 + :param image_type: 图像扩展名(如 'jpg', 'png') + :param output_dir: 输出目录(相对于项目根目录,如 'static/upload/article/images') + :return: 保存后的相对路径(以 / 开头) + """ + # 生成唯一文件名 + filename = f"{basemodel.BaseModel.newId()}.{image_type}" + full_path = os.path.abspath(os.path.join(os.curdir, output_dir, filename)) + + # 确保目录存在 + os.makedirs(os.path.dirname(full_path), exist_ok=True) + + # 保存图像 + with open(full_path, 'wb') as f: + f.write(image_data) + + # 返回相对路径(以 / 开头) + rel_path = os.path.join(output_dir, filename).replace('\\', '/') + if not rel_path.startswith('/'): + rel_path = '/' + rel_path + return rel_path + + +def download_and_save_image(url: str, output_dir: str) -> Union[str, None]: + """ + 从外部 URL 下载图像并保存到指定目录。 + + :param url: 外部图像的完整 URL + :param output_dir: 输出目录 + :return: 保存成功时返回相对路径,失败时返回 None + """ + try: + res_img, res_content_type = fetch_image(url) + + # 提取扩展名 + image_type = res_content_type.split('/')[1].split(';')[0].strip() if '/' in res_content_type else 'jpg' + + # 验证扩展名安全性 + allowed_extensions = {'jpg', 'jpeg', 'png', 'gif', 'webp', 'svg', 'bmp'} + if image_type not in allowed_extensions: + image_type = 'jpg' + + # 收集图像数据 + image_data = b''.join(res_img.iter_content(1024)) + + # 保存到本地 + new_src = save_image_to_dir(image_data, image_type, output_dir) + return new_src + except Exception: + return None + + +def decode_base64_image(header: str, data: str, output_dir: str) -> str: + """ + 解码 base64 格式的图像数据并保存到指定目录。 + + :param header: base64 数据头 + :param data: base64 编码的图像数据 + :param output_dir: 输出目录 + :return: 保存后的相对路径 + """ + # 从 header 中获取图像类型 + image_type = header.split(';')[0].split('/')[1] + + # 验证扩展名安全性 + allowed_extensions = {'jpg', 'jpeg', 'png', 'gif', 'webp', 'svg', 'bmp'} + if image_type not in allowed_extensions: + image_type = 'jpg' + + # 解码并保存 + image_data = base64.b64decode(data) + return save_image_to_dir(image_data, image_type, output_dir) + + +def extract_image_paths(html_content: str) -> list[dict]: + """ + 从 HTML 内容中提取所有图像的 src 信息。 + + 该方法用于识别文章中引用的所有图像资源,返回详细的图像信息列表。 + + :param str html_content: HTML 内容 + :return: 图像信息列表,每个元素包含 src 值和类型 + :rtype: list[dict] + + 返回结构:: + + [ + { + 'original': 'https://external.com/img.jpg', # 原始 src 值 + 'src': '/static/upload/article/images/abc.jpg', # 标准化后的本地路径(external/base64 为 None) + 'type': 'external', # local: 本地路径,domain: 本地域名,external: 外部域名,base64: base64 数据 + 'url': 'https://external.com/img.jpg' # 完整 URL(仅 external 类型有值) + } + ] + + 注意:: + + - local/domain 类型:src 为标准化本地路径 + - external 类型:src 为 None,url 为原始外部 URL + - base64 类型:src 为 None,url 为 None + """ + # 允许的本地域名列表 + allowed_domains = { + 'haiten.cn', 'www.haiten.cn', 'usasu.cn', 'www.usasu.cn', 'pathx.cn', 'www.pathx.cn', + '127.0.0.1', '100.64.0.18', 'localhost' + } + + # 改进的正则表达式: + # - 允许 src 是第一个属性 + # - 支持单引号和双引号 + # - 确保引号成对匹配 + # - 支持跨行匹配 + img_pattern = re.compile( + r']*?\s+src\s*=\s*(["\'])([^"\']+?)\1[^>]*?>?', + re.IGNORECASE | re.DOTALL + ) + + images = [] + + for match in img_pattern.finditer(html_content): + original_src = match.group(2) # 捕获组 2 是 src 的值 + image_info = { + 'original': original_src, + 'src': None, + 'type': None, + 'url': None + } + + # 判断图像类型 + if original_src.startswith('data:image'): + # base64 数据 + image_info['type'] = 'base64' + + elif original_src.startswith(('http://', 'https://')): + parsed_url = urlparse(original_src) + domain = parsed_url.netloc.split(':')[0] + + if domain in allowed_domains: + # 本地域名 - 转换为相对路径 + new_src = parsed_url.path + if parsed_url.query: + new_src += f"?{parsed_url.query}" + if parsed_url.fragment: + new_src += f"#{parsed_url.fragment}" + # 确保路径以 / 开头 + if not new_src.startswith('/'): + new_src = '/' + new_src + image_info['src'] = new_src + image_info['type'] = 'domain' + else: + # 外部域名 + image_info['type'] = 'external' + image_info['url'] = original_src + + else: + # 本地相对路径 + # 确保路径以 / 开头 + if not original_src.startswith('/'): + original_src = '/' + original_src + image_info['src'] = original_src + image_info['type'] = 'local' + + images.append(image_info) + + return images diff --git a/paste/util/ustr.py b/paste/util/ustr.py new file mode 100644 index 0000000..e2cf463 --- /dev/null +++ b/paste/util/ustr.py @@ -0,0 +1,218 @@ +import datetime +import gzip +import io +import re +from typing import List +from urllib.parse import quote + + +def str_q_count(ustring): + """ + 汉字加全角字符数量。 + + :param ustring: 待扫描文本 + :return: 全角字符数量 + """ + count = 0 + for uchar in ustring: + inside_code = ord(uchar) + if '\u4e00' <= uchar <= '\u9fff' or 65281 <= inside_code <= 65374: + count += 1 + return count + + +def str_q2b(ustring): + """ + 全角转半角。 + + :param ustring: 待转换文本 + :return: 转换后的文本 + """ + r_str = "" + for uchar in ustring: + inside_code = ord(uchar) + if inside_code == 12288: + # 全角空格直接转换 + inside_code = 32 + elif 65281 <= inside_code <= 65374: + # 全角字符(除空格)根据关系转化 + inside_code -= 65248 + r_str += chr(inside_code) + return r_str + + +def str_b2q(ustring): + """ + 半角转全角。 + + :param ustring: 待转换文本 + :return: 转换后的文本 + """ + r_str = "" + for uchar in ustring: + inside_code = ord(uchar) + if inside_code == 32: + # 半角空格直接转化 + inside_code = 12288 + elif 32 <= inside_code <= 126: + # 半角字符(除空格)根据关系转化 + inside_code += 65248 + r_str += chr(inside_code) + return r_str + + +def str_gzip(data: str): + """ + 创建gzip压缩数据。 + + :param data: 待压缩的数据 + """ + buffer = io.BytesIO() + with gzip.GzipFile(fileobj=buffer, mode='w') as f: + f.write(data.encode('utf-8')) + _compressed_data = buffer.getvalue() + return _compressed_data + + +def is_contains_chinese(text, length: int = None): + """ + 检查字符串中是否包含中文字符。 + + :param text: 要检查的字符串 + :param length: 可选参数,要求中文字符的最小数量 + :return: 如果包含中文字符返回True,否则返回False + """ + chinese_chars = [char for char in text if '\u4e00' <= char <= '\u9fff'] + + if not chinese_chars: + # 如果没有中文字符 + return False + + if length is not None: + # 如果指定了length参数 + return len(chinese_chars) >= length + + return True # 默认情况,只要包含中文就返回True + + +def is_valid_id_number(id_str): + """ + 检查字符串是否符合中国居民身份证号码格式。 + + 支持15位和18位身份证号码,包括校验位验证 + :param id_str: 要检查的字符串 + :return: 如果符合格式返回True,否则返回False + """ + # 正则表达式匹配 + pattern = r'^[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]$' + if not re.match(pattern, id_str): + return False + + # 如果是15位身份证,直接返回True(15位不包含校验位) + if len(id_str) == 15: + return True + + # 18位身份证校验位验证 + # 权重系数 + weight = [7, 9, 10, 5, 8, 4, 2, 1, 6, 3, 7, 9, 10, 5, 8, 4, 2] + # 校验码对应值 + validate = ['1', '0', 'X', '9', '8', '7', '6', '5', '4', '3', '2'] + + # 计算校验位 + sum_val = 0 + for i in range(17): + sum_val += int(id_str[i]) * weight[i] + + mod_val = sum_val % 11 + if validate[mod_val].upper() != id_str[17].upper(): + return False + + return True + + +def is_valid_phone_number(phone_str): + """ + 验证是否是中国大陆合法的手机号码。 + + :param phone_str: 要检查的字符串 + :return: 如果是合法手机号返回True,否则返回False + """ + # 2023年中国大陆手机号正则表达式 + pattern = r'^1(3[0-9]|4[5-9]|5[0-35-9]|6[2567]|7[0-8]|8[0-9]|9[0-35-9])\d{8}$' + + return bool(re.fullmatch(pattern, phone_str)) + + +def is_valid_postcode(postcode): + """ + 验证中国邮政编码是否合法 + :param postcode: 要验证的邮编字符串或数字 + :return: 如果合法返回True,否则返回False + """ + # 转换为字符串处理 + postcode_str = str(postcode) + + # 中国邮政编码规则: + # 1. 6位数字 + # 2. 第一位不能是0 + pattern = r'^[1-9]\d{5}$' + + return bool(re.fullmatch(pattern, postcode_str)) + + +def encode_path_to_url(local_path: str) -> str: + """ + 将本地文件路径转换为URL编码的相对路径 + + 参数: + local_path: 本地路径(如 "C:\\data\\报告.pdf" 或 "/var/www/文件.txt") + + 返回: + URL编码的相对路径(如 "data/%E6%8A%A5%E5%91%8A.pdf") + + 处理逻辑: + 1. 统一路径分隔符为/ + 2. 移除Windows盘符 + 3. 分段编码每个路径部分 + 4. 保留路径中的/分隔符 + """ + # 统一路径分隔符为POSIX格式 + normalized_path = local_path.replace('\\', '/') + + # 移除Windows盘符(如 C:/) + normalized_path = re.sub(r'^[A-Za-z]:/', '', normalized_path) + + # 移除开头多余的/ + normalized_path = normalized_path.lstrip('/') + + # 分段处理每个路径部分 + encoded_parts = [] + for part in normalized_path.split('/'): + if part: + # 对每个路径段进行URL编码(保留. _ - 不编码) + encoded_part = quote(part, safe='.-_') + encoded_parts.append(encoded_part) + + # 拼接编码后的路径 + return '/'.join(encoded_parts) + + +def to_datetime(dt_str: str, fmt_list: List[str]): + """ + 字符串转时间日期对象。 + + :param dt_str: 需要转日期格式的字符串 + :param fmt_list: 用于转换的日期格式列表,注意将最有可能的放在前面 + """ + _date = None + + for _fmt in fmt_list: + if _date is None: + try: + _date = datetime.datetime.strptime(dt_str, _fmt) + except (ValueError, Exception): + pass + else: + return _date + + return _date \ No newline at end of file diff --git a/paste/util/xlsx.py b/paste/util/xlsx.py new file mode 100644 index 0000000..dca8b06 --- /dev/null +++ b/paste/util/xlsx.py @@ -0,0 +1,154 @@ +from typing import Union, List, Optional, Dict, Any + +import pandas as pd + +from paste.util import ufont + + +def cm_to_excel_units(cm): + """ + 厘米转Excel列宽单位。 + + :param cm: 厘米单位 + """ + return cm / 2.54 * 7 # 1英寸=2.54厘米, 1Excel单位=1/7英寸 + + +def auto_width_cm(series: pd.Series, font_name='Microsoft YaHei', font_size=11, min_cm=1.5, max_cm=20): + """ + 自动列宽计算方法(区分中英文)。 + + :param series: pandas Series (数据列) + :param font_name: 字体名称 + :param font_size: 字号 + :param min_cm: 最小列宽(厘米) + :param max_cm: 最大列宽(厘米) + :return: 建议的列宽(厘米) + """ + # 获取精确字体度量 + en_width_cm, cn_width_cm = ufont.get_font_metrics(font_name, font_size) + + def calculate_text_width(text): + """计算文本总宽度""" + cn_count = 0 + en_count = 0 + for char in str(text): + if '\u4e00' <= char <= '\u9fff': + cn_count += 1 + else: + en_count += 1 + + _total_width = (cn_count * cn_width_cm) + (en_count * en_width_cm) + return _total_width + + # 计算列标题宽度 + title_width = calculate_text_width(series.name) + + # 计算数据内容最大宽度 + content_width = series.astype(str).apply(calculate_text_width).max() + + # 取最大值并增加边距,15%额外边距 + total_width = max(title_width, content_width) * 1.2 + + return max(min(total_width, max_cm), min_cm) + + +def auto_column_width(df, worksheet): + """ + 根据内容自动设置列的宽度。 + + :param df: pandas DataFrame + :param worksheet: 工作表 + """ + _font_name_set = ufont.get_fonts() + for col_num, col_name in enumerate(df.columns): + # 计算列宽 + width_cm = auto_width_cm(df[col_name], font_name=_font_name_set[0], font_size=11) + # 设置列宽 + worksheet.set_column(col_num, col_num, cm_to_excel_units(width_cm)) + + +def apply_header_style(df, worksheet, workbook, **kwargs): + """ + 应用表头样式。 + + :param df: 原始 DataFrame + :param worksheet: xlsxwriter worksheet对象 + :param workbook: xlsxwriter workbook对象 + :param kwargs: 样式参数 + :return: 无 + """ + _style_sheet = { + 'font_size': 12, + 'bg_color': '#F2F2F2', + 'border': 1, + 'bold': True, + } + _header_style = workbook.add_format({**_style_sheet, **kwargs}) + + for col_num, value in enumerate(df.columns.values): + worksheet.write(0, col_num, value, _header_style) + + +def apply_data_style(df, worksheet, workbook, **kwargs): + """ + 应用数据单元格样式。 + + :param df: 原始 DataFrame + :param worksheet: xlsxwriter worksheet对象 + :param workbook: xlsxwriter workbook对象 + :param kwargs: 样式参数 + :return: 无 + """ + _style_sheet = { + 'font_size': 12, + 'border': 1, + } + _cell_style = workbook.add_format({**_style_sheet, **kwargs}) + + for row in range(1, len(df) + 1): + for col in range(0, len(df.columns)): + worksheet.write(row, col, df.iloc[row - 1, col], _cell_style) + + +def insert_text_to_column( + worksheet, + workbook, + column: int, + start_row: int, + texts: Union[str, List[str]], + text_format: Optional[Dict[str, Any]] = None +) -> None: + """ + 向Excel表格的指定列插入文本。 + + :param worksheet: xlsxwriter worksheet + :param workbook: xlsxwriter workbook + :param column: 列号 + :param start_row: 开始插入的行号(1-based) + :param texts: 要插入的文本(字符串或字符串列表) + :param text_format: 格式字典,None则使用默认格式 + :return: None + """ + # 设置默认格式 + default_format = { + 'font_size': 12, + } + + # 合并用户自定义格式 + fmt = workbook.add_format({**default_format, **(text_format or {})}) + + # 转换列号为数字索引(1-based) + if isinstance(column, str): + col_idx = ord(column.upper()) - ord('A') + 1 + else: + col_idx = column + + # 确保texts是列表形式 + if isinstance(texts, str): + texts = [texts] + + # 处理每行数据 + for i, text in enumerate(texts): + row_num = start_row + i + worksheet.write(row_num, col_idx, text, fmt) diff --git a/paste/web/__init__.py b/paste/web/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paste/web/application.py b/paste/web/application.py new file mode 100755 index 0000000..fe145e1 --- /dev/null +++ b/paste/web/application.py @@ -0,0 +1,236 @@ +import importlib +import pkgutil +from types import ModuleType +from typing import Optional, Any + +import tornado +from tornado.routing import URLSpec +from tornado.web import OutputTransform, _RuleList + + +class Application(tornado.web.Application): + """ + 从 Tornado 派生的应用程序类。 + """ + + @classmethod + def modules_iterator(cls, package: [str, ModuleType]): + """ + 从 package 包装载所有模块。这里返回模块迭代器。 + 若为字符串,则直接从目录中载入模块;若为模块,则根据模块的参数装载。 + + :param package: 包,允许为路径或包(模块对象) + :return: 模块迭代器 + """ + if isinstance(package, str): + package = importlib.import_module(package) + + # 模块迭代器,能够遍历出所有子包中的子模块 + return pkgutil.walk_packages(package.__path__, f"{package.__name__}.") + + @classmethod + def fetch_handlers(cls, module: ModuleType): + """ + 查找模块中所有的请求处理类,即类 RequestHandler 的子类。 + + :param module: 模块 + :return: [(路由模式, 请求处理类)] + """ + # 判断是否是有效的 RequestHandler 类,且是 RequestHandler 的子类 + def is_handler(handler_cls): + return isinstance(handler_cls, type) and issubclass(handler_cls, tornado.web.RequestHandler) + + # 判断是否拥有 route_pattern 模式属性,且该属性值为字符串类型 + def has_pattern(handler_cls): + return hasattr(handler_cls, 'route_pattern') and isinstance(getattr(handler_cls, 'route_pattern'), str) + + handlers: list[tuple[str, ModuleType]] = [] + # 迭代模块成员 + for _n in dir(module): + _cls = getattr(module, _n) + is_hdl = is_handler(_cls) + has_pat = has_pattern(_cls) + if is_hdl and has_pat: + _route = _cls.route_pattern + handlers.append((_route, _cls)) + + return handlers + + @classmethod + def load_ui_modules(cls, ui_modules_config): + """ + 将JSON配置中的模块字符串转换为实际的类。 + """ + loaded_modules = {} + for name, path in ui_modules_config.items(): + try: + module_path, class_name = path.rsplit('.', 1) + module = importlib.import_module(module_path) + loaded_modules[name] = getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise RuntimeError(f"Failed to load UIModule {name} from {path}: {str(e)}") + return loaded_modules + + def __init__( + self, + handlers: Optional[_RuleList] = None, + handlers_pkg: [str, ModuleType] = None, + uri_prefix: str = "", + **settings: Any + ) -> None: + """ + 重写应用程序构造函数,增加自动装载功能。 + + :param handlers: 请求处理路由配置列表 + :param handlers_pkg: 执行自动装载的请求处理类所在包 + :param uri_prefix: URI 前缀 + :param settings: 其他配置 + """ + + self.routes: list[(URLSpec, _RuleList)] = [] + """ + 请求处理路由列表。 + """ + + if uri_prefix: + uri_prefix = uri_prefix if uri_prefix.startswith('/') else f"/{uri_prefix}" + + self.uri_prefix = uri_prefix + """ + 统一资源标识符前缀。仅支持动态加载的请求处理类。 + """ + + # 合并构造参数中的请求处理路由 + if handlers: + self.routes.extend(handlers) + + # 动态加载请求处理类,并执行合并 + if handlers_pkg: + self.routes.extend(self.load_handlers(handlers_pkg=handlers_pkg)) + + self.before_create() + + super().__init__(handlers=self.routes, **settings) + + def before_create(self): + """ + 在创建应用之前执行。 + """ + pass + + def load_handlers(self, handlers_pkg: [str, ModuleType] = None): + """ + 从 handlers_pkg 指定的包装载所有模块,分析出所有请求处理类和路由路径,并返回。 + + :param handlers_pkg: 模块根目录,允许为路径或包(模块对象) + :return: 动态装载的所有路由配置 + """ + _routes = [] + + if handlers_pkg is None: + return _routes + + # 迭代器装载所有子包中的子模块 + modules_itr = self.modules_iterator(package=handlers_pkg) + for _file_finder, _module_name, _is_package in modules_itr: + if _is_package: + continue + + _module = importlib.import_module(_module_name) + _handlers = self.fetch_handlers(module=_module) + for _hdl in _handlers: + _pattern, _cls = str(_hdl[0]), _hdl[1] + _pattern = _pattern if _pattern.startswith('/') else f"/{_pattern}" + _url_spec = tornado.web.url( + pattern=f"{self.uri_prefix}{_pattern}", handler=_cls, name=_cls.__name__ + ) + _routes.append(_url_spec) + + return _routes + + +class ApplicationSwagger(Application): + """ + 从框架 Application 派生,增加对 Swagger 的支持。 + """ + + swagger_schema = "" + """ + 在 Swagger 注入时,保存 json schema。 + """ + + swagger_home_template = "" + """ + 在 Swagger 注入时,保存 Ui 页面内容。 + """ + + swagger_url = "/docs" + """ + Swagger URL。 + """ + + swagger_api_base_url = "/" + """ + Swagger API base URL。 + """ + + swagger_title = "" + """ + Swagger 页面标题。 + """ + + swagger_description = "" + """ + Swagger 页面描述。 + """ + + swagger_api_version = "" + """ + Swagger 页面版本。 + """ + + swagger_contact = "" + """ + Swagger 页面联系方式。 + """ + + swagger_schemes = ["http", "https"] + """ + Swagger 协议方案。 + """ + + def __init__(self, **settings: Any) -> None: + self.swagger_schema = settings.get('swagger_schema', self.swagger_schema) + self.swagger_home_template = settings.get('swagger_home_template', self.swagger_home_template) + + self.swagger_url = settings.get('swagger_url', self.swagger_url) + self.swagger_url = self.swagger_url if self.swagger_url.startswith('/') else f"/{self.swagger_url}" + + self.swagger_api_base_url = settings.get('swagger_api_base_url', self.swagger_api_base_url) + self.swagger_api_base_url = self.swagger_api_base_url if self.swagger_api_base_url.startswith('/') \ + else f"/{self.swagger_api_base_url}" + + self.swagger_title = settings.get('swagger_title', self.swagger_title) + self.swagger_description = settings.get('swagger_description', self.swagger_description) + self.swagger_api_version = settings.get('swagger_api_version', self.swagger_api_version) + self.swagger_contact = settings.get('swagger_contact', self.swagger_contact) + self.swagger_schemes = settings.get('swagger_schemes', self.swagger_schemes) + + super().__init__(**settings) + + def before_create(self): + _swagger_url = f"{self.uri_prefix}{self.swagger_url}" + _swagger_api_base_url = f"{self.uri_prefix}{self.swagger_api_base_url}" + + from paste.web.swagger import setup_swagger + setup_swagger( + app=self, + routes=self.routes, + swagger_url=_swagger_url, + api_base_url=_swagger_api_base_url, + title=self.swagger_title, + description=self.swagger_description, + api_version=self.swagger_api_version, + contact=self.swagger_contact, + schemes=self.swagger_schemes, + ) diff --git a/paste/web/decorators.py b/paste/web/decorators.py new file mode 100644 index 0000000..b69a3e8 --- /dev/null +++ b/paste/web/decorators.py @@ -0,0 +1,205 @@ +import functools +import logging +from typing import Awaitable + +from jwt import ExpiredSignatureError, InvalidSignatureError, InvalidTokenError + +from paste.security import token +from paste.web.handler import RequestHandler + + +def route(route_pattern: str): + """ + 路由装饰器。为类增加 route_pattern 属性,并赋值。 + + :param route_pattern: URL 路径模式 + """ + + def wrapper(cls: type[RequestHandler]): + cls.route_pattern = route_pattern + return cls + + # 标记已经被 route 装饰 + setattr(wrapper, '__route__', True) + return wrapper + + +def auth_token(func): + """ + 令牌验证装饰器,用于 :class:`tornado.web.RequestHandler` 子类中的 get()/post() 等需要执行权限验证的方法,以便在正 + 式执行方法前利用客户端提交的令牌进行鉴权。 + + 当执行该装饰器解码令牌后,将更新 RequestHandler 对象的 token_payload 属性数据。其次,若能通过令牌中配载的 user_id 取 + 得用户数据,则还将设置 current_user 属性。 + + 该装饰器仅用于校验令牌的有效性,并取得用户信息,不负责校验用户的具体权限,若要验证权限需使用 @auth_permission 装饰器。 + + 使用方式如下:: + + @auth_token + async def post(self): + pass + + 要求在请求的 Headers 中必须包含 Access-Token,且内容由 encode_token 方法签发。 + + :param func: 被装饰的函数对象,不需要手动传入该参数 + :return: 装饰后的函数对象 + """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + req_handler: RequestHandler = args[0] + try: + # 请求头 + req_headers = dict(req_handler.request.headers) + + # 取出 Token + access_token = req_headers.get('Access-Token', None) + if access_token in (None, b'', ''): + raise InvalidTokenError(f'请求地址:{req_handler.request.uri}') + + # 如果采用 OAuth2 规范,这里应当调用远程 API 执行解码,解码后返回用户信息 + + # 用解码后的 Token 字典更新 Handler 中的的 token_dict + token_payload = token.decode_token(access_token) + req_handler.token_payload.update(token_payload) + + # 根据 Token 读取用户对象,并设置到请求处理对象(控制器) + _user_id = req_handler.token_param('user_id') + if _user_id and req_handler.user_class: + _user = await req_handler.user_class.async_find_by_id(_user_id) + if _user is None: + raise InvalidTokenError() + req_handler.current_user = _user + await req_handler.after_auth_token(token_payload) + + # 兼容同步或异步方法 + _result = func(*args, **kwargs) + if isinstance(_result, Awaitable): + _result = await _result + return _result + except ExpiredSignatureError as e: + e.args = ('令牌已过期,请求被拒绝.',) + req_handler.response_error(e, status_code=403, api_status_code=403) + req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True) + return None + except InvalidSignatureError as e: + e.args = ('令牌签名错误,请求被拒绝.',) + req_handler.response_error(e, status_code=403, api_status_code=403) + req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True) + return None + except InvalidTokenError as e: + e.args = ('令牌错误,请求被拒绝.',) + req_handler.response_error(e, status_code=401, api_status_code=401) + req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True) + return None + except Exception as e: + req_handler.response_error(e, status_code=501, api_status_code=501) + req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True) + return None + + # 标记已经被 auth_token 装饰 + setattr(wrapper, '__auth_token__', True) + return wrapper + + +def auth_permission(func): + """ + 权限检查装饰器。若不启用 RBAC 则不应用该装饰器。 + 用于检查用户是否有执行某个操作的具体权限。该装饰器须跟随在 @auth_token 装饰器的后面使用。 + + :param func: 被装饰的函数对象,不需要手动传入该参数 + :return: 装饰后的函数对象 + """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # + # 为了能在不使用 RBAC 的系统中正常运行,这里的引用必须放在函数中 + # 否则初始化过程中 RBAC 数据模型会尝试读取数据表配置,发生错误 + # + from paste.rbac.rbac_user import RbacUser, Supervisors + + req_handler: RequestHandler = args[0] + try: + # 验证当前用户状态 + _user: RbacUser = req_handler.current_user + assert _user is not None, f"无效令牌或未登录,无权执行:{req_handler.route_pattern} 操作." + + # 类型检测 + _right_type = isinstance(_user, RbacUser) + assert _right_type, f"当前用户类型错误,必须为 RbacUser 的子类." + + if _user.username not in Supervisors: + # 验证用户权限状态 + _has_permission = await _user.has_permission(req_handler.route_pattern) + assert _has_permission, f"当前用户 {_user.username} 无权执行:{req_handler.route_pattern} 操作." + + # 兼容同步或异步方法 + _result = func(*args, **kwargs) + if isinstance(_result, Awaitable): + _result = await _result + return _result + except AssertionError as e: + req_handler.response_error(e, status_code=401, api_status_code=401) + req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True) + return None + except Exception as e: + req_handler.response_error(e, status_code=501, api_status_code=501) + req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True) + return None + + # 标记已经被 auth_permission 装饰 + setattr(wrapper, '__auth_permission__', True) + return wrapper + + +def auth_rule(func): + """ + 规则检查装饰器。若不启用规则验证,则不应用该装饰器。 + 用于对用户按规则验证。该装饰器须跟随在 @auth_token 装饰器的后面使用。 + + :param func: 被装饰的函数对象,不需要手动传入该参数 + :return: 装饰后的函数对象 + """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # + # 为了能在不使用 RBAC 的系统中正常运行,这里的引用必须放在函数中 + # 否则初始化过程中 RBAC 数据模型会尝试读取数据表配置,发生错误 + # + from paste.rbac.rbac_user import RbacUser, Supervisors + + req_handler: RequestHandler = args[0] + try: + # 验证当前用户状态 + _user: RbacUser = req_handler.current_user + assert _user is not None, f"无效令牌或未登录,无权执行:{req_handler.route_pattern} 操作." + + # 类型检测 + _right_type = isinstance(_user, RbacUser) + assert _right_type, f"当前用户类型错误,必须为 RbacUser 的子类." + + if _user.username not in Supervisors: + # 验证用户权限状态 + _user_can = await _user.can(req_handler.route_pattern, **kwargs) + assert _user_can, f"当前用户 {_user.username} 无权执行:{req_handler.route_pattern} 操作(规则验证不通过)." + + # 兼容同步或异步方法 + _result = func(*args, **kwargs) + if isinstance(_result, Awaitable): + _result = await _result + return _result + except AssertionError as e: + req_handler.response_error(e, status_code=401, api_status_code=401) + req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True) + return None + except Exception as e: + req_handler.response_error(e, status_code=501, api_status_code=501) + req_handler.log(msg=e, level=logging.ERROR, is_log_exc=True) + return None + + # 标记已经被 auth_rule 装饰 + setattr(wrapper, '__auth_rule__', True) + return wrapper diff --git a/paste/web/form.py b/paste/web/form.py new file mode 100644 index 0000000..5186f76 --- /dev/null +++ b/paste/web/form.py @@ -0,0 +1,61 @@ +from wtforms_tornado import Form + + +class ModelForm(Form): + """ + 模型表单。派生后主要处理以下内容:: + + 有可能在 formdata 中出现的非列表类型,统一转为列表类型。 + """ + + def __init__(self, formdata=None, obj=None, prefix="", data=None, meta=None, **kwargs): + """ + 构造模型表单。 + + :param formdata: 来自客户端的输入数据,通常为 request.form 或等效数据。应该提供一个 multi-dict 接口来获取给定键的值列表。 + :param obj: 从该对象上与表单字段属性匹配的属性中获取现有数据。仅在未传递 formdata 时使用。 + :param prefix: 如果提供,所有字段的名称都将以值为前缀。这是为了区分单个页面上的多个表单。这只会影响匹配输入数据的 HTML 名称,而不会影响匹配现有数据的 Python 名称。 + :param data: 从该 dict 中与表单字段属性匹配的键中获取现有数据,如果 obj 也有匹配的属性,则它优先。仅在未传递 formdata 时使用。 + :param meta: 要在此窗体的 :attr: meta 实例上重写的属性 dict。 + :param kwargs: 与 data 合并以允许将现有数据作为参数传递。覆盖 data 中的任何重复键。仅在未传递 formdata 时使用。 + """ + if isinstance(formdata, dict): + # 对有可能在 formdata 中出现的非列表类型,统一转为列表类型 + formdata = {k: list(v) if isinstance(v, (list, tuple, set)) else [f'{v}'] for k, v in formdata.items()} + + # 启动父类构造 + super(Form, self).__init__(formdata=formdata, obj=obj, prefix=prefix, data=data, meta=meta, **kwargs) + + @classmethod + def list_to_field_list(cls, formdata, field_name: str, separator: str = '-'): + """ + 将 list 数据转换为符合 FieldList 的赋值规则的字段数据,即转换为 字段名+隔符+下标 格式的表单数据。 + + :param formdata: 来自客户端的输入数据 + :param field_name: 字段名 + :param separator: 分隔符,默认与 FieldList 一致为 "-" 符号 + """ + # 当以 JSON 数组传入时,转换为以 - 连接的字段项,以符合 FieldList 的赋值规则 + _value_list = formdata.get(field_name, []) + if _value_list and isinstance(_value_list, list): + for _idx, _item in enumerate(_value_list): + formdata[f"{field_name}{separator}{_idx}"] = _item + + def validate_form(self, auto_raise: bool = True): + """ + 验证表单数据。 + + :param auto_raise: 当该参数为 True 时,若验证不成功,抛出验证异常。 + :return: 验证结果 + """ + validate_result = {} + + if self.validate(): + return True, validate_result + else: + validate_result.update(self.errors.items()) + + if auto_raise: + raise Exception('数据验证错误!', {'form_data': self.data, 'form_errors': validate_result}) + + return False, validate_result diff --git a/paste/web/handler.py b/paste/web/handler.py new file mode 100755 index 0000000..f6ab628 --- /dev/null +++ b/paste/web/handler.py @@ -0,0 +1,249 @@ +import importlib +import json +import logging +from abc import ABC +from collections import namedtuple +from typing import Optional, Union, Any, Type + +import tornado.web + +from paste.core import config +from paste.db.basemodel import BaseModel +from paste.util.encoder import JsonDumpsEncoder +from paste.core.logging import echo_log + + +def init_user_class(): + """ + 从配置文件初始化用户类。默认采用 rbac.RbacUser。 + """ + + try: + # 若没有配置 RBAC 直接返回 None + _rbac_cfg = config.get_config('rbac.user_class', None) + except AssertionError: + return None + + _cfg_user_class: str = config.get_config('rbac.user_class', None) + if _cfg_user_class is not None: + _parts = _cfg_user_class.split('.') + _module_name = '.'.join(_parts[:-1]) + _user_module = importlib.import_module(_module_name) + _user_class = getattr(_user_module, _parts[-1]) + return _user_class + + from paste.rbac.rbac_user import RbacUser + return RbacUser + + +class RequestHandler(tornado.web.RequestHandler, ABC): + """ + 请求控制父类。 + """ + + route_pattern: Optional[str] = None + """ + URL 路径模式。由装饰器 web.decorators.route 赋值,在 base.Application.load_handler_module 自动加载时调用,作为访问 + 路径,设置到 Application 中。 + """ + + user_class: Type[BaseModel] = init_user_class() + """ + 用户数据处理类。装饰器 web.decorators.auth_token 执行令牌验证时调用该类,用于创建用户对象,并保存在 current_user 属性中。 + 注意:这里仅初始化类,而不创建对象。该类允许用户继承扩展,然后自行配置。主要用于执行有关用户的数据操作。 + """ + + @classmethod + def log(cls, msg: Union[str, Exception], level: int = logging.INFO, is_log_exc: bool = False): + """ + 输出日志文本。 + + :param msg: 消息内容,当是 Exception 对象时,从 args 中取出第一项作为消息 + :param level: 消息等级 + :param is_log_exc: 是否输出异常信息到日志文件 + """ + echo_log(msg=msg, level=level, is_log_exc=is_log_exc) + + @classmethod + def dict_to_namedtuple(cls, name, data): + """ + 递归转换字典和列表中的字典为 namedtuple 对象。 + + 参数: + name: 用于创建 namedtuple 的名称 + data: 要转换的数据,可以是 dict、list 或基本类型 + + 返回: + 转换后的 namedtuple 对象或列表 + """ + if isinstance(data, dict): + # 处理字典类型 + NT = namedtuple(name, data.keys()) + return NT(**{ + k: cls.dict_to_namedtuple(k, v) + for k, v in data.items() + }) + elif isinstance(data, list): + # 处理列表类型:递归转换列表中的每个元素 + return [ + cls.dict_to_namedtuple(f"{name}_item", item) + if isinstance(item, (dict, list)) else item + for item in data + ] + else: + # 基本类型直接返回 + return data + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.rule_kwargs = {} + """ + 规则参数,用于在控制器和规则之间做数据交换 + """ + + self.token_payload: dict[str: Any] = {} + """ + 令牌配载数据字典。装饰器 web.decorators.auth_token 执行令牌验证时解码并赋值。 在 HandlerRequest 子类中 + 只要配置 auth_token 装饰即可使用该配载数据。 + + 其结构为:: + + { + 'iss': private_iss, + 'iat': datetime.datetime.utcnow(), + 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=7), + 'params': { + 'id': user_id, + 'username': username + } + } + """ + + async def after_auth_token(self, token_payload: dict): + """ + 在验证 Token 后调用的函数,子类可覆盖。 + + :param token_payload: Token 数据项 + """ + pass + + def token_params(self) -> dict: + """ + 取出 Token 中的参数字典。 + + :return: 参数字典 + """ + return self.token_payload.get('params', {}) + + def token_param(self, key): + """ + 取出 Token 参数字典中的参数。 + + :param key: 参数名称 + """ + return self.token_params().get(key, None) + + def set_default_headers(self): + """ + 设置默认的请求头。 + """ + request_headers = dict(self.request.headers) + allow_headers = [ + 'Accept', 'Content-Type', 'Origin', 'Access-Token', 'ClientId', 'Timestamp', 'Verify-Hash', 'Security-Key' + ] + allow_methods = [ + 'OPTIONS', 'GET', 'POST' + ] + allow_origins = [ + request_headers.get('Origin', '*') + ] + content_type = [ + request_headers.get('Content-type', 'application/json') + ] + response_header_cfg = { + 'Access-Control-Allow-Headers': ','.join(set(allow_headers)), + 'Access-Control-Allow-Methods': ','.join(set(allow_methods)), + 'Access-Control-Allow-Origin': ','.join(set(allow_origins)), + 'Access-Control-Allow-Credentials': 'true', + 'Content-type': ','.join(set(content_type)), + } + for _k, _v in response_header_cfg.items(): + self.set_header(_k, _v) + + def get_current_user(self) -> Any: + if not hasattr(self, '_current_user'): + if self.user_class is not None: + # 设置了用户类,但是未创建对象的,这里默认创建空用户对象 + setattr(self, '_current_user', self.user_class()) + else: + setattr(self, '_current_user', None) + return self._current_user + + def options(self): + """ + 处理跨域请求中的 OPTIONS 预检。 + """ + self.set_status(status_code=200) + self.finish() + + def request_arguments(self): + """ + 取得所有请求参数。若 self.request.arguments 中有参数,则优先读取。 + 若无参数,则从 self.request.body 读取,且该参数必须为 JSON 结构。 + + :return: 请求参数字典 + """ + _args: dict[str: Any] = dict() + if len(self.request.arguments) > 0: + # 按 Form 提交时,从 Form 参数中读取命令,命令参数从 request.arguments 读取 + for _n, _v in self.request.arguments.items(): + if isinstance(_v, list): + # 对数组进行分解 + if len(_v) == 1: + _args[_n] = _v[0].decode("utf-8") + else: + _args[_n] = [__v.decode("utf-8") for __v in _v] + else: + _args[_n] = f"{_v}" + else: + # 非 Form 提交时,从 Body 解析命令,命令参数从 body.params 读取 + _body = self.request.body if self.request.body else '{}' + _args = json.loads(_body) + return _args + + def response_ok(self, **kwargs): + """ + 成功响应内容。 + + :param kwargs: 参数 + """ + self.set_status(status_code=200) + chunk = {'code': 200, 'status': 'OK'} + chunk.update(kwargs) + self.write(json.dumps(chunk, cls=JsonDumpsEncoder, ensure_ascii=False)) + self.set_header('Content-Type', 'application/json') + + def response_error(self, e: Exception, status_code: int = 200, api_status_code: int = None, **kwargs): + """ + 错误响应内容。 + + :param e: 异常对象 + :param status_code: HTTP/HTTPS 响应状态码 + :param api_status_code: API 状态码,若不提供则使用 status_code 参数 + """ + if api_status_code is None: + api_status_code = status_code + + self.set_status(status_code=status_code) + chunk = {'code': api_status_code, 'status': 'error'} + chunk.update(kwargs) + if len(e.args) > 0 and isinstance(e.args[0], str): + chunk['message'] = e.args[0] + if len(e.args) > 1: + if isinstance(e.args[1], dict): + chunk.update(e.args[1]) + elif isinstance(e.args[1], list): + chunk['errors'] = e.args[1] + self.write(json.dumps(chunk, cls=JsonDumpsEncoder, ensure_ascii=False)) + self.set_header('Content-Type', 'application/json') diff --git a/paste/web/param_aware_loader.py b/paste/web/param_aware_loader.py new file mode 100644 index 0000000..a9494e5 --- /dev/null +++ b/paste/web/param_aware_loader.py @@ -0,0 +1,212 @@ +import ast +import asyncio +import hashlib +import os +import threading +from typing import Tuple, List, Any, Optional, Dict + +from tornado.template import Loader, Template +from tornado.web import UIModule + + +class ParamAwareUIModuleDataWarehouse: + """ + 预处理数据仓库。 + 数据用唯一调用 ID 作为 Key 存储。 + """ + + def __init__(self): + self._store = {} + self._lock = threading.Lock() + + def prepare(self, module_name: str, call_id: str, data: Any): + """存储预处理数据""" + with self._lock: + self._store.setdefault(module_name, {})[call_id] = data + + def fetch(self, module_name: str, call_id: str) -> Any: + """获取预处理数据""" + with self._lock: + return self._store.get(module_name, {}).get(call_id) + + +warehouse = ParamAwareUIModuleDataWarehouse() +""" +全局单例,参数感知预处理仓库。 +""" + + +class ParamAwareUIModule(UIModule): + """ + 参数感知 UIModule 父类。 + 1、子类应当实现 async_prepare 方法完成数据预处理,该方法在 Handler 执行过程中会根据模板文件的配置调用完成数据初始化,模板中配置的参数会传给该方法。 + 2、原有的 render 方法作为从数据仓库中获取数据,调用 render_with_data 方法完成渲染,已无需在子类中实现,模板中配置的参数也会传给该方法。 + 3、子类应当实现 render_with_data 方法完成渲染,预处理数据通过参数 prepared_data 传入。 + """ + + @classmethod + def generate_call_id(cls, module_name: str, kwargs: dict) -> str: + """根据模块名和参数生成唯一调用ID""" + param_str = ",".join(f"{k}={v}" for k, v in sorted(kwargs.items())) + return hashlib.md5(f"{module_name}|{param_str}".encode()).hexdigest() + + async def async_prepare(self, **kwargs) -> Any: + """子类实现异步数据加载,用静态方法避免参数缺失""" + raise NotImplementedError + + def render(self, **kwargs): + """自动关联预处理数据""" + call_id = self.generate_call_id(self.__class__.__name__, kwargs) + prepared_data = warehouse.fetch(self.__class__.__name__, call_id) + return self.render_with_data(prepared_data, **kwargs) + + def render_with_data(self, prepared_data: Any, **kwargs): + """子类实现具体渲染逻辑""" + raise NotImplementedError + + +class UIModuleCallAnalyzer(ast.NodeVisitor): + """ + 用于分析 Tornado 模板生成的 Python 代码,从中解析出对 UIModule 的名称和实际调用参数。 + """ + + def __init__(self): + self.calls = [] # 存储 (module_class_name, kwargs) + + def visit_Assign(self, node): + """ + 匹配 _tt_tmp = _tt_modules.XxxModule(...) 模式。 + + :param node: + :return: + """ + if (isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Attribute) + and isinstance(node.value.func.value, ast.Name) and node.value.func.value.id == '_tt_modules'): + module_class = node.value.func.attr + kwargs = self._extract_kwargs(node.value) + self.calls.append((module_class, kwargs)) + + @classmethod + def _extract_kwargs(cls, call_node: ast.Call) -> dict: + """ + 安全提取调用参数。 + + :param call_node: 调用节点 + :return: 实际参数 + """ + kwargs = {} + + # 处理位置参数 (Tornado不会生成这种情况) + for arg in call_node.args: + if isinstance(arg, ast.Constant): + kwargs.setdefault('_pos_args', []).append(arg.s) + + # 处理关键字参数 + for kw in call_node.keywords: + if isinstance(kw.value, (ast.Constant, ast.Constant, ast.Constant)): + kwargs[kw.arg] = ast.literal_eval(ast.unparse(kw.value)) + elif isinstance(kw.value, ast.Name) and kw.value.id in ('True', 'False', 'None'): + kwargs[kw.arg] = ast.literal_eval(kw.value.id) + + return kwargs + + @classmethod + def ui_module_calls(cls, template_code: str) -> List[Tuple[str, dict]]: + """ + 从模板生成的 Python 代码中提取 UIModule 的调用。 + + :param template_code: 模板生成的函数。 + :return: + """ + try: + tree = ast.parse(template_code) + analyzer = cls() + analyzer.visit(tree) + return analyzer.calls + except: + return [] + + +class ParamAwareTemplate(Template): + """ + 参数感知模板。 + 重写 _generate_python 方法,从 Tornado 模板编译生成的 Python 代码中分析出 UIModule 调用参数。 + 提供 prepare_ui_modules 方法在 Handler 中 load 完成后预处理数据,预处理得到的数据会保存在数据仓库中。 + """ + + def __init__(self, *args, **kwargs): + self.ui_module_calls = [] + super().__init__(*args, **kwargs) + + def _generate_python(self, *args, **kwargs): + code = super()._generate_python(*args, **kwargs) + self.ui_module_calls = UIModuleCallAnalyzer.ui_module_calls(code) + return code + + async def prepare_ui_modules(self, template: 'ParamAwareTemplate', ui_modules: dict[UIModule]): + """执行模板中所有UIModule的异步预处理""" + tasks = [] + + for module_name, kwargs in template.ui_module_calls: + module_class = ui_modules.get(module_name) + if not hasattr(module_class, 'async_prepare'): + continue + + call_id = ParamAwareUIModule.generate_call_id(module_name, kwargs) + task = asyncio.create_task( + self._prepare_single(module_class, call_id, kwargs) + ) + tasks.append(task) + + await asyncio.gather(*tasks) + + async def _prepare_single(self, module_class, call_id, kwargs): + """单个模块的预处理流程""" + try: + _ui_modulr: ParamAwareUIModule = module_class(handler=self.namespace.get('handler')) + data = await _ui_modulr.async_prepare(**kwargs) + warehouse.prepare(module_class.__name__, call_id, data) + except Exception as e: + warehouse.prepare(module_class.__name__, call_id, { + "__error__": str(e) + }) + + +class ParamAwareLoader(Loader): + """ + 参数感知装载器,也是本代码文件中主要对外开放的类。 + 重写 _create_template 方法,用参数感知模板替换原有模板。 + 重写 load 明确返回参数感知模板。 + """ + + def __init__(self, root_directory: str, **kwargs: Any) -> None: + super().__init__(root_directory, **kwargs) + self.templates = {} # type: Dict[str, ParamAwareTemplate] + + def _create_template(self, name: str) -> ParamAwareTemplate: + path = os.path.join(self.root, name) + with open(path, "rb") as f: + template = ParamAwareTemplate(f.read(), name=name, loader=self) + return template + + def load(self, name: str, parent_path: Optional[str] = None) -> ParamAwareTemplate: + """Loads a template.""" + name = self.resolve_path(name, parent_path=parent_path) + with self.lock: + if name not in self.templates: + self.templates[name] = self._create_template(name) + return self.templates[name] + + async def load_with_prepare(self, name: str) -> ParamAwareTemplate: + """ + 加载模板,并完成数据准备。 + + :param name: 模板名称 + :return: 完成数据准备的模板 + """ + template = self.load(name) + _modules = self.namespace.get('modules', None) + if _modules and hasattr(_modules, 'ui_modules'): + _ui_modules = _modules.ui_modules + await template.prepare_ui_modules(template, _ui_modules) + return template diff --git a/paste/web/requests.py b/paste/web/requests.py new file mode 100644 index 0000000..47928bc --- /dev/null +++ b/paste/web/requests.py @@ -0,0 +1,364 @@ +import asyncio +import io +import json +import logging +import mimetypes +import random +import sys +import time +from asyncio import Task +from typing import Optional, Callable, Awaitable, Dict, Any +from urllib.parse import urlencode + +from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPClientError, HTTPResponse +from tornado.web import RequestHandler + +from paste.core.logging import echo_log +from paste.util.encoder import JsonDumpsEncoder + + +_global_http_client: Optional[AsyncHTTPClient] = None + + +def get_http_client(): + """获取全局共享的 HTTP 客户端,避免重复创建和销毁。""" + global _global_http_client + if _global_http_client is None: + _global_http_client = AsyncHTTPClient() + return _global_http_client + + +async def close_http_client(): + """关闭全局 HTTP 客户端。""" + global _global_http_client + if _global_http_client: + _global_http_client.close() + _global_http_client = None + + +async def async_request(request: HTTPRequest, before_request: Callable = None, after_request: Callable = None, + retry_queue: asyncio.Queue[HTTPRequest] = None, is_log_exc=True, + on_error: Callable = None): + """ + 异步提交请求,返回响应数据对象。如提供回调函数,则将响应对象作为回调函数的参数传入,并执行。 + + :param request: 请求对象 + :param before_request: 在提交请求前要处理的行为 + :param after_request: 请求后的回调函数,回调参数为:HTTPResponse 响应对象、重试请求队列(无该参数则为 None) + :param retry_queue: 重试队列,若传入该参数,则失败的请求会放入该队列 + :param is_log_exc: 是否记录日志 + :param on_error: 发生异常后的处理,回调参数为:请求对象、异常对象、重试请求队列(无该参数则为 None) + :return: 响应数据对象 + """ + _http_client: Optional[AsyncHTTPClient] = None + try: + # 执行请求前的回调函数 + if before_request: + _before_result = before_request(request, retry_queue, is_log_exc=is_log_exc) + if isinstance(_before_result, Awaitable): + # 处理回调协程 + await _before_result + + if is_log_exc: + # 记录请求信息 + echo_log(f'请求地址:{request.method}: {request.url}.') + echo_log(f'主体长度:{sys.getsizeof(request.body)}.') + + # 在此之前的异常,不加入重试队列 + _http_client = get_http_client() + _response: HTTPResponse = await _http_client.fetch(request=request) + if after_request: + _after_result = after_request(_response, retry_queue) + if isinstance(_after_result, Awaitable): + # 处理协程回调 + await _after_result + return _response + except HTTPClientError as e: + if e.response and e.response.code in (302, 412): + # 这里依然可以拿到响应对象,继续返回 + return e.response + if is_log_exc: + echo_log(f'请求错误:{e},地址:{request.url}', level=logging.ERROR, is_log_exc=True) + if e.response is not None: + echo_log(f'响应内容:{e.response.body.decode()}', level=logging.ERROR) + if retry_queue is not None and _http_client is not None: + await retry_queue.put(request) + if on_error: + _err_result = on_error(request, e, retry_queue) + if isinstance(_err_result, Awaitable): + # 处理协程回调 + await _err_result + except ConnectionError as e: + if is_log_exc: + echo_log(f'连接错误:{e}', level=logging.ERROR, is_log_exc=True) + if retry_queue is not None and _http_client is not None: + await retry_queue.put(request) + if on_error: + _err_result = on_error(request, e, retry_queue) + if isinstance(_err_result, Awaitable): + # 处理协程回调 + await _err_result + except Exception as e: + if is_log_exc: + echo_log(f'未知错误:{e}', level=logging.ERROR, is_log_exc=True) + if retry_queue is not None and _http_client is not None: + await retry_queue.put(request) + if on_error: + _err_result = on_error(request, e, retry_queue) + if isinstance(_err_result, Awaitable): + # 处理协程回调 + await _err_result + + return None + + +async def async_concurrency(request_queue: Optional[asyncio.Queue[HTTPRequest]], con_count=10, retry=5, + before_request: Callable = None, + after_request: Callable = None, after_done: Callable = None, + is_log_exc=True, + retry_queue: asyncio.Queue[HTTPRequest] = None, + response_list: list[HTTPResponse] = None, + on_error: Callable = None, + wait_after: int = None): + """ + 异步并发请求,默认并发 10 个请求,且默认合计尝试 5 次(除第 1 次外,再尝试 4 次)。 + + :param request_queue: 请求队列 + :param con_count: 每批并发请求数量 + :param retry: 总尝试次数,默认尝试 5 次 + :param before_request: 在提交请求前要处理的行为 + :param after_request: 请求后的回调函数,回调参数为:HTTPResponse 响应对象、重试请求队列(无该参数则为 None) + :param after_done: 所有任务都完成后的回调,回调参数为:response_list,发生异常的请求不在列表中,应通过 on_error 回调获取 + :param is_log_exc: 是否记录异常日志 + :param retry_queue: 重试队列,若传入该参数,则失败的请求会放入该队列 + :param response_list: 响应列表 + :param on_error: 发生异常后的处理,回调参数为:请求对象、异常对象、重试请求队列(无该参数则为 None) + :param wait_after: 在请求完成后的等待时间,应当考虑请求服务器的处理时间,必要时可设置等待时间,但是不易设置过长一般 1~3 秒 + :return 若设置了 after_done 且有返回值,则返回,否则返响应列表 + """ + if retry_queue is None: + retry_queue = asyncio.Queue() + + if response_list is None: + response_list = [] + + while not request_queue.empty(): + # 按配置,读取队列,创建任务组 + _tasks: set[Task] = set() + for _i in range(con_count): + if request_queue.empty(): + break + + _request = await request_queue.get() + setattr(_request, 'max_retry', retry) + _task = asyncio.create_task(async_request( + request=_request, before_request=before_request, after_request=after_request, + retry_queue=retry_queue, is_log_exc=is_log_exc, on_error=on_error + )) + _tasks.add(_task) + + # 执行,并等待任务组完成 + response_list += await asyncio.gather(*_tasks) + # 处理等待 + if wait_after: + await asyncio.sleep(wait_after) + + # 检查任务(包含重试任务)是否完成,完成则返回,否则继续 + if not request_queue.empty(): + continue + + # 任然有需要重试的请求 + while not retry_queue.empty(): + _request = await retry_queue.get() + _retry = getattr(_request, 'retry', 0) + 1 + if _retry < retry: + setattr(_request, 'retry', _retry) + await request_queue.put(_request) + + if is_log_exc and not request_queue.empty(): + echo_log(f'启动重试,共有:{request_queue.qsize()} 个请求启动重试.') + + echo_log(f'所有请求执行完毕,任务结束.') + # 所有请求包括重试都已经完成,执行回调 + _result = None + if after_done: + _after_done_result = after_done(response_list) + if isinstance(_after_done_result, Awaitable): + # 处理协程回调 + _result = await _after_done_result + else: + # 普通函数调用 + _result = _after_done_result + # 钩子有返回时,返回钩子处理结果 + if _result is not None: + return _result + return response_list + + +async def async_forward(handler: RequestHandler, forward_url: str, is_log_exc=True, request_timeout=60): + """ + 转发请求。 + + :param handler: 收到请求的控制器对象 + :param forward_url: 要转发的目标地址 + :param is_log_exc: 是否记录日志 + :param request_timeout: 超时时长 + :return: 转发响应结果 + """ + _req_params = { + 'body': handler.request.body, + 'headers': { + 'Accept': '*/*', + 'Access-Token': handler.request.headers.get('Access-Token'), + 'Content-Type': handler.request.headers.get('Content-Type', 'application/json'), + 'timestamp': f'{int(time.time() * 1000)}' + }, + 'method': handler.request.method, + 'request_timeout': request_timeout, + 'url': forward_url, + } + _request = HTTPRequest(**_req_params) + + _http_client = get_http_client() + try: + if is_log_exc: + # 记录请求信息 + echo_log(f'请求地址:{_request.method}: {_request.url}.') + echo_log(f'主体长度:{sys.getsizeof(_request.body)}.') + _response: HTTPResponse = await _http_client.fetch(request=_request) + return _response + except HTTPClientError as e: + if is_log_exc: + echo_log(f'请求错误:{e},地址:{_request.url}', level=logging.ERROR, is_log_exc=True) + if e.response is not None: + echo_log(f'响应内容:{e.response.body.decode()}', level=logging.ERROR) + raise e + except ConnectionError as e: + if is_log_exc: + echo_log(f'连接错误:{e}', level=logging.ERROR, is_log_exc=True) + raise e + except Exception as e: + if is_log_exc: + echo_log(f'未知错误:{e}', level=logging.ERROR, is_log_exc=True) + raise e + + +def build_http_request( + url: str, + body: Optional[Dict[str, Any]] = None, + method: str = 'POST', + timeout: Optional[float] = None, + follow_redirects: bool = True, + use_form: bool = False, + extra_headers: Optional[Dict[str, str]] = None, + ** kwargs +) -> HTTPRequest: + """ + 构建一个 tornado.httpclient.HTTPRequest 对象。 + + 支持 GET 和 POST 方法: + - GET: 参数通过 URL 查询字符串传递 + - POST: 参数通过 JSON body 或 form 表单传递(由 use_form 控制) + + :param url: 请求的完整 URL + :param body: 请求体(字典),GET 时为查询参数,POST 时为 JSON 或 form 数据 + :param method: HTTP 方法,仅支持 'GET' 或 'POST' + :param timeout: 请求超时时间(秒) + :param follow_redirects: 是否跟随重定向 + :param use_form: 如果为 True,POST 时使用 application/x-www-form-urlencoded 格式;否则使用 JSON + :param extra_headers: 可选的额外请求头,用于传入 Cookie、Authorization 等 + :param kwargs: 其他参数,符合 tornado.httpclient.HTTPRequest 参数要求 + :return: tornado.httpclient.HTTPRequest 对象 + :raises ValueError: 当 method 不合法时抛出 + """ + if method not in ('GET', 'POST'): + raise ValueError(f"Unsupported HTTP method: {method}") + + body = body or {} + + # 基础头 + headers = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate', + 'Accept-Language': 'zh-CN,zh;q=0.9', + 'Connection': 'keep-alive', + 'Content-Type': 'application/x-www-form-urlencoded; charset=UTF-8', + 'X-Requested-With': 'XMLHttpRequest', + } + + # 合并额外头(优先级:extra_headers > DEFAULT_HEADERS) + if extra_headers: + headers.update(extra_headers) + + req_params = { + 'url': url, + 'method': method, + 'headers': headers, + 'follow_redirects': follow_redirects, + } + + if timeout: + req_params['request_timeout'] = timeout + + if method == 'GET': + # GET 方法:参数拼接到 URL + if body: + req_params['url'] = f"{url}?{urlencode(body)}" + req_params.pop('body', None) + req_params['headers'].pop('Content-Type', None) + req_params['headers'].pop('Content-Length', None) + req_params['headers'].pop('Transfer-Encoding', None) + else: + # POST 方法 + if use_form: + # 检查 body 中是否有文件对象 + has_files = any(isinstance(v, io.IOBase) for v in body.values()) + + if has_files: + # 构建 multipart/form-data + boundary = f"----WebKitFormBoundary{random.randint(10000000, 99999999)}" + multipart_body = [] + + for key, value in body.items(): + if isinstance(value, io.IOBase): + # 文件对象 + value.seek(0) # 确保从头读取 + filename = getattr(value, 'name', key) # 尝试获取文件名 + content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' + + multipart_body.append(f'--{boundary}\r\n'.encode()) + multipart_body.append(f'Content-Disposition: form-data; name="{key}"; filename="{filename}"\r\n'.encode()) + multipart_body.append(f'Content-Type: {content_type}\r\n\r\n'.encode()) + multipart_body.append(value.read()) + multipart_body.append(b'\r\n') + else: + # 普通文本 + multipart_body.append(f'--{boundary}\r\n'.encode()) + multipart_body.append(f'Content-Disposition: form-data; name="{key}"\r\n\r\n'.encode()) + # 处理布尔值和 None + if isinstance(value, bool): + value = str(value).lower() + elif value is None: + value = '' + else: + value = str(value) + multipart_body.append(value.encode('utf-8')) + multipart_body.append(b'\r\n') + + multipart_body.append(f'--{boundary}--\r\n'.encode()) + req_params['body'] = b''.join(multipart_body) + headers['Content-Length'] = str(len(req_params['body'])) + headers['Content-Type'] = f'multipart/form-data; boundary={boundary}' + else: + # 普通 form 表单(无文件) + req_params['body'] = urlencode(body).encode('utf-8') + headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' + headers['Content-Length'] = str(len(req_params['body'])) + else: + body_bytes = json.dumps(body, cls=JsonDumpsEncoder, ensure_ascii=False).encode('utf-8') + req_params['body'] = body_bytes + headers['Content-Length'] = str(len(body_bytes)) + # 保持 application/json + + req_params.update(kwargs) + return HTTPRequest(**req_params) diff --git a/paste/web/swagger.py b/paste/web/swagger.py new file mode 100644 index 0000000..0ef5443 --- /dev/null +++ b/paste/web/swagger.py @@ -0,0 +1,98 @@ +import os +import typing + +import tornado +import tornado.web +from tornado_swagger._builders import generate_doc_from_endpoints +from tornado_swagger._handlers import TornadoBaseHandler +from tornado_swagger.const import API_SWAGGER_2 +from tornado_swagger.setup import STATIC_PATH + +from paste.web.application import ApplicationSwagger + + +class SwaggerUiHandler(TornadoBaseHandler): + """ + 自定义 Ui,支持从应用程序读取文档页面。 + 主要是为了允许不同的应用具有不同的接口描述页面。 + """ + + def get(self): + if hasattr(self.application, 'swagger_home_template'): + self.write(self.application.swagger_home_template) + else: + self.write( + f'类型错误,无法从应用程序读取 swagger_home_template 属性,' + f'请使用 ApplicationSwagger 以支持 Swagger。' + ) + + +class SwaggerSpecHandler(TornadoBaseHandler): + """ + 自定义 Spec,支持从应用程序读取 Schema。 + 主要是为了允许不同的应用具有不同的接口描述页面。 + """ + + def get(self): + if hasattr(self.application, 'swagger_schema'): + self.write(self.application.swagger_schema) + else: + self.write( + f'类型错误,无法从应用程序读取 swagger_schema 属性,' + f'请使用 ApplicationSwagger 以支持 Swagger。' + ) + + +def setup_swagger( + app: ApplicationSwagger, + routes: typing.List[tornado.web.URLSpec], + *, + swagger_url: str = "/api/doc", + api_base_url: str = "/", + description: str = "Swagger API definition", + api_version: str = "1.0.0", + title: str = "Swagger API", + contact: str = "", + schemes: list = None, + security_definitions: dict = None, + security: list = None, + display_models: bool = True, + api_definition_version: str = API_SWAGGER_2 +): + """ + 注入 Swagger ui 到应用程序路由。 + """ + + swagger_schema = generate_doc_from_endpoints( + routes, + api_base_url=api_base_url, + description=description, + api_version=api_version, + title=title, + contact=contact, + schemes=schemes, + security_definitions=security_definitions, + security=security, + api_definition_version=api_definition_version, + ) + + _swagger_ui_url = f"/{swagger_url}" if not swagger_url.startswith("/") else swagger_url + _base_swagger_ui_url = _swagger_ui_url.rstrip("/") + _swagger_spec_url = f"{_swagger_ui_url}/swagger.json" + + routes[:0] = [ + tornado.web.url(_swagger_ui_url, SwaggerUiHandler), + tornado.web.url(f"{_base_swagger_ui_url}/", SwaggerUiHandler), + tornado.web.url(_swagger_spec_url, SwaggerSpecHandler), + ] + + app.swagger_schema = swagger_schema + + with open(os.path.join(STATIC_PATH, "ui.html"), "r", encoding="utf-8") as f: + app.swagger_home_template = ( + f.read().replace( + "{{ SWAGGER_URL }}", _swagger_spec_url + ).replace( + "{{ DISPLAY_MODELS }}", str(-1 if not display_models else 1) + ) + ) diff --git a/paste/web/websocket.py b/paste/web/websocket.py new file mode 100644 index 0000000..87b9800 --- /dev/null +++ b/paste/web/websocket.py @@ -0,0 +1,130 @@ +from abc import ABC +from typing import Optional, Awaitable, Any, Type + +from tornado import websocket + +import tornado.websocket + +from paste.db.basemodel import BaseModel +from paste.web.handler import init_user_class + + +class WebSocketHandler(tornado.websocket.WebSocketHandler, ABC): + """ + WebSocketHandler 的派生父类,主要增加了 send 方法,用于向客户端发送数据。 + """ + + _web_sockets: set['WebSocketHandler'] = set() + """ + 用于全局保存所有的客户端连接。 + """ + + user_class: Type[BaseModel] = init_user_class() + """ + 用户数据处理类。装饰器 web.decorators.auth_token 执行令牌验证时调用该类,用于创建用户对象,并保存在 current_user 属性中。 + 注意:这里仅初始化类,而不创建对象。该类允许用户继承扩展,然后自行配置。主要用于执行有关用户的数据操作。 + """ + + @classmethod + def add_socket(cls, web_socket): + """ + 加入 WebSocket 集合。 + + :param web_socket: 要加入的 WebSocketHandler 对象 + """ + assert hasattr(web_socket, 'send') + cls._web_sockets.add(web_socket) + + @classmethod + def get_sockets(cls): + """ + 取得 WebSocket 集合。 + """ + return cls._web_sockets + + @classmethod + def has_sockets(cls): + """ + 连接队列中是否还有连接。 + """ + return True if cls._web_sockets else False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.token_payload: dict[str: Any] = {} + """ + 令牌配载数据字典。装饰器 web.decorators.auth_token 执行令牌验证时解码并赋值。 在 HandlerRequest 子类中 + 只要配置 auth_token 装饰即可使用该配载数据。 + + 其结构为:: + + { + 'iss': private_iss, + 'iat': datetime.datetime.utcnow(), + 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=7), + 'params': { + 'id': user_id, + 'username': username + } + } + """ + + def token_params(self) -> dict: + """ + 取出 Token 中的参数字典。 + + :return: 参数字典 + """ + return self.token_payload.get('params', {}) + + def token_param(self, key): + """ + 取出 Token 参数字典中的参数。 + + :param key: 参数名称 + """ + return self.token_params().get(key, None) + + def is_connected(self): + """ + 检查当前WebSocket连接是否打开。 + """ + return self.ws_connection is not None and self.ws_connection.stream is not None + + def select_subprotocol(self, subprotocols: [str]) -> Optional[str]: + """ + 选择子协议字符串。注意:: + + 1、该方法返回的数据必须位于 subprotocols 数组中; + 2、若有 subprotocols 参数传入,默认始终返回第 0 项; + 3、用于验证的 Token 始终放在子协议的最后一项,读取该数据设置到 request.headers 中; + + :param subprotocols: 子协议数组,当前端传入字符串时,该数组仅有一项 + :return: 选择的子协议 + """ + if subprotocols: + _token = subprotocols[-1] + self.request.headers.add('Access-Token', _token) + + return subprotocols[0] + return None + + def on_close(self): + """ + 关闭连接时,从集合中删除客户端连接。 + """ + if self in self._web_sockets: + self._web_sockets.remove(self) + + def send(self) -> Optional[Awaitable[None]]: + """ + 向客户端发送数据。必须在子类中加以实现。 + """ + raise NotImplementedError() + + async def data_received(self, chunk: bytes): + pass + + def check_origin(self, origin): + return True diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2cc8def --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,132 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "paste-framework" +version = "2.0.1" +description = "A production-ready lightweight Python framework with built-in RBAC, JWT, async tasks, Swagger, and modular utilities" +readme = "README.md" +requires-python = ">=3.11" +license = {text = "MIT"} +authors = [ + {name = "Paste Contributors", email = "waynezwf@qq.com"}, +] +keywords = ["framework", "tornado", "rbac", "jwt", "redis", "swagger"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Server", + "Topic :: Software Development :: Libraries :: Application Frameworks", +] + +dependencies = [ + # --- Web 框架核心 --- + "tornado>=6.4", + + # --- 数据库 / ORM --- + "sqlalchemy>=2.0.49,<3.0", + "PyMySQL>=1.1.0", + "aiomysql>=0.2.0", + + # --- 安全 / JWT --- + "PyJWT>=1.7.1", + + # --- Redis / 消息队列 --- + "redis>=5.2.1", + + # --- 表单验证 --- + "WTForms>=3.2.1", + "tornado-wtforms>=0.0.1", + + # --- 文件处理(PDF / SVG / Excel)--- + "weasyprint>=64.1", + "svgwrite>=1.4.2", + "pandas>=2.0.0", + "xlsxwriter>=3.0", + "openpyxl>=3.1.5", + + # --- 日期 / 时间 --- + "python-dateutil>=2.8.0", + + # --- 工具 / 序列化 --- + "PyYAML>=6.0.2", + "aiofiles>=23.0.0", + "psutil>=5.9.0", + + # --- 通用 --- + "numpy>=1.24.0", +] + +[project.optional-dependencies] +swagger = [ + "tornado-swagger>=1.4.5", +] +async = [ + "aiohttp>=3.13.0", + "aiosqlite>=0.21.0", + "aioquic>=1.2.0", +] +java = [ + "javaobj-py3>=0.4.4", +] +chart = [ + "matplotlib>=3.10.1", + "scipy>=1.14.0", + "seaborn>=0.13.2", +] +image = [ + "pillow>=10.0.0", + "opencv-python>=4.11.0.86", +] +all = [ + "paste-framework[swagger,async,java,chart,image]", +] +test = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.0.0", + "httpx>=0.25.0", + "httpx-sse>=0.4.0", +] +dev = [ + "black>=24.0.0", + "flake8>=7.0.0", + "mypy>=1.0.0", + "pre-commit>=3.0.0", +] + +[project.urls] +Homepage = "https://github.com/wayne-zwf/paste" +Documentation = "https://paste-framework.readthedocs.io" +Repository = "https://github.com/wayne-zwf/paste" +Changelog = "https://github.com/wayne-zwf/paste/blob/main/CHANGELOG.md" + +[tool.setuptools.packages.find] +include = ["paste", "paste.*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" +addopts = [ + "--verbose", + "--cov=paste", + "--cov-report=term-missing", + "--cov-report=html", + "--asyncio-mode=auto", +] + +[tool.black] +line-length = 120 +target-version = ["py311"] + +[tool.flake8] +max-line-length = 120 +extend-ignore = ["E203", "W503"] + +[project.scripts] +paste-hello = "examples.01_hello_world.main:main" +paste-task = "examples.04_tasks_service.task_service:start_service" \ No newline at end of file diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..034266d --- /dev/null +++ b/tests/README.md @@ -0,0 +1,55 @@ +# 测试指南 + +## 测试目录结构 + +``` +tests/ +├── conftest.py # 全局 fixtures 和配置 +├── unit/ # 单元测试(无外部依赖,可离线运行) +│ ├── test_paste.py # 基础导入和版本 +│ ├── test_snow_id.py # 雪花ID生成器 +│ ├── test_jwt.py # JWT 令牌(mock) +│ ├── test_configure.py # 配置管理(mock) +│ ├── test_pagination.py # 分页逻辑 +│ ├── test_ustr.py # 字符串工具 +│ └── test_udict.py # 字典工具 +├── integration/ # 集成测试(需要真实服务) +│ └── test_db.py # 数据库连接测试 +└── README.md # 本文件 +``` + +## 运行方式 + +### 运行所有单元测试(推荐日常使用) + +```bash +cd <项目根目录> +pytest tests/unit/ -v +``` + +### 运行所有测试(含集成测试) + +```bash +pytest tests/ -v +``` + +### 仅运行集成测试 + +```bash +pytest tests/integration/ -v +``` + +### 生成覆盖率报告 + +```bash +pytest tests/unit/ --cov=paste --cov-report=html +open htmlcov/index.html +``` + +## 编写规范 + +1. **单元测试**放在 `tests/unit/`,无外部依赖 +2. **集成测试**放在 `tests/integration/`,加 `@pytest.mark.integration` +3. 每个测试类以 `Test` 开头,测试方法以 `test_` 开头 +4. 使用断言而非 print 验证结果 +5. 不依赖外部配置文件或数据库连接 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ef7394f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,68 @@ +""" +pytest 全局配置和 fixtures。 +单元测试使用 mock,集成测试需要真实服务。 +""" +import json +import tempfile +from pathlib import Path +from typing import Dict, Any + +import pytest + + +@pytest.fixture(scope="session") +def mock_config_dict() -> Dict[str, Any]: + """ + 模拟配置数据,用于不依赖真实 config.json 的测试。 + """ + return { + "db_engine": { + "engine": "sqlite+pysqlite:///:memory:", + "async_engine": "sqlite+aiosqlite:///:memory:", + "engine_option": {"echo": False}, + }, + "redis": { + "connection": { + "url": "redis://localhost:6379/15", + }, + }, + "rbac": { + "table": { + "rule": "rbac_rule", + "user": "rbac_user", + "item": "rbac_item", + "assignment": "rbac_assignment", + "item_child": "rbac_item_child", + }, + "user_class": "paste.rbac.rbac_user.RbacUser", + }, + "logger": { + "default": { + "basic": { + "level": 40, + }, + }, + }, + "tornado": { + "demo": { + "port": 9000, + }, + }, + } + + +@pytest.fixture +def temp_config_file(mock_config_dict): + """ + 创建临时配置文件,测试后自动清理。 + """ + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False, encoding="utf-8" + ) as f: + json.dump(mock_config_dict, f, ensure_ascii=False) + temp_path = Path(f.name) + + yield temp_path + + # 清理 + temp_path.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_db_connection.py b/tests/integration/test_db_connection.py new file mode 100644 index 0000000..c6b8f7e --- /dev/null +++ b/tests/integration/test_db_connection.py @@ -0,0 +1,39 @@ +""" +数据库连接集成测试。 +需要真实数据库连接,默认被跳过。 +运行方式:pytest tests/integration/ -v +""" + +import pytest +from paste.db.basetable import BaseTable +from paste.db.baseadapter import BaseAdapter + + +@pytest.mark.integration +class TestDbConnection: + """数据库连接基础测试""" + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_ping_database(self): + """测试数据库连通性""" + result = BaseAdapter.ping() + assert result is True + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_tables_in_db(self): + """测试获取表列表""" + tables = await BaseTable.tables_in_db() + assert isinstance(tables, list) + # 验证返回的表名称都是字符串 + for table in tables: + assert isinstance(table, str) + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_is_table_exist(self): + """测试表存在性判断""" + # 测试存在的表(information_schema.tables 必然存在) + exists = await BaseTable.is_table_exist('information_schema') + assert isinstance(exists, bool) \ No newline at end of file diff --git a/tests/integration/test_db_oracle.py b/tests/integration/test_db_oracle.py new file mode 100644 index 0000000..dbe7f27 --- /dev/null +++ b/tests/integration/test_db_oracle.py @@ -0,0 +1,28 @@ +""" +数据库集成测试。 +需要真实数据库连接,默认跳过。 +通过 `--run-integration` 参数运行。 +""" + +import pytest + +from paste.db.basetable import BaseTable + + +@pytest.mark.integration +class TestDatabaseIntegration: + """数据库集成测试""" + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_tables_in_db(self): + """测试获取数据库表列表""" + tables = await BaseTable.tables_in_db() + assert isinstance(tables, list) + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_is_table_exist(self): + """测试表存在性判断""" + exists = await BaseTable.is_table_exist('hat_article') + assert isinstance(exists, bool) \ No newline at end of file diff --git a/tests/integration/test_rbac_models.py b/tests/integration/test_rbac_models.py new file mode 100644 index 0000000..e8e4738 --- /dev/null +++ b/tests/integration/test_rbac_models.py @@ -0,0 +1,113 @@ +""" +RBAC 模型集成测试。 +需要真实数据库连接,默认被跳过。 +运行方式:pytest tests/integration/ -v +""" + +import pandas as pd +import pytest + +from paste.rbac.rbac_user import RbacUser, RbacItem, RbacPermission, RbacAssignment + + +@pytest.mark.integration +class TestRbacUser: + """RBAC 用户模型集成测试""" + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_user_query_generates_sql(self): + """验证用户查询能生成 SQL""" + query = RbacUser().gen_query() + sql = RbacUser.raw_sql(query) + assert sql is not None + assert 'SELECT' in str(sql) + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_user_query_all(self): + """测试查询所有用户""" + query = RbacUser().gen_query() + users = await RbacUser.query_all(query) + assert isinstance(users, list) + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_user_query_as_dataframe(self): + """测试用户查询返回 DataFrame""" + query = RbacUser().gen_query() + df = await RbacUser.query_as_df(query) + assert isinstance(df, pd.DataFrame) + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_find_by_username(self): + """测试根据用户名查询""" + user = await RbacUser.find_by_username('test') + # 如果用户不存在,返回 None + if user is not None: + assert isinstance(user, RbacUser) + assert user.username == 'test' + + +@pytest.mark.integration +class TestRbacItem: + """RBAC 授权项模型集成测试""" + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_item_query_all(self): + """测试查询所有授权项""" + query = RbacItem().gen_query() + items = await RbacItem.query_all(query) + assert isinstance(items, list) + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_item_query_as_dataframe(self): + """测试授权项查询返回 DataFrame""" + query = RbacItem().gen_query() + df = await RbacItem.query_as_df(query) + assert isinstance(df, pd.DataFrame) + + +@pytest.mark.integration +class TestRbacPermission: + """RBAC 权限模型集成测试""" + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_permission_query_all(self): + """测试查询所有权限""" + query = RbacPermission().gen_query() + permissions = await RbacPermission.query_all(query) + assert isinstance(permissions, list) + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_permission_query_as_dataframe(self): + """测试权限查询返回 DataFrame""" + query = RbacPermission().gen_query() + df = await RbacPermission.query_as_df(query) + assert isinstance(df, pd.DataFrame) + + +@pytest.mark.integration +class TestRbacAssignment: + """RBAC 分配关系集成测试""" + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_assignment_query_all(self): + """测试查询所有分配关系""" + query = RbacAssignment().gen_query() + assignments = await RbacAssignment.query_all(query) + assert isinstance(assignments, list) + + @pytest.mark.skip(reason="需要真实数据库连接") + @pytest.mark.asyncio + async def test_assignment_query_as_dataframe(self): + """测试分配关系查询返回 DataFrame""" + query = RbacAssignment().gen_query() + df = await RbacAssignment.query_as_df(query) + assert isinstance(df, pd.DataFrame) \ No newline at end of file diff --git a/tests/integration/test_redis.py b/tests/integration/test_redis.py new file mode 100644 index 0000000..d53947d --- /dev/null +++ b/tests/integration/test_redis.py @@ -0,0 +1,45 @@ +""" +Redis 集成测试。 +需要真实 Redis 服务,默认被跳过。 +运行方式:pytest tests/integration/ -v +""" + +import pytest + +from paste.db.redis import Redis + + +@pytest.mark.integration +class TestRedisConnection: + """Redis 集成测试""" + + @pytest.mark.skip(reason="需要真实 Redis 服务") + @pytest.mark.asyncio + async def test_redis_ping(self): + """测试 Redis 连通性""" + result = await Redis.ping() + assert result is True + + @pytest.mark.skip(reason="需要真实 Redis 服务") + @pytest.mark.asyncio + async def test_redis_get_set(self): + """测试 Redis 基本读写""" + from paste.db.redis import Redis + + async with await Redis.get_redis() as r: + # 写入测试 + await r.set("test_key", "test_value") + # 读取验证 + value = await r.get("test_key") + assert value == b"test_value" + # 清理 + await r.delete("test_key") + + @pytest.mark.skip(reason="需要真实 Redis 服务") + @pytest.mark.asyncio + async def test_redis_get_keys(self): + """测试获取所有 keys""" + from paste.db.redis import Redis + + keys = await Redis.keys() + assert isinstance(keys, list) \ No newline at end of file diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..0e4c0a5 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,60 @@ +""" +测试配置读取功能。 +使用 mock 配置文件,不依赖真实 config.json。 +""" + +from unittest.mock import patch + +import pytest + +from paste.core import config + + +class TestConfiguration: + """配置管理测试""" + + def test_load_config_with_mock_file(self, temp_config_file): + """使用临时配置文件测试加载""" + with patch.object(config, 'load_config', wraps=config.load_config): + # 模拟配置文件路径 + with patch('paste.core.config.os.path.abspath') as mock_abspath: + mock_abspath.return_value = str(temp_config_file) + cfg = config.load_config() + assert isinstance(cfg, dict) + assert 'db_engine' in cfg + + def test_get_config_by_path_existing(self, mock_config_dict): + """测试读取存在的配置项""" + with patch('paste.core.config.GLOBAL_CONFIG', mock_config_dict): + result = config.get_config_by_path('db_engine.engine') + assert result == "sqlite+pysqlite:///:memory:" + + def test_get_config_by_path_nested(self, mock_config_dict): + """测试读取深层嵌套配置""" + with patch('paste.core.config.GLOBAL_CONFIG', mock_config_dict): + result = config.get_config_by_path('redis.connection.url') + assert result == "redis://localhost:6379/15" + + def test_get_config_by_path_with_default(self, mock_config_dict): + """测试带默认值的配置读取""" + with patch('paste.core.config.GLOBAL_CONFIG', mock_config_dict): + result = config.get_config_by_path('nonexistent.key', default='fallback') + assert result == 'fallback' + + def test_get_config_by_path_missing_without_default(self, mock_config_dict): + """测试缺失配置项且无默认值时抛出异常""" + with patch('paste.core.config.GLOBAL_CONFIG', mock_config_dict): + with pytest.raises(AssertionError): + config.get_config_by_path('nonexistent.key') + + def test_get_config_shortcut(self, mock_config_dict): + """测试 get_config 快捷方法""" + with patch('paste.core.config.GLOBAL_CONFIG', mock_config_dict): + result = config.get_config('tornado.demo.port') + assert result == 9000 + + def test_get_config_default_none(self, mock_config_dict): + """测试 get_config 默认值 None 的情况""" + with patch('paste.core.config.GLOBAL_CONFIG', mock_config_dict): + with pytest.raises(AssertionError): + config.get_config('completely.nonexistent') \ No newline at end of file diff --git a/tests/unit/test_jwt.py b/tests/unit/test_jwt.py new file mode 100644 index 0000000..b264d3d --- /dev/null +++ b/tests/unit/test_jwt.py @@ -0,0 +1,75 @@ +""" +测试 JWT 令牌编解码功能。 +使用 mock 配置,不依赖真实密钥文件。 +""" + +import time + +import pytest + +from paste.security.token import encode_token, decode_token + + +class TestJwtToken: + """JWT 令牌测试""" + + def test_encode_decode_basic(self): + """基础编解码测试""" + payload = { + 'user_id': 123, + 'username': 'test_user', + 'role': 'admin', + } + token = encode_token(**payload) + assert token is not None + assert isinstance(token, str) + assert len(token) > 0 + + decoded = decode_token(token) + assert decoded is not None + assert decoded.get('params', {}).get('user_id') == 123 + + def test_token_contains_expected_fields(self): + """验证 token 包含必要字段""" + payload = {'user_id': 456, 'username': 'demo'} + token = encode_token(**payload) + decoded = decode_token(token) + + # 标准 JWT 字段 + assert 'iss' in decoded, "Token should have issuer" + assert 'iat' in decoded, "Token should have issued-at time" + assert 'exp' in decoded, "Token should have expiration time" + + # 自定义字段 + params = decoded.get('params', {}) + assert params.get('user_id') == 456 + assert params.get('username') == 'demo' + + def test_token_expiration(self): + """验证 token 过期机制""" + payload = { + 'user_id': 789, + 'username': 'expired_user', + 'exp': int(time.time()) - 3600, # 1小时前过期 + } + token = encode_token(**payload) + + with pytest.raises(Exception): + decode_token(token) + + def test_token_tampering(self): + """验证 token 防篡改""" + payload = {'user_id': 999, 'username': 'hacker'} + token = encode_token(**payload) + + # 篡改 token + tampered_token = token[:-5] + 'XXXXX' + + with pytest.raises(Exception): + decode_token(tampered_token) + + def test_empty_payload(self): + """空 payload 处理""" + token = encode_token() + decoded = decode_token(token) + assert decoded is not None \ No newline at end of file diff --git a/tests/unit/test_pagination.py b/tests/unit/test_pagination.py new file mode 100644 index 0000000..bf06e39 --- /dev/null +++ b/tests/unit/test_pagination.py @@ -0,0 +1,119 @@ +""" +测试分页逻辑。 +无外部依赖,可离线运行。 +""" + +from paste.util.pagination import Pagination + + +class TestPagination: + """分页功能测试""" + + def test_pages_calculation_exact(self): + """精确整除的分页计算""" + p = Pagination(row_count=100) + pages = p.pages(page_size=20) + assert pages == 5 + + def test_pages_calculation_remainder(self): + """有余数的分页计算""" + p = Pagination(row_count=101) + pages = p.pages(page_size=20) + assert pages == 6 + + def test_pages_calculation_zero_rows(self): + """零行数据的处理""" + p = Pagination(row_count=0) + pages = p.pages(page_size=20) + assert pages == 1 + + def test_page_number_valid(self): + """页码有效性检测""" + p = Pagination(row_count=50) + p.pages(page_size=20) + assert p.number(1) == 1 + assert p.number(3) == 3 # 超出范围应返回最大页 + + def test_page_number_negative(self): + """负页码处理""" + p = Pagination(row_count=50) + p.pages(page_size=20) + assert p.number(-1) == 1 + + def test_offset_calculation(self): + """偏移量计算""" + p = Pagination(row_count=100) + p.pages(page_size=20) + assert p.offset(1) == 0 + assert p.offset(2) == 20 + assert p.offset(3) == 40 + + def test_paging_chain(self): + """链式调用分页""" + p = Pagination(row_count=123).paging(page_number=2, page_size=20) + assert p.page_count == 7 + assert p.page_number == 2 + assert p.offset_size == 20 + + def test_page_size_bounds(self): + """页大小边界: 最小1,最大1000""" + p = Pagination(row_count=2000) + assert p.pages(page_size=0) > 0 + assert p.pages(page_size=2000) <= 1000 + + def test_large_dataset(self): + """大数据集分页""" + p = Pagination(row_count=1000000) + pages = p.pages(page_size=100) + assert pages == 10000 + + def test_single_row(self): + """单行数据""" + p = Pagination(row_count=1) + assert p.pages(page_size=10) == 1 + p.paging(page_number=1, page_size=10) + assert p.page_number == 1 + assert p.offset_size == 0 + + def test_page_number_upper_bound(self): + """页码上限处理""" + p = Pagination(row_count=30).paging(page_number=100, page_size=10) + assert p.page_number == 3 # 最大只能到3页 + + +# ============================================================ +# 以下是从 test_db.py 迁移过来的分页测试 +# 原函数 test_pagination() 改为标准的 pytest 测试 +# ============================================================ + + +class TestPaginationFromDb: + """从 test_db.py 迁移的分页测试""" + + def test_pagination_basic(self): + """基础分页计算""" + from paste.util.pagination import Pagination + p = Pagination(row_count=123).paging(page_number=2, page_size=20) + assert p.page_count == 7 + assert p.page_number == 2 + assert p.offset_size == 20 + + def test_pagination_first_page(self): + """第一页""" + from paste.util.pagination import Pagination + p = Pagination(row_count=50).paging(page_number=1, page_size=10) + assert p.page_number == 1 + assert p.offset_size == 0 + + def test_pagination_last_page(self): + """最后一页""" + from paste.util.pagination import Pagination + p = Pagination(row_count=55).paging(page_number=6, page_size=10) + assert p.page_number == 6 + assert p.offset_size == 50 + + def test_pagination_out_of_range(self): + """超出范围时自动修正""" + from paste.util.pagination import Pagination + p = Pagination(row_count=30).paging(page_number=100, page_size=10) + assert p.page_number == 3 # 只有3页,自动修正 \ No newline at end of file diff --git a/tests/unit/test_paste.py b/tests/unit/test_paste.py new file mode 100644 index 0000000..d998b10 --- /dev/null +++ b/tests/unit/test_paste.py @@ -0,0 +1,26 @@ +""" +测试 paste 包的基本导入和版本信息。 +无外部依赖,可离线运行。 +""" + +import paste + + +class TestPasteImport: + """测试 paste 包基础功能""" + + def test_paste_imports(self): + """确保 paste 包能正确导入""" + assert paste is not None + + def test_paste_version(self): + """检查 paste 包是否有 __version__""" + assert hasattr(paste, "__version__"), "paste package should have __version__" + assert isinstance(paste.__version__, str), "__version__ should be a string" + + def test_paste_version_value(self): + """验证版本号格式符合语义化版本规范""" + import re + version = paste.__version__ + assert re.match(r'^\d+\.\d+\.\d+', version), \ + f"Version {version} should follow semver format" \ No newline at end of file diff --git a/tests/unit/test_snow_id.py b/tests/unit/test_snow_id.py new file mode 100644 index 0000000..bc24237 --- /dev/null +++ b/tests/unit/test_snow_id.py @@ -0,0 +1,46 @@ +""" +测试雪花 ID 生成器。 +无外部依赖,可离线运行。 +""" + +from paste.util.snow_id import IdWorker + + +class TestSnowflakeId: + """雪花 ID 生成器测试""" + + def test_snow_id_generates_string(self): + """测试 Snowflake ID 是否生成字符串""" + id_worker = IdWorker.get_id_worker() + sid = f'{id_worker.get_id()}' + assert isinstance(sid, str), "雪花 ID 必须是字符串" + assert len(sid) > 0, "雪花 ID 必须包含内容" + + def test_snow_id_is_unique(self): + """测试生成的 ID 是否唯一(简单验证)""" + id_worker = IdWorker.get_id_worker() + ids = [id_worker.get_id() for _ in range(50)] + assert len(set(ids)) == len(ids), "All generated IDs should be unique" + + def test_snow_id_monotonic_increase(self): + """测试雪花 ID 单调递增""" + id_worker = IdWorker.get_id_worker() + ids = [id_worker.get_id() for _ in range(100)] + for i in range(1, len(ids)): + assert ids[i] > ids[i - 1], \ + f"ID at position {i} should be greater than previous" + + def test_snow_id_worker_isolation(self): + """测试不同 worker 生成的 ID 不冲突""" + worker1 = IdWorker.get_id_worker(datacenter_id=1, worker_id=1) + worker2 = IdWorker.get_id_worker(datacenter_id=2, worker_id=2) + ids = [worker1.get_id() for _ in range(50)] + \ + [worker2.get_id() for _ in range(50)] + assert len(set(ids)) == len(ids), \ + "IDs from different workers should be unique" + + def test_snow_id_high_throughput(self): + """测试短时间高并发生成""" + id_worker = IdWorker.get_id_worker() + ids = [id_worker.get_id() for _ in range(1000)] + assert len(set(ids)) == 1000 \ No newline at end of file diff --git a/tests/unit/test_udict.py b/tests/unit/test_udict.py new file mode 100644 index 0000000..31eb371 --- /dev/null +++ b/tests/unit/test_udict.py @@ -0,0 +1,45 @@ +""" +测试字典工具。 +无外部依赖,可离线运行。 +""" + +from paste.util import udict + + +class TestUdict: + """字典工具测试""" + + def test_get_by_path_simple(self): + """简单路径读取""" + data = {"a": 1, "b": 2} + assert udict.get_by_path(data, "a") == 1 + + def test_get_by_path_nested(self): + """嵌套路径读取""" + data = {"a": {"b": {"c": 123}}} + assert udict.get_by_path(data, "a.b.c") == 123 + + def test_get_by_path_missing(self): + """缺失路径处理""" + data = {"a": 1} + assert udict.get_by_path(data, "b.c.d", "default") == "default" + + def test_get_by_path_none_default(self): + """缺失路径无默认值""" + data = {"a": 1} + assert udict.get_by_path(data, "b") is None + + def test_get_with_default_existing(self): + """存在的键读取""" + data = {"key": "value"} + assert udict.get_with_default(data, "key", "fallback") == "value" + + def test_get_with_default_missing(self): + """缺失键使用默认值""" + data = {"key": "value"} + assert udict.get_with_default(data, "missing", "fallback") == "fallback" + + def test_get_with_default_none_value(self): + """值为 None 时使用默认值""" + data = {"key": None} + assert udict.get_with_default(data, "key", "fallback") == "fallback" \ No newline at end of file diff --git a/tests/unit/test_ustr.py b/tests/unit/test_ustr.py new file mode 100644 index 0000000..174e153 --- /dev/null +++ b/tests/unit/test_ustr.py @@ -0,0 +1,44 @@ +""" +测试字符串工具。 +无外部依赖,可离线运行。 +""" + +import datetime +from paste.util import ustr + + +class TestUstr: + """字符串工具测试""" + + def test_str_q_count_all_cn(self): + """全中文统计""" + assert ustr.str_q_count("中国汉字") == 4 + + def test_str_q_count_mixed(self): + """中英文混合统计""" + count = ustr.str_q_count("Hello中国") + assert count == 2 # 只有中文字符算 + + def test_str_q_count_empty(self): + """空字符串统计""" + assert ustr.str_q_count("") == 0 + + def test_str_q_count_no_cn(self): + """纯英文统计""" + assert ustr.str_q_count("HelloWorld") == 0 + + def test_to_datetime_standard(self): + """标准格式解析""" + result = ustr.to_datetime("2024-01-15 10:30:00", ["%Y-%m-%d %H:%M:%S"]) + assert result is not None + assert isinstance(result, datetime.datetime) + + def test_to_datetime_invalid(self): + """无效格式解析""" + result = ustr.to_datetime("not-a-date", ["%Y-%m-%d"]) + assert result is None + + def test_to_datetime_empty(self): + """空字符串解析""" + result = ustr.to_datetime("", ["%Y-%m-%d"]) + assert result is None \ No newline at end of file