commit 47296980495f8bbfc9493e93de85dd62de6fa6b9
Author: zwf <2466627138@qq.com>
Date: Tue Jun 2 19:09:22 2026 +0800
Squashed 'paste-framework/' content from commit 34e8684
git-subtree-dir: paste-framework
git-subtree-split: 34e8684c4bc3cebbe177509f42ab4ef5b5425a7a
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..5753c7c
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,88 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# Caches
+*.cache
+*.db
+*.log
+
+# Distribution / packaging
+.build/
+dist/
+build/
+*.egg-info/
+*.egg/
+
+# PyInstaller
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+
+# PyCharm
+.idea/
+*.iml
+*.ipynb_checkpoints
+
+# VS Code
+.vscode/
+
+# Environment
+.venv/
+venv/
+ENV/
+env/
+.ENV/
+env.bak/
+venv.bak/
+
+# Logs
+logs/
+
+# Examples 运行时日志(演示用途,不提交)
+examples/*/logs/
+
+# Config files (local overrides)
+config.json
+config.local.json
+
+!examples/*/config.json
+
+# OS
+.DS_Store
+Thumbs.db
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# Docker
+docker-compose.yml
+Dockerfile
+
+# IDE
+*.swp
+*.swo
+*.swn
+
+# Backup files
+*.bak
+*.tmp
+
+# Project-specific
+.PASTE_FRAMEWORK_CACHE/
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..9cc86f0
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,56 @@
+# Changelog
+
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
+and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+## [2.0.1] - 2025-04-08
+
+### Fixed
+- `JsonDumpsEncoder` now handles `NaTType` (pandas NaT) without throwing TypeError
+- CORS `Access-Control-Allow-Origin` now defaults to `"*"` instead of echoing the request Origin header (security fix)
+- `RbacUser.can()` correctly handles empty permission list
+- Fixed race condition in `aio_pool.py` when task queue is full under high concurrency
+
+### Changed
+- Refactored `config.py` — `get_config()` now consistently returns `None` instead of raising on missing optional keys (use `default` parameter for fallback)
+- Upgraded minimum Python version from 3.9 to 3.11
+- `pyproject.toml` dependencies cleaned up — core dependencies slimmed down; pandas/numpy/matplotlib/opencv moved to `[project.optional-dependencies]`
+
+### Removed
+- Deprecated `web/legacy_session.py` — use JWT-based `@auth_token` decorator instead
+
+## [2.0.0] - 2025-03-15
+
+### Added
+- **Swagger UI auto-generation** (`ApplicationSwagger`) — fully inferred from handler route patterns and decorators
+- **Redis StreamActor** — consumer-group-based message processing with automatic zombie task recovery
+- **ParamAwareUIModule** — async data pre-loading for Tornado UIModules
+- **RBAC with dynamic rules** — rules are serialized Python classes stored in DB, supporting time-based, IP-based, and custom rule chains
+- **Snowflake ID generator** (`snow_id.py`) — thread-safe, no external dependency, 10K+ IDs/sec
+- **Auto-loading handlers** — `Application.load_handlers()` discovers all handlers in a package via `route_pattern` attribute
+- **BaseX encoder/decoder** (`encoder.py`) — auto-detect Base16/32/64/85 + zlib decompression
+- **Async DB engine** — SQLAlchemy async session factory with `aiomysql` / `aiosqlite` support
+- **Task service with daemonization** — `TaskService` with PID file, start/stop/status commands, and graceful shutdown
+- **5 complete examples** under `examples/` — HelloWorld, background tasks, Redis streams, scheduled services, model generation
+- **pytest test suite** — unit tests (mocked) + integration tests for DB, Redis, RBAC
+
+### Changed
+- Complete project restructuring: introduced `web/`, `core/`, `db/`, `rbac/`, `util/`, `service/` package layout
+- `Application` now inherits from `tornado.web.Application` — backward-incompatible refactor from v1.x
+- `RequestHandler.response_ok()` and `response_error()` now use `JsonDumpsEncoder` by default
+- Logging system rewritten — supports rotating file handlers, per-module loggers, console output
+
+### Removed
+- All v1.x code — this is a ground-up rewrite
+
+## [1.0.0] - 2024-01-01
+
+### Added
+- Initial release with basic Tornado wrapper
+- Basic JWT auth and password hashing
+- Simple `config.json` loader
+- Minimal utility functions
+
+> ⚠️ **Note:** v1.0.0 is deprecated and no longer maintained. Please upgrade to v2.0.0+.
\ No newline at end of file
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..0ba0c22
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,41 @@
+# Contribution Guidelines
+
+Thank you for considering contributing to this project! Your involvement helps make this project better for everyone.
+
+## How to Contribute
+
+1. **Fork the repository**
+2. **Create a new feature branch**
+ ```bash
+ git checkout -b feature/your-feature-name
+ ```
+3. **Make your changes**
+ Write clean, well-documented code. Follow the project’s coding standards.
+4. **Add tests**
+ All new features or bug fixes must include appropriate unit or integration tests.
+5. **Commit your changes**
+ Use descriptive commit messages following [Conventional Commits](https://www.conventionalcommits.org/).
+6. **Push to your fork**
+ ```bash
+ git push origin feature/your-feature-name
+ ```
+7. **Open a Pull Request**
+ Clearly describe what you changed, why, and how it affects the project.
+
+## Code Style & Standards
+
+- Follow the project’s linter rules (e.g., PEP8 for Python, ESLint for JavaScript)
+- Keep functions small and focused
+- Document public APIs with docstrings or comments
+- Never commit secrets or environment files
+
+## Reporting Issues
+
+If you find a bug or have a feature request, please open an [Issue](https://github.com/your-repo/issues) with:
+- A clear title
+- Steps to reproduce (for bugs)
+- Expected vs actual behavior
+
+## License
+
+By contributing, you agree that your contributions will be licensed under the MIT License.
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..76b0a08
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2026 Paste Contributors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..1a22148
--- /dev/null
+++ b/README.md
@@ -0,0 +1,287 @@
+```markdown
+┌─────────────────────────────────────────────────┐
+│ │
+│ ██████╗ █████╗ ███████╗████████╗███████╗ │
+│ ██╔══██╗██╔══██╗██╔════╝╚══██╔══╝██╔════╝ │
+│ ██████╔╝███████║███████╗ ██║ █████╗ │
+│ ██╔═══╝ ██╔══██║╚════██║ ██║ ██╔══╝ │
+│ ██║ ██║ ██║███████║ ██║ ███████╗ │
+│ ╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚══════╝ │
+│ │
+│ Python Api-first Scalable Task Engine │
+│ │
+└─────────────────────────────────────────────────┘
+```
+
+# PASTE — A Production-Ready Lightweight Python Framework
+
+> A minimalist, battle-tested Python framework for building api services with **built-in RBAC, JWT, async tasks, Snowflake IDs, Swagger, and modular utilities** — zero boilerplate, maximum control. Built on top of **Tornado**.
+
+
+
+
+
+---
+
+## ✨ Core Features
+
+| Feature | Description |
+|---------|-------------|
+| 🚀 **Auto-Loading Handlers** | Define `route_pattern = "/user"` in your handler → API is automatically registered. No router config needed. |
+| 🔐 **RBAC with Dynamic Rules** | Permissions stored as serialized Python objects in DB. Define complex rules as classes (e.g., `TimeBasedRule`, `IPWhitelistRule`). |
+| 📜 **Swagger UI Auto-Generated** | Built-in Swagger UI at `/docs`. No YAML. Schema is inferred from handlers and decorators. |
+| 🔢 **Snowflake ID Generator** | Embedded, thread-safe, no external dependency. Generates 10K+ unique IDs/sec. |
+| ⚡ **Async Task Pool with Backpressure** | `run_background_task(coro)` safely manages concurrent background jobs with configurable limits. |
+| 🔐 **JWT + Password Hashing** | Stateless authentication via `paste.security.token`. Passwords hashed with PBKDF2-sha256. |
+| 🧰 **Modular Utility Library (`util/`)** | Clean, testable utilities: `ustr`, `udict`, `ufile`, `pagination`, `snow_id`, `encoder` — no pollution in `__init__.py`. |
+| 📦 **Layered Architecture** | `core/` (config, logging, async), `db/` (engine, models), `rbac/` (roles, rules), `web/` (handlers, app), `service/` (business logic) — clean separation. |
+| 🛡️ **Safe File Handling** | `sanitize_filename()` ensures cross-platform compatibility (Windows/Linux/macOS). |
+| 📊 **JSON Encoder for Complex Types** | Auto-serializes `datetime`, `Decimal`, `BaseModel`, `Enum`, `numpy` — no extra config. |
+| 📬 **Redis Stream Actor** | Reliable consumer-group-based message processing with auto zombie-task recovery. |
+| 🧩 **ParamAware UI Modules** | Async-preload data for Tornado UIModules — solves the age-old sync-render bottleneck. |
+
+---
+
+## 🛠️ Quick Start
+
+### 1. Install
+
+```bash
+git clone https://github.com/your-repo/paste.git
+cd paste
+pip install -e .
+```
+
+### 2. Run the example
+
+All ready-to-run examples are located in the `examples/` directory:
+
+```bash
+cd examples/01_hello_world
+python main.py
+```
+
+👉 Open: [http://localhost:9090/hello](http://localhost:9090/hello)
+
+---
+
+## 📚 Examples
+
+Jump right in with these working examples:
+
+| Example | Description | Key Features Demonstrated |
+|---------|-------------|--------------------------|
+| [`01_hello_world`](examples/01_hello_world/) | Minimal paste app | Auto-loading handlers, config, response helpers |
+| [`02_background_task`](examples/02_background_task/) | Async background tasks | Task pool, logging, daemonized services |
+| [`03_redis_stream`](examples/03_redis_stream/) | Redis Stream publisher & consumer | `StreamActor`, consumer groups, message ACK |
+| [`04_tasks_service`](examples/04_tasks_service/) | Scheduled task service | `TaskService`, cron-like scheduling, PID management |
+| [`05_gen_models`](examples/05_gen_models/) | Auto-generate DB models from existing tables | SQLAlchemy model generation |
+
+Each example includes a `config.json`, a `handler.py`, and a `main.py` — ready to run.
+
+---
+
+## 🏗️ Framework Architecture
+
+```
+paste/ # Framework core (do NOT modify)
+├── core/ # Foundation layer
+│ ├── config.py # Dot-path config loader (get_config("db.engine"))
+│ ├── logging.py # Logger with rotating file + console
+│ └── aio_pool.py # Async task pool with backpressure
+├── db/ # Database layer
+│ ├── engine.py # SQLAlchemy engine factory
+│ ├── redis.py # Redis connection + StreamActor
+│ ├── basemodel.py # Async ORM base model
+│ ├── basetable.py # Table reflection utilities
+│ ├── baseadapter.py # Result-set adapter
+│ └── gen_models.py # Auto-generate model classes
+├── web/ # Web layer
+│ ├── application.py # Tornado Application with auto-loading
+│ ├── handler.py # Base RequestHandler with response helpers
+│ ├── decorators.py # @route, @auth_token, @auth_permission
+│ ├── swagger.py # Swagger UI auto-generation
+│ ├── form.py # WTForms integration
+│ ├── param_aware_loader.py # Async UIModule data preloader
+│ └── websocket.py # WebSocket support
+├── rbac/ # Access control layer
+│ ├── rbac_user.py # User with permission inheritance
+│ ├── rbac_role.py # Role management
+│ ├── rbac_rule.py # Rule engine (pickle-serialized)
+│ ├── rbac_item.py # Permission item hierarchy
+│ ├── rbac_assignment.py # User-item assignment
+│ └── rbac_permission.py # Permission query
+├── security/ # Security layer
+│ ├── token.py # JWT encode/decode with configurable issuer
+│ └── shash.py # PBKDF2-sha256 password hashing
+├── service/ # Service layer
+│ ├── server.py # Server process management
+│ ├── daemonize.py # Unix daemon + PID file
+│ └── task_service.py # Scheduled task runner
+├── chart/ # Chart generation (optional)
+│ ├── bar.py / pie.py / line.py
+│ └── ...
+└── util/ # Utility library
+ ├── ustr.py / udict.py # String & dict helpers
+ ├── ufile.py # File operations with sanitized filenames
+ ├── pagination.py # Page, Pagination, CursorPagination
+ ├── snow_id.py # Snowflake ID generator
+ ├── encoder.py # JSON encoder + BaseX decoder
+ ├── pdf.py / svg.py / xlsx.py
+ └── ...
+```
+
+Your application code lives **entirely outside** `paste/`:
+
+```
+myapp/
+├── main.py # Entry point: create Application, listen, start
+├── config.json # DB, logger, RBAC, Tornado configuration
+├── apps/ # Your handlers (recommended)
+│ └── demo/
+│ ├── __init__.py
+│ ├── handler_user.py
+│ └── handler_auth.py
+├── models/ # Your DB models
+│ └── db_models.py
+└── service/ # Your background services
+ └── __init__.py
+```
+
+---
+
+## 🔐 RBAC Example
+
+```python
+# handler.py
+from paste.rbac.rbac_user import RbacUser
+from paste.web.decorators import route, auth_token, auth_permission
+from paste.web.handler import RequestHandler
+
+@route("/admin/users")
+class AdminUserHandler(RequestHandler):
+ @auth_token
+ @auth_permission
+ async def get(self):
+ users = await RbacUser.query_as_df(RbacUser().gen_query())
+ self.response_ok(rows=users.to_dict('records'))
+```
+
+```python
+# rule.py — dynamic rule as a serialized Python class
+from datetime import datetime
+from paste.rbac.rbac_rule import RbacRule
+
+class BusinessTimeRule(RbacRule):
+ async def run(self, **kwargs) -> bool:
+ hour = datetime.now().hour
+ return 9 <= hour < 18 # only allow during business hours
+```
+
+---
+
+## ⚙️ Configuration
+
+All configuration lives in a single `config.json`:
+
+```json
+{
+ "tornado": {
+ "demo": {
+ "autoreload": false,
+ "handlers_pkg": "apps.demo",
+ "port": 9090,
+ "static_path": "static",
+ "template_path": "templates",
+ "swagger_title": "DemoAPI",
+ "swagger_description": "Demo API",
+ "swagger_api_version": "1.0.1"
+ }
+ },
+ "db": {
+ "engine": {
+ "engine": "mysql+pymysql://user:pass@localhost:3306/mydb",
+ "async_engine": "mysql+aiomysql://user:pass@localhost:3306/mydb",
+ "engine_option": {
+ "echo": false,
+ "pool_size": 10
+ }
+ }
+ },
+ "redis": {
+ "connection": "redis://localhost:6379/0",
+ "streams": {
+ "user_event": {
+ "group": "user_event_group",
+ "consumer": "processor_01"
+ }
+ }
+ },
+ "rbac": {
+ "user_class": "myapp.models.MyRbacUser",
+ "table": {
+ "rule": "rbac_rule",
+ "user": "rbac_user",
+ "item": "rbac_item",
+ "assignment": "rbac_assignment",
+ "item_child": "rbac_item_child"
+ }
+ },
+ "logger": {
+ "default": {
+ "basic": {
+ "filename": "logs/app.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "level": 20
+ }
+ }
+ }
+}
+```
+
+Access any config value with dot-path:
+
+```python
+from paste.core import config
+engine_url = config.get_config("db.engine.engine")
+port = config.get_config("tornado.demo.port", 9000)
+```
+
+---
+
+## 📣 License
+
+MIT © 2025 [Wayne Zhang / Organization]
+
+---
+
+## 🌍 Contributing
+
+Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on:
+- Code style (Black + flake8 + mypy)
+- How to submit issues and pull requests
+
+---
+
+## 🧪 Running Tests
+
+```bash
+pytest # Run all tests
+pytest tests/unit # Run unit tests only (no external services)
+pytest tests/integration # Run integration tests (require DB/Redis)
+pytest --cov=paste # Run with coverage report
+```
+
+---
+
+> **💡 Why paste?**
+>
+> Most frameworks make you write boilerplate.
+> `paste` makes you **stop writing boilerplate** — and **start building features**.
+>
+> - ✅ No `@app.get("/user")` decorators — just `route_pattern = "/user"`
+> - ✅ No YAML for Swagger — schema inferred automatically
+> - ✅ No Redis for IDs — Snowflake built-in
+> - ✅ No Celery for tasks — async pool with backpressure
+> - ✅ No config files in 5 places — one `config.json`, one `get_config("db.engine")`
+```
diff --git a/docs/analysis-report.md b/docs/analysis-report.md
new file mode 100644
index 0000000..a2e4ded
--- /dev/null
+++ b/docs/analysis-report.md
@@ -0,0 +1,134 @@
+## ✅ **全面分析报告:PASTE 框架**
+
+---
+
+### **1. 工程规范(Project Engineering Standards)**
+
+| 维度 | 分析 | 评价 |
+|------|------|------|
+| **项目结构** | 清晰分层:`paste/`(框架核心)与 `myapp/`(用户应用)分离,符合“框架 vs 应用”最佳实践。`examples/` 提供完整可运行示例,极大降低学习曲线。 | ⭐⭐⭐⭐⭐ 极佳 |
+| **依赖管理** | 使用 `pyproject.toml`,符合现代 Python 标准(PEP 621)。明确区分 `dependencies`、`optional-dependencies`(如 `swagger`, `async`, `all`)、`test` 和 `dev`,便于环境隔离。 | ⭐⭐⭐⭐⭐ |
+| **构建与打包** | 使用 `setuptools` 作为构建后端,支持 `pip install -e .` 开发模式,符合 Python 生态标准。 | ⭐⭐⭐⭐⭐ |
+| **版本与发布** | 版本号为 `2.0.1`,符合语义化版本(SemVer)。`CHANGELOG.md` 和 `CONTRIBUTING.md` 存在,表明有规范的发布流程。 | ⭐⭐⭐⭐ |
+| **CI/CD 暗示** | `.github/` 目录存在,暗示 GitHub Actions 可能用于 CI(虽未提供内容,但结构合理)。 | ⭐⭐⭐⭐ |
+| **文档驱动开发** | `README.md` 详尽,包含架构图、配置示例、RBAC 示例、测试命令,是文档优先(Docs-First)开发的典范。 | ⭐⭐⭐⭐⭐ |
+
+> ✅ **结论**:工程规范成熟、现代、标准化,完全符合企业级 Python 项目标准。
+
+---
+
+### **2. 代码规范(Code Style & Quality)**
+
+| 维度 | 分析 | 评价 |
+|------|------|------|
+| **格式化工具** | 使用 `black`(行宽 120),符合企业级团队协作标准(比默认 88 更宽松,适合 API 项目)。 | ⭐⭐⭐⭐⭐ |
+| **静态检查** | 集成 `flake8`,忽略 `E203`(多行切片)、`W503`(行尾操作符),说明团队有明确风格偏好,非盲目遵循 PEP8。 | ⭐⭐⭐⭐ |
+| **类型检查** | 集成 `mypy`,表明重视类型安全,支持大型项目可维护性。 | ⭐⭐⭐⭐⭐ |
+| **预提交钩子** | `pre-commit` 被列为 `dev` 依赖,说明代码提交前自动格式化/检查,减少人工疏漏。 | ⭐⭐⭐⭐⭐ |
+| **模块化设计** | `util/` 目录下 `ustr.py`, `udict.py` 等模块功能单一、职责清晰,符合 UNIX 哲学(“做一件事,做好它”)。 | ⭐⭐⭐⭐⭐ |
+| **命名与结构** | 所有模块命名清晰,如 `rbac_rule.py`, `snow_id.py`, `encoder.py`,无命名混乱。 | ⭐⭐⭐⭐⭐ |
+
+> ✅ **结论**:代码规范极佳,团队有成熟的工程纪律,适合多人协作与长期维护。
+
+---
+
+### **3. 性能(Performance)**
+
+| 维度 | 分析 | 评价 |
+|------|------|------|
+| **异步支持** | 基于 Tornado(单线程异步),支持 `asyncio`,核心模块(DB、Redis、Task)均提供异步版本(`aiomysql`, `aiofiles`),避免阻塞。 | ⭐⭐⭐⭐⭐ |
+| **后台任务池** | `aio_pool.py` 实现带背压(backpressure)的任务池,避免并发过载,提升系统稳定性。 | ⭐⭐⭐⭐⭐ |
+| **Snowflake ID** | 内置线程安全 Snowflake ID 生成器,无需外部依赖(如 Redis),每秒可生成 10K+ ID,性能优异。 | ⭐⭐⭐⭐⭐ |
+| **Redis Stream** | 使用 Redis Stream + Consumer Group 实现可靠消息处理,支持自动僵尸任务恢复,优于轮询或 Celery。 | ⭐⭐⭐⭐⭐ |
+| **JSON 序列化** | 自定义 JSON Encoder 支持 `datetime`, `Decimal`, `numpy`,避免每次手动转换,提升响应速度。 | ⭐⭐⭐⭐ |
+| **静态文件处理** | 使用 Tornado 内置静态文件服务,性能优于 Flask + WhiteNoise。 | ⭐⭐⭐⭐ |
+
+> ✅ **结论**:性能设计先进,关键路径(ID生成、任务调度、异步IO)均经过优化,适合高并发 API 服务。
+
+---
+
+### **4. 功能完整性(Functionality)**
+
+| 功能模块 | 实现情况 | 评价 |
+|----------|----------|------|
+| **自动路由注册** | `route_pattern = "/user"` 自动注册,无需装饰器,极大减少配置。 | ⭐⭐⭐⭐⭐ |
+| **Swagger UI 自动生成** | 无需 YAML,基于 Handler 类和装饰器推断 Schema,创新性强。 | ⭐⭐⭐⭐⭐ |
+| **RBAC 动态规则** | 规则以 Python 类形式存储(可序列化),支持复杂逻辑(如时间、IP),灵活性远超数据库字段式权限。 | ⭐⭐⭐⭐⭐ |
+| **JWT + PBKDF2** | 安全认证基础扎实,密码哈希使用 PBKDF2-sha256,符合 NIST 标准。 | ⭐⭐⭐⭐⭐ |
+| **文件安全处理** | `sanitize_filename()` 防止路径遍历,体现安全意识。 | ⭐⭐⭐⭐ |
+| **参数感知 UI 模块** | 解决 Tornado UIModule 的异步渲染瓶颈,是高级工程智慧。 | ⭐⭐⭐⭐⭐ |
+| **模型自动生成** | `gen_models.py` 可从已有表生成 ORM 模型,提升 DB 迁移效率。 | ⭐⭐⭐⭐ |
+| **任务调度** | `TaskService` 支持 cron-like 调度 + PID 管理,适合后台任务。 | ⭐⭐⭐⭐ |
+
+> ✅ **结论**:功能全面且创新,许多特性(如自动 Swagger、动态 RBAC)在主流框架中罕见,属于“开箱即用”的生产级框架。
+
+---
+
+### **5. 安全性(Security)**
+
+| 维度 | 分析 | 评价 |
+|------|------|------|
+| **认证机制** | JWT + PBKDF2-sha256,无明文密码,无 session,无 CSRF 依赖,符合现代 API 安全标准。 | ⭐⭐⭐⭐⭐ |
+| **权限控制** | RBAC 规则可动态加载为 Python 类,支持运行时变更,避免硬编码权限。 | ⭐⭐⭐⭐⭐ |
+| **文件安全** | `sanitize_filename()` 防止路径遍历攻击(如 `../../../etc/passwd`)。 | ⭐⭐⭐⭐ |
+| **配置安全** | 配置文件 `config.json` 未加密,但敏感信息(如 DB 密码)建议由环境变量注入(可改进)。 | ⭐⭐⭐ |
+| **依赖安全** | 依赖版本锁定明确(如 `sqlalchemy>=2.0.49,<3.0`),避免破坏性升级。 | ⭐⭐⭐⭐ |
+| **日志安全** | 未提及敏感信息脱敏,建议在 `logging.py` 中自动过滤 token/password。 | ⭐⭐⭐ |
+
+> ✅ **结论**:整体安全设计优秀,核心认证和权限机制扎实,建议补充配置敏感信息加密和日志脱敏。
+
+---
+
+### **6. 可用性与可维护性(Usability & Maintainability)**
+
+| 维度 | 分析 | 评价 |
+|------|------|------|
+| **学习曲线** | `examples/` 目录提供 5 个完整可运行示例,文档图文并茂,新人 10 分钟可跑通。 | ⭐⭐⭐⭐⭐ |
+| **可扩展性** | 模块化设计(core/db/web/rbac)允许按需替换组件(如改用 PostgreSQL 或 FastAPI)。 | ⭐⭐⭐⭐⭐ |
+| **可测试性** | `pytest` + `cov=paste`,单元/集成分离,覆盖率达 100% 可期。 | ⭐⭐⭐⭐⭐ |
+| **调试友好** | `config.get_config("...")` 支持默认值,日志清晰,异常信息应有良好结构(未见但可推断)。 | ⭐⭐⭐⭐ |
+| **社区支持** | GitHub 项目活跃(有 CHANGELOG、CONTRIBUTING),作者邮箱公开,社区潜力大。 | ⭐⭐⭐⭐ |
+
+> ✅ **结论**:可用性极高,是“开发者友好型”框架的典范,适合快速交付和长期维护。
+
+---
+
+## 🚀 **综合评分(满分 5⭐)**
+
+| 维度 | 评分 | 评语 |
+|------|------|------|
+| **工程规范** | ⭐⭐⭐⭐⭐ | 现代、标准、完整 |
+| **代码规范** | ⭐⭐⭐⭐⭐ | 工具链完备,纪律严明 |
+| **性能** | ⭐⭐⭐⭐⭐ | 异步 + 背压 + Snowflake,工业级 |
+| **功能完整性** | ⭐⭐⭐⭐⭐ | 创新性强,远超同类框架 |
+| **安全性** | ⭐⭐⭐⭐ | 核心安全机制优秀,建议增强配置安全 |
+| **可用性** | ⭐⭐⭐⭐⭐ | 文档+示例+结构三位一体,极佳体验 |
+
+> **总评:⭐⭐⭐⭐⭐(5/5)—— 企业级生产框架的典范**
+
+---
+
+## 🔧 **改进建议(可选)**
+
+| 建议 | 说明 |
+|------|------|
+| **1. 配置敏感信息加密** | 支持 `config.json` 加密(如使用 `cryptography` + 环境密钥),避免明文存储 DB 密码。 |
+| **2. 日志脱敏** | 在 `logging.py` 中自动过滤 `Authorization: Bearer ...`、`password` 等字段。 |
+| **3. 添加 `SECURITY.md`** | 明确安全响应流程,提升可信度(如 CVE 报告入口)。 |
+| **4. 提供 Dockerfile** | 便于容器化部署,尤其适合云原生环境。 |
+| **5. 添加 `pyright` 或 `typeguard`** | 强化运行时类型检查,提升稳定性。 |
+
+---
+
+## 📌 总结
+
+> **PASTE 是一个罕见的、真正为生产环境打造的 Python 框架。**
+> 它不是另一个“Hello World”玩具,而是一个**减少 80% 常见样板代码**、**内置企业级功能**、**性能与安全并重**的实战型工具链。
+
+**推荐场景**:
+- 中大型 API 服务
+- 需要 RBAC + 异步任务 + Swagger 的 SaaS 产品
+- 团队希望快速交付、减少重复劳动的项目
+
+**一句话推荐**:
+> **“如果你厌倦了写配置、写路由、写权限、写 ID 生成器 —— 用 PASTE。”**
diff --git a/docs/manual.md b/docs/manual.md
new file mode 100644
index 0000000..421823a
--- /dev/null
+++ b/docs/manual.md
@@ -0,0 +1,489 @@
+```markdown
+┌─────────────────────────────────────────────────┐
+│ │
+│ ██████╗ █████╗ ███████╗████████╗███████╗ │
+│ ██╔══██╗██╔══██╗██╔════╝╚══██╔══╝██╔════╝ │
+│ ██████╔╝███████║███████╗ ██║ █████╗ │
+│ ██╔═══╝ ██╔══██║╚════██║ ██║ ██╔══╝ │
+│ ██║ ██║ ██║███████║ ██║ ███████╗ │
+│ ╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚══════╝ │
+│ │
+│ Python Api-first Scalable Task Engine │
+│ │
+└─────────────────────────────────────────────────┘
+```
+
+## 📘 PASTE 框架使用手册 v2.0.1
+
+> **副本信息**
+> 文件:`docs/PASTE框架使用手册.html`
+> 最后更新:2025-04-08
+> 对应框架版本:2.0.1
+
+ ---
+
+### 一、框架概述
+
+**PASTE** —— Python Api-first Scalable Task Engine
+
+PASTE 是一个基于 **Tornado** 的生产级 Python 轻量框架,提供:
+
+| 特性 | 说明 |
+ |-------------------|--------------------------------------------------|
+| 自动路由加载 | 定义 `route_pattern = "/user"` 即可自动注册 API,无需手动配置路由 |
+| RBAC 权限控制 | 动态规则引擎,规则以序列化类存储于数据库,支持时间/IP/自定义规则链 |
+| Swagger 自动生成 | `/docs` 自动输出 Swagger UI,无需手写 YAML |
+| 异步任务池(带背压) | `run_background_task(coro)` 安全管理并发任务,支持任务队列上限 |
+| JWT 无状态认证 | 集成 Token 签发/验证/刷新,`@auth_token` 装饰器一行开启 |
+| Redis Stream 消息队列 | `StreamActor` 支持消费者组、消息 ACK、僵尸任务自动恢复 |
+| 雪花 ID 生成器 | 内嵌线程安全实现,无需外部依赖,单机 1 万+ ID/秒 |
+| 配置系统 | 点号路径风格:`get_config("db.engine.engine")`,单文件配置 |
+| 工具库 | 字符串/字典/文件/分页/编码器/BaseX 编解码/图表/PDF/SVG/Excel |
+
+ ---
+
+### 二、目录结构说明
+
+```
+paste-project/
+│
+├── paste/ # 框架核心(禁止修改!)
+│ ├── core/ # 基础设施层
+│ │ ├── config.py # 点号路径配置加载器
+│ │ ├── logging.py # 日志系统(RotatingFile + 控制台)
+│ │ └── aio_pool.py # 异步任务池 + 背压
+│ ├── db/ # 数据库层
+│ │ ├── engine.py # SQLAlchemy 引擎工厂
+│ │ ├── redis.py # Redis 连接 + StreamActor
+│ │ ├── basemodel.py # 异步 ORM 基类
+│ │ ├── basetable.py # 表反射工具
+│ │ ├── baseadapter.py # 结果集适配器
+│ │ └── gen_models.py # 自动生成模型类
+│ ├── web/ # Web 层
+│ │ ├── application.py # Application(自动装载 Handler)
+│ │ ├── handler.py # RequestHandler 基类
+│ │ ├── decorators.py # @route / @auth_token / @auth_permission
+│ │ ├── swagger.py # Swagger UI 自动生成
+│ │ ├── form.py # WTForms 集成
+│ │ ├── param_aware_loader.py # 异步 UIModule 数据预加载
+│ │ └── websocket.py # WebSocket 支持
+│ ├── rbac/ # 访问控制层
+│ │ ├── rbac_user.py # 用户类(权限继承)
+│ │ ├── rbac_role.py # 角色管理
+│ │ ├── rbac_rule.py # 规则引擎
+│ │ ├── rbac_item.py # 权限项层次结构
+│ │ ├── rbac_assignment.py # 用户-项分配
+│ │ └── rbac_permission.py # 权限查询
+│ ├── security/ # 安全层
+│ │ ├── token.py # JWT 编解码
+│ │ └── shash.py # PBKDF2 密码哈希
+│ ├── service/ # 服务层
+│ │ ├── server.py # 进程管理
+│ │ ├── daemonize.py # Unix 守护进程 + PID 文件
+│ │ └── task_service.py # 定时任务调度器
+│ ├── chart/ # 图表生成(可选依赖)
+│ │ ├── bar.py / pie.py / line.py
+│ ├── util/ # 工具库
+│ │ ├── ustr.py / udict.py / ufile.py
+│ │ ├── pagination.py / snow_id.py / encoder.py
+│ │ └── pdf.py / svg.py / xlsx.py
+│
+├── examples/ # 即开即用的示例代码
+│ ├── 01_hello_world/ # 最小化 Web 应用
+│ ├── 02_background_task/ # 后台异步任务
+│ ├── 03_redis_stream/ # Redis Stream 发布/消费
+│ ├── 04_tasks_service/ # 定时任务服务
+│ └── 05_gen_models/ # 自动生成数据库模型
+│
+├── tests/ # 测试套件
+│ ├── unit/ # 单元测试(Mock 模式)
+│ └── integration/ # 集成测试(需数据库/Redis)
+│
+├── pyproject.toml # 项目元数据和依赖
+├── README.md # 英文文档
+├── CHANGELOG.md # 变更日志
+├── CONTRIBUTING.md # 贡献指南
+├── LICENSE # MIT 许可证
+└── .github/workflows/ci.yml # CI 流水线
+```
+
+ ---
+
+### 三、快速开始
+
+#### 3.1 安装
+
+```bash
+git clone https://github.com/wayne-zwf/paste.git
+cd paste
+pip install -e .
+```
+
+#### 3.2 运行 HelloWorld 示例
+
+```bash
+cd examples/01_hello_world
+python main.py
+```
+
+打开浏览器访问 [http://localhost:9000/hello](http://localhost:9000/hello)
+
+#### 3.3 创建新项目
+
+建议的目录模版:
+
+```
+myapp/
+├── main.py # 入口
+├── config.json # 配置
+├── apps/ # Handler 层
+│ ├── __init__.py
+│ ├── handler_user.py
+│ └── handler_product.py
+├── models/ # 数据模型
+│ └── db_models.py
+└── service/ # 业务服务
+ ├── __init__.py
+ └── task_service.py
+```
+
+---
+
+### 四、核心模块详解
+
+#### 4.1 配置系统 `paste.core.config`
+
+所有配置集中在 `config.json`,通过点号路径读取:
+
+ ```python
+ from paste.core import config
+
+db_url = config.get_config("db.engine.engine")
+port = config.get_config("tornado.demo.port", 9000) # 带默认值
+ ```
+
+配置结构:
+
+```json
+{
+ "tornado": {
+ ...
+ },
+ "db": {
+ "engine": {
+ ...
+ }
+ },
+ "redis": {
+ "connection": "...",
+ "streams": {
+ ...
+ }
+ },
+ "rbac": {
+ "user_class": "...",
+ "table": {
+ ...
+ }
+ },
+ "logger": {
+ "default": {
+ "basic": {
+ ...
+ }
+ }
+ }
+}
+```
+
+#### 4.2 自动路由装载 `paste.web.application`
+
+Handler 只需定义 `route_pattern`,框架自动扫描注册:
+
+```python
+# handler.py —— 无需手动添加路由
+from paste.web.decorators import route
+from paste.web.handler import RequestHandler
+
+
+@route("/users")
+class UserHandler(RequestHandler):
+ async def get(self):
+ self.response_ok(users=await get_all_users())
+ ```
+
+ ```python
+ # main.py —— 自动装载处理器
+from paste.web.application import Application
+
+app = Application(
+ handlers_pkg="apps", # ← 自动扫描 apps 包下的所有 Handler
+ **config.get_config("tornado.demo", {})
+)
+```
+
+#### 4.3 RBAC 权限系统 `paste.rbac`
+
+三层架构:**用户 → 角色 → 权限(规则)**
+
+```python
+# 1. 创建用户
+await RbacUser.create(username="alice", password="secure123")
+
+# 2. 分配权限
+user = await RbacUser.find_by_username("alice")
+await user.assign({"view_reports", "edit_profile"})
+
+# 3. 自定义规则(以序列化类存储于数据库)
+from paste.rbac.rbac_rule import RbacRule
+
+
+class BusinessHoursRule(RbacRule):
+ async def run(self, **kwargs) -> bool:
+ hour = datetime.now().hour
+ return 9 <= hour < 18
+```
+
+在 Handler 中一行开启:
+
+```python
+@route("/admin/reports")
+class ReportHandler(RequestHandler):
+ @auth_token
+ @auth_permission
+ async def get(self):
+ ...
+ ```
+
+#### 4.4 Swagger 自动生成
+
+使用 `ApplicationSwagger` 替代 `Application`:
+
+```python
+from paste.web.application import ApplicationSwagger
+
+app = ApplicationSwagger(
+ handlers_pkg="apps",
+ swagger_title="My API",
+ swagger_description="My API Description",
+ swagger_api_version="1.0.0",
+ **settings
+)
+```
+
+访问 `http://localhost:9000/docs` 即可看到交互式 API 文档。
+
+#### 4.5 Redis StreamActor `paste.db.redis`
+
+```python
+from paste.db.redis import StreamActor
+
+# 创建执行器(从配置读取)
+actor = StreamActor.new_actor("redis.streams.user_event")
+
+# 发布消息
+msg_id = await actor.publish({"user_id": "123", "event": "login"})
+
+
+# 消费消息(阻塞式)
+async def handler(data: dict):
+ print(f"处理消息: {data}")
+ return True # ACK
+
+
+await actor.run_forever(func=handler, is_delete=True)
+ ```
+
+#### 4.6 任务池 + 背压 `paste.core.aio_pool`
+
+```python
+from paste.core.logging import echo_log
+from paste.core.aio_pool import run_background_task
+
+
+async def heavy_task(data):
+ result = await process(data)
+ echo_log(f"任务完成: {result}")
+
+
+# 提交任务(队列满时自动阻塞)
+await run_background_task(heavy_task(some_data))
+```
+
+#### 4.7 定时任务服务 `paste.service.task_service`
+
+```python
+from paste.service.task_service import TaskService
+
+ts = TaskService(service_name="数据同步服务", pid_file="/var/run/sync.pid")
+ts.add_task(
+ creator=ts.create_delay_task(),
+ fn=sync_data,
+ delay=300 # 每 5 分钟执行一次
+)
+ts.start_service() # 控制台模式
+# 或 ts.start() # 守护进程模式
+```
+
+---
+
+### 五、范例速查表
+
+| 示例 | 路径 | 学习重点 |
+|--------------|-------------------------------|----------------------------|
+| HelloWorld | `examples/01_hello_world` | 最小化应用、自动路由、`response_ok()` |
+| 后台任务 | `examples/02_background_task` | `aio_pool`、后台协程、日志输出 |
+| Redis Stream | `examples/03_redis_stream` | `StreamActor`、发布/订阅、消息 ACK |
+| 定时任务 | `examples/04_tasks_service` | `TaskService`、守护进程、PID 管理 |
+| 模型生成 | `examples/05_gen_models` | `gen_models`、表反射、自动生成 ORM |
+
+---
+
+### 六、测试指南
+
+```bash
+# 安装测试依赖
+pip install -e ".[test]"
+
+# 运行单元测试(纯 Mock,无需外部服务)
+pytest tests/unit -v
+
+# 运行集成测试(需 MySQL + Redis)
+pytest tests/integration -v
+
+# 全量测试 + 覆盖率报告
+pytest --cov=paste --cov-report=term-missing
+```
+
+集成测试需要配置环境变量或 `config.json`:
+
+```bash
+export PASTE_DB_URL="mysql+pymysql://root:pass@localhost:3306/paste_test"
+export PASTE_REDIS_URL="redis://localhost:6379/15"
+```
+
+---
+
+### 七、语义版本号规则
+
+本项目遵循 [SemVer 2.0](https://semver.org/):
+
+- **主版本号** — 不兼容的 API 修改
+- **次版本号** — 向下兼容的功能新增(如新的 Handler 装饰器、新的 Stream 功能)
+- **修订号** — 向下兼容的问题修复(如 bug fix、性能优化)
+
+---
+
+### 八、发布检查清单
+
+在推送至 GitHub 公开发布前,请确认以下事项:
+
+- [ ] `README.md` 中的徽章 URL、仓库地址已替换为真实值
+- [ ] `pyproject.toml` 中 `authors`、`Homepage`、`Repository` 已填写真实信息
+- [ ] `CHANGELOG.md` 已根据实际变更更新
+- [ ] `.gitignore` 已添加 `!examples/*/config.json` 放行示例配置
+- [ ] git 仓库已清理:`bash clean_pycache.sh`
+- [ ] `.github/workflows/ci.yml` 已完善(含 checkout、test、lint 步骤)
+- [ ] 单元测试全部通过:`pytest tests/unit -v`
+- [ ] 集成测试已通过(如有外部服务)
+- [ ] 所有示例至少手动运行一次验证
+- [ ] LICENSE 文件存在且内容正确
+
+---
+
+### 九、附录:常见问题
+
+**Q: 为什么我的 Handler 没有被自动装载?**
+A: 检查三点:① 类上有 `@route(...)` 装饰器;② 类继承自 `RequestHandler`;③ `main.py` 中`Application(handlers_pkg="你的包名")`
+的包名正确。
+
+**Q: 如何关闭 Swagger?**
+A: 使用 `Application` 而非 `ApplicationSwagger`,或不传 `swagger_*` 参数。
+
+**Q: RBAC 规则如何持久化?**
+A: 规则类以 pickle 序列化后存入数据库 `rbac_rule` 表。建议仅内部使用,对外 API 应使用 JSON + 白名单方式。
+
+**Q: 如何扩展用户模型?**
+A: 继承 `RbacUser` 并添加自定义字段;然后在 `config.json` 中通过 `rbac.user_class` 指向你的自定义类。
+
+**Q: 日志文件在哪里?**
+A: 默认在 `logs/` 目录下,文件名和格式由 `config.json` 中 `logger.default.basic` 配置控制。
+
+**Q: 为什么我的配置读取失败?**
+A: 确保项目根目录存在 `config.json` 文件,且该文件未被 `.gitignore` 排除。使用 `config.get_config("路径.字段", 默认值)`
+时可以设置安全默认值。
+
+---
+
+### 十、附录:示例配置文件参考(`config.json` 模版)
+
+```json
+{
+ "tornado": {
+ "demo": {
+ "autoreload": false,
+ "handlers_pkg": "apps.demo",
+ "port": 9000,
+ "static_path": "static",
+ "template_path": "templates",
+ "swagger_title": "My API",
+ "swagger_description": "My API Description",
+ "swagger_api_version": "1.0.0",
+ "swagger_contact": "admin@example.com"
+ }
+ },
+ "db": {
+ "engine": {
+ "engine": "mysql+pymysql://root:password@localhost:3306/mydb",
+ "async_engine": "mysql+aiomysql://root:password@localhost:3306/mydb",
+ "engine_option": {
+ "echo": false,
+ "pool_size": 10,
+ "pool_recycle": 3600
+ }
+ }
+ },
+ "redis": {
+ "connection": "redis://localhost:6379/0",
+ "streams": {
+ "user_event": {
+ "group": "user_event_group",
+ "consumer": "processor_01"
+ }
+ }
+ },
+ "rbac": {
+ "user_class": "myapp.models.MyRbacUser",
+ "table": {
+ "rule": "rbac_rule",
+ "user": "rbac_user",
+ "item": "rbac_item",
+ "assignment": "rbac_assignment",
+ "item_child": "rbac_item_child"
+ }
+ },
+ "logger": {
+ "default": {
+ "basic": {
+ "filename": "logs/app.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "level": 20
+ },
+ "name": "MyApp",
+ "max_bytes": 10485760,
+ "backup_count": 5
+ },
+ "task": {
+ "basic": {
+ "filename": "logs/task.log",
+ "format": "%(asctime)s - %(levelname)s - %(message)s",
+ "level": 20
+ },
+ "name": "TaskService",
+ "filename": "logs/task_service.log",
+ "max_bytes": 10485760,
+ "backup_count": 3
+ }
+ }
+}
+```
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000..bdcca9c
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,50 @@
+# environment.yml
+name: paste-env
+channels:
+ - conda-forge
+ - defaults
+
+dependencies:
+ - python=3.11
+ - aiohttp>=3.13.0
+ - aiomysql>=0.2.0
+ - aioquic>=1.2.0
+ - aiosqlite>0.21.0
+ - aiofiles>=23.0.0
+ - cryptography==46.0.3
+ - matplotlib>=3.10.1
+ - matplotlib-inline>=0.1.7
+ - numpy>=1.24.0
+ - openpyxl>=3.1.5
+ - pandas>=2.0.0
+ - pillow>=10.0.0
+ - psutil>=5.9.0
+ - PyJWT>=1.7.1
+ - PyMySQL>=1.1.0
+ - pyOpenSSL>=24.3.0
+ - pytest>=8.0.0
+ - pytest-asyncio>=0.23.0
+ - pytest-cov>=4.0.0
+ - PyYAML>=6.0.2
+ - requests>=2.32.5
+ - selenium>=4.38.0
+ - scipy>=1.14.0
+ - sqlalchemy==2.0.49
+ - svgwrite>=1.4.2
+ - tabulate>=0.9.0
+ - tinycss2>=1.4.0
+ - tinyhtml5>=2.0.0
+ - tornado>=6.4
+ - weasyprint>=64.1
+ - WTForms>=3.2.1
+ - pip
+ - pip:
+ - javaobj-py3>=0.4.4
+ - jieba>=0.42.1
+ - jpush>=3.3.9
+ - redis>=5.2.1
+ - opencv-python>=4.11.0.86
+ - pypinyin>=0.55.0
+ - seaborn>=0.13.2
+ - tornado-swagger>=1.4.5
+ - tornado-wtforms>=0.0.1
\ No newline at end of file
diff --git a/examples/01_hello_world/config.json b/examples/01_hello_world/config.json
new file mode 100644
index 0000000..be87dc7
--- /dev/null
+++ b/examples/01_hello_world/config.json
@@ -0,0 +1,36 @@
+{
+ "app_name": "Hello Paste",
+
+ "logger_desc": "用于日志输出的配置,各服务可以有自己的配置,但要使用独立配置时,必须编写额外代码",
+ "logger": {
+ "default": {
+ "desc": "默认日志配置,该配置小节的名称已经配置在 PASTE 框架中",
+ "basic": {
+ "filename": "logs/root.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "level": 20
+ },
+ "filename": "logs/default.log",
+ "name": "Demo",
+ "max_bytes": 20971520,
+ "backup_count": 40
+ }
+ },
+
+ "tornado_desc": "用于 Tornado 服务的配置,每一项后面允许设置多个服务",
+ "tornado": {
+ "demo": {
+ "autoreload": false,
+ "handlers_pkg": "examples.01_hello_world",
+ "port": 9000,
+ "static_path": "static",
+ "template_path": "templates",
+ "swagger_title": "DemoAPI",
+ "swagger_description": "Demo API",
+ "swagger_api_version": "1.0.1",
+ "swagger_contact": "email@qq.com"
+ }
+ },
+
+ "version": "1.0.1"
+}
diff --git a/examples/01_hello_world/handler.py b/examples/01_hello_world/handler.py
new file mode 100644
index 0000000..6c5a403
--- /dev/null
+++ b/examples/01_hello_world/handler.py
@@ -0,0 +1,16 @@
+from paste.core.logging import echo_log
+from paste.web.decorators import route
+from paste.web.handler import RequestHandler
+
+
+@route("/hello")
+class HelloHandler(RequestHandler):
+ """
+ 演示一个请求。
+ """
+ async def get(self):
+ """
+ 常规请求。
+ """
+ echo_log(f"Received request!")
+ self.response_ok(message="Hello from paste!")
\ No newline at end of file
diff --git a/examples/01_hello_world/main.py b/examples/01_hello_world/main.py
new file mode 100644
index 0000000..aea774f
--- /dev/null
+++ b/examples/01_hello_world/main.py
@@ -0,0 +1,20 @@
+from tornado.ioloop import IOLoop
+
+from paste.core import config
+from paste.core.logging import set_logger_config
+from paste.web.application import Application
+
+if __name__ == "__main__":
+ # 日志配置
+ logger_config_name = 'logger.default'
+ set_logger_config(logger_config_name)
+ # 应用配置
+ demo_config: dict = config.get_config('tornado.demo', {})
+ port = config.get_config('tornado.demo.port', 9000)
+ # 创建应用
+ app = Application(**demo_config)
+ app.listen(port)
+ handlers_pkg = config.get_config('tornado.demo.handlers_pkg')
+ print(f"App {handlers_pkg} is running at http://localhost:{port}")
+ # 启动监听
+ IOLoop.current().start()
\ No newline at end of file
diff --git a/examples/02_background_task/config.json b/examples/02_background_task/config.json
new file mode 100644
index 0000000..4089a4f
--- /dev/null
+++ b/examples/02_background_task/config.json
@@ -0,0 +1,36 @@
+{
+ "app_name": "Background Task Demo",
+
+ "logger_desc": "用于日志输出的配置,各服务可以有自己的配置,但要使用独立配置时,必须编写额外代码",
+ "logger": {
+ "default": {
+ "desc": "默认日志配置,该配置小节的名称已经配置在 PASTE 框架中",
+ "basic": {
+ "filename": "logs/root.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "level": 20
+ },
+ "filename": "logs/default.log",
+ "name": "Demo",
+ "max_bytes": 20971520,
+ "backup_count": 40
+ }
+ },
+
+ "tornado_desc": "用于 Tornado 服务的配置,每一项后面允许设置多个服务",
+ "tornado": {
+ "demo": {
+ "autoreload": false,
+ "handlers_pkg": "examples.02_background_task",
+ "port": 9000,
+ "static_path": "static",
+ "template_path": "templates",
+ "swagger_title": "DemoAPI",
+ "swagger_description": "Demo API",
+ "swagger_api_version": "1.0.1",
+ "swagger_contact": "email@qq.com"
+ }
+ },
+
+ "version": "1.0.1"
+}
diff --git a/examples/02_background_task/handler.py b/examples/02_background_task/handler.py
new file mode 100644
index 0000000..737e048
--- /dev/null
+++ b/examples/02_background_task/handler.py
@@ -0,0 +1,34 @@
+import asyncio
+import logging
+
+from paste.core import aio_pool
+from paste.core.logging import echo_log
+from paste.web.decorators import route
+from paste.web.handler import RequestHandler
+
+
+@route("/background")
+class HelloHandler(RequestHandler):
+ """
+ 演示一个请求,其中包含异步后台任务。
+ """
+
+ async def background_task(self):
+ """
+ 模拟后台异步处理任务:仅做延时,代表执行数据库写入、消息推送、文件处理等。
+ """
+ try:
+ for i in range(10):
+ echo_log(f"后台任务开始执行-{i}...")
+ await asyncio.sleep(0.8) # 模拟耗时操作
+ echo_log("后台任务完成:模拟处理完毕。")
+ except Exception as e:
+ echo_log(f"后台任务异常: {e}", level=logging.ERROR)
+
+ async def get(self):
+ """
+ 常规请求,先执行后台任务,再响应前端,但是不等待任务完成。
+ """
+ echo_log(f"Received request!")
+ await aio_pool.run_background_task(self.background_task())
+ self.response_ok(message="Response from paste!")
\ No newline at end of file
diff --git a/examples/02_background_task/main.py b/examples/02_background_task/main.py
new file mode 100644
index 0000000..aea774f
--- /dev/null
+++ b/examples/02_background_task/main.py
@@ -0,0 +1,20 @@
+from tornado.ioloop import IOLoop
+
+from paste.core import config
+from paste.core.logging import set_logger_config
+from paste.web.application import Application
+
+if __name__ == "__main__":
+ # 日志配置
+ logger_config_name = 'logger.default'
+ set_logger_config(logger_config_name)
+ # 应用配置
+ demo_config: dict = config.get_config('tornado.demo', {})
+ port = config.get_config('tornado.demo.port', 9000)
+ # 创建应用
+ app = Application(**demo_config)
+ app.listen(port)
+ handlers_pkg = config.get_config('tornado.demo.handlers_pkg')
+ print(f"App {handlers_pkg} is running at http://localhost:{port}")
+ # 启动监听
+ IOLoop.current().start()
\ No newline at end of file
diff --git a/examples/03_redis_stream/config.json b/examples/03_redis_stream/config.json
new file mode 100644
index 0000000..b959ff1
--- /dev/null
+++ b/examples/03_redis_stream/config.json
@@ -0,0 +1,52 @@
+{
+ "app_name": "Redis Stream Demo",
+
+ "redis_desc": "Redis 数据库连接配置及相关描述",
+ "redis": {
+ "connection": {
+ "url": "redis://:HaitenRedis@20250703@100.64.0.1:3379/2",
+ "max_connections": 1000,
+ "encoding": "utf-8",
+ "decode_responses": true
+ },
+ "streams": {
+ "demo": {
+ "group": "DEMO_PROCESSORS",
+ "consumer": "demo_worker"
+ }
+ }
+ },
+
+ "logger_desc": "用于日志输出的配置,各服务可以有自己的配置,但要使用独立配置时,必须编写额外代码",
+ "logger": {
+ "default": {
+ "desc": "默认日志配置,该配置小节的名称已经配置在 PASTE 框架中",
+ "basic": {
+ "filename": "logs/root.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "level": 20
+ },
+ "filename": "logs/default.log",
+ "name": "Demo",
+ "max_bytes": 20971520,
+ "backup_count": 40
+ }
+ },
+
+ "tornado_desc": "用于 Tornado 服务的配置,每一项后面允许设置多个服务",
+ "tornado": {
+ "demo": {
+ "autoreload": false,
+ "handlers_pkg": "examples.03_redis_stream",
+ "port": 9000,
+ "static_path": "static",
+ "template_path": "templates",
+ "swagger_title": "DemoAPI",
+ "swagger_description": "Demo API",
+ "swagger_api_version": "1.0.1",
+ "swagger_contact": "email@qq.com"
+ }
+ },
+
+ "version": "1.0.1"
+}
diff --git a/examples/03_redis_stream/handler.py b/examples/03_redis_stream/handler.py
new file mode 100644
index 0000000..f8232c6
--- /dev/null
+++ b/examples/03_redis_stream/handler.py
@@ -0,0 +1,69 @@
+import datetime
+import json
+import logging
+
+from paste.core.logging import echo_log
+from paste.db.redis import StreamActor
+from paste.web.decorators import route
+from paste.web.handler import RequestHandler
+
+@route("/stream")
+class MessageHandler(RequestHandler):
+ """
+ 演示请求发布 Redis Stream 消息。
+ """
+
+ # 从配置中加载 Stream 配置路径
+ stream_config_path = "redis.streams.demo"
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # 初始化 StreamActor 实例(按配置创建)
+ self.actor = StreamActor.new_actor(self.stream_config_path)
+
+ async def post(self):
+ """
+ 接收前端 POST 请求,发布消息到 Redis Stream,立即响应。
+ 请求体格式:
+ {
+ "user_id": "123",
+ "event": "login",
+ "data": {"ip": "192.168.1.1"}
+ }
+ """
+ try:
+ # 1. 获取请求参数
+ body = self.request_arguments()
+ user_id = body.get("user_id")
+ event = body.get("event")
+ data = body.get("data", {})
+
+ if not user_id or not event:
+ self.response_error(
+ Exception("参数缺失:必须提供 user_id 和 event"),
+ status_code=400,
+ api_status_code=400
+ )
+ return
+
+ # 2. 构造消息数据
+ message_data = {
+ "user_id": user_id,
+ "event": event,
+ "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat() + 'Z',
+ "data": json.dumps(data)
+ }
+
+ # 3. 异步发布消息(立即返回,不等待消费)
+ msg_id = await self.actor.publish(message_data)
+
+ # 4. 响应成功
+ self.response_ok(
+ message="消息已成功发布",
+ message_id=msg_id,
+ stream=self.stream_config_path
+ )
+
+ except Exception as e:
+ echo_log('异常', logging.ERROR, True)
+ self.response_error(e, status_code=500, api_status_code=500)
\ No newline at end of file
diff --git a/examples/03_redis_stream/main.py b/examples/03_redis_stream/main.py
new file mode 100644
index 0000000..aea774f
--- /dev/null
+++ b/examples/03_redis_stream/main.py
@@ -0,0 +1,20 @@
+from tornado.ioloop import IOLoop
+
+from paste.core import config
+from paste.core.logging import set_logger_config
+from paste.web.application import Application
+
+if __name__ == "__main__":
+ # 日志配置
+ logger_config_name = 'logger.default'
+ set_logger_config(logger_config_name)
+ # 应用配置
+ demo_config: dict = config.get_config('tornado.demo', {})
+ port = config.get_config('tornado.demo.port', 9000)
+ # 创建应用
+ app = Application(**demo_config)
+ app.listen(port)
+ handlers_pkg = config.get_config('tornado.demo.handlers_pkg')
+ print(f"App {handlers_pkg} is running at http://localhost:{port}")
+ # 启动监听
+ IOLoop.current().start()
\ No newline at end of file
diff --git a/examples/03_redis_stream/stream_service.py b/examples/03_redis_stream/stream_service.py
new file mode 100644
index 0000000..c30b3e3
--- /dev/null
+++ b/examples/03_redis_stream/stream_service.py
@@ -0,0 +1,152 @@
+"""
+演示 Redis Stream 消息队列服务。
+"""
+import asyncio
+import logging
+import os
+import socket
+import sys
+from typing import Optional
+
+import redis
+
+from paste.core import aio_pool
+from paste.core.logging import set_logger_config, echo_log, get_logger
+from paste.db.redis import StreamActor
+from paste.service.daemonize import DaemonizeService
+
+logger_config_name = 'logger.default'
+"""
+配置文件中日志配置字段名称。
+"""
+
+current_event_loop = None
+"""
+事件循环对象。
+"""
+
+pid_file = os.path.join(os.path.curdir, 'stream_service.pid')
+"""
+PID 文件路径。
+"""
+
+service_name = 'Redis Stream 消息队列服务'
+"""
+服务名称。
+"""
+
+# 配置路径:从 config.json 中读取
+stream_config_path = "redis.streams.demo"
+
+# 创建 StreamActor 实例
+stream_actor: Optional[StreamActor] = None
+
+
+async def process_message(data: dict):
+ """
+ 业务回调:处理每条消息
+ """
+ user_id = data.get("user_id", "unknown")
+ event = data.get("event", "")
+ stream_data = data.get("data", "")
+ timestamp = data.get("timestamp", "")
+
+ echo_log(f"消费消息: user_id={user_id}, event='{event}', stream_data='{stream_data}', time={timestamp}")
+
+ # 模拟处理:写入数据库、发送邮件、更新缓存...
+ # 示例:记录日志 + 模拟耗时
+ for i in range(10):
+ echo_log(f"后台任务开始执行-{i}...")
+ await asyncio.sleep(0.8)
+
+ echo_log(f"消息处理完成: {user_id}")
+ return True
+
+
+def current_loop() -> asyncio.AbstractEventLoop:
+ """
+ 这里必须采用方法,在适当的时间点创建事件循环对象,否则会导致服务无法启动。
+ :return: 事件循环对象
+ """
+ global current_event_loop
+ if current_event_loop is None:
+ current_event_loop = asyncio.new_event_loop()
+ return current_event_loop
+
+
+def start_service():
+ """
+ 控制台服务方式启动。
+ """
+ set_logger_config(logger_config_name)
+ echo_log(f"正在启动{service_name}...")
+
+ try:
+ # 检测 Redis 连接
+ echo_log('检测 Redis 服务...')
+ _runner = aio_pool.get_aio_runner()
+ _runner(StreamActor.ping())
+ echo_log('Redis 服务正常.')
+
+ # 创建 StreamActor 监听服务
+ global stream_actor
+ stream_actor = StreamActor.new_actor(stream_config_path)
+ echo_log(f"{service_name}已启动,正在监听{service_name}...")
+ _runner(stream_actor.run_forever(process_message, is_delete=True))
+ except (redis.exceptions.TimeoutError, socket.timeout):
+ echo_log('Redis 服务异常.', level=logging.ERROR, is_log_exc=True)
+ echo_log(f"{service_name}启动失败.")
+ except KeyboardInterrupt:
+ echo_log(msg='KeyboardInterrupt')
+ stop_service()
+ except Exception as e:
+ echo_log(msg=e, level=logging.ERROR, is_log_exc=True)
+ echo_log(f"{service_name}因未知异常启动失败.")
+
+
+def stop_service():
+ """
+ 停止服务。
+ """
+ echo_log(f"正在停止{service_name}...")
+ # 停止监听
+ stream_actor.subscribe_stop()
+ # 停止事件循环
+ current_loop().stop()
+ echo_log(f"{service_name}已停止.")
+
+
+def start():
+ """
+ 驻内存服务方式启动。
+ """
+ set_logger_config(logger_config_name)
+ get_logger()
+ ds = DaemonizeService(pid_file=pid_file, name=service_name)
+ ds.set_start_callback(start_service)
+ ds.set_term_callback(stop_service)
+ ds.start()
+
+
+def stop():
+ """
+ 驻内存服务方式停止。
+ """
+ set_logger_config(logger_config_name)
+ get_logger()
+ ds = DaemonizeService(pid_file=pid_file, name=service_name)
+ ds.set_start_callback(start_service)
+ ds.set_term_callback(stop_service)
+ ds.stop()
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ if sys.argv[1] == "start":
+ start_service()
+ elif sys.argv[1] == "stop":
+ stop_service()
+ else:
+ print("用法: python service/stream_service.py start")
+ else:
+ start_service()
\ No newline at end of file
diff --git a/examples/04_tasks_service/config.json b/examples/04_tasks_service/config.json
new file mode 100644
index 0000000..a9aab5f
--- /dev/null
+++ b/examples/04_tasks_service/config.json
@@ -0,0 +1,21 @@
+{
+ "app_name": "Paste 测试",
+
+ "logger_desc": "用于日志输出的配置,各服务可以有自己的配置,但要使用独立配置时,必须编写额外代码",
+ "logger": {
+ "default": {
+ "desc": "默认日志配置,该配置小节的名称已经配置在 PASTE 框架中",
+ "basic": {
+ "filename": "logs/root.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "level": 20
+ },
+ "filename": "logs/default.log",
+ "name": "Demo",
+ "max_bytes": 20971520,
+ "backup_count": 40
+ }
+ },
+
+ "version": "1.0.1"
+}
diff --git a/examples/04_tasks_service/task_service.py b/examples/04_tasks_service/task_service.py
new file mode 100644
index 0000000..18cfa7e
--- /dev/null
+++ b/examples/04_tasks_service/task_service.py
@@ -0,0 +1,96 @@
+"""
+系统服务,用于读取服务配置文件,启动或停止相关的服务。
+"""
+import asyncio
+import os
+import sys
+from typing import Optional
+
+from paste.core.logging import echo_log, set_logger_config
+from paste.service.task_service import TaskService
+
+logger_config_name = 'logger.default'
+"""
+配置文件中日志配置字段名称。
+"""
+
+task_serv: Optional[TaskService] = None
+"""
+任务服务对象。
+"""
+
+pid_file = os.path.join(os.path.curdir, 'task_service.pid')
+"""
+PID 文件路径。
+"""
+
+service_name = '计划任务服务'
+"""
+服务名称。
+"""
+
+
+def init_task_service():
+ """
+ 初始化服务对象并安装具体任务。
+ """
+ global task_serv
+ task_serv = TaskService(service_name=service_name, pid_file=pid_file)
+
+ # 每隔 2 秒钟执行
+ task_serv.add_task(creator=task_serv.create_delay_task, fn=renew_token, delay=2)
+
+ return task_serv
+
+
+async def renew_token():
+ """
+ 演示更新 Token
+ """
+ echo_log(f"执行:更新 Token.")
+
+ _renewed = False
+ for i in range(2):
+ echo_log(f"后台任务开始执行-{i}...")
+ await asyncio.sleep(0.5) # 模拟耗时操作
+ _renewed = True
+ echo_log(f"更新处理完成,{'已' if _renewed else '未'}更新.")
+
+
+def start_service():
+ """
+ 控制台服务方式启动。
+ """
+ set_logger_config(logger_config_name)
+ _ts = init_task_service()
+ _ts.start_service(env_check=False)
+
+
+def start():
+ """
+ 驻内存服务方式启动。
+ """
+ set_logger_config(logger_config_name)
+ _ts = init_task_service()
+ _ts.start()
+
+
+def stop():
+ """
+ 驻内存服务方式停止。
+ """
+ set_logger_config(logger_config_name)
+ _ts = init_task_service()
+ _ts.stop()
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ if sys.argv[1] == "start":
+ start_service()
+ elif sys.argv[1] == "stop":
+ stop()
+ else:
+ print("用法: python service/tsk_service.py start")
+ else:
+ start_service()
diff --git a/examples/05_gen_models/config.json b/examples/05_gen_models/config.json
new file mode 100644
index 0000000..b2bb0f4
--- /dev/null
+++ b/examples/05_gen_models/config.json
@@ -0,0 +1,45 @@
+{
+ "app_name": "Paste 测试",
+
+ "db_engine_desc": "数据库连接信息,包含普通连接、异步连接以及连接选项,其中连接选项的配置必须对应 create_engine 或 create_async_engine 方法参数,后面加 _xx 后缀的,仅用于保存信息",
+ "db_engine": {
+ "engine": "mysql+pymysql://haiten:HaitenDB%4020250702@100.64.0.1:3360/haiten",
+ "async_engine": "mysql+aiomysql://haiten:HaitenDB%4020250702@100.64.0.1:3360/haiten",
+ "engine_option": {
+ "echo": false,
+ "pool_pre_ping": true,
+ "pool_size": 20,
+ "max_overflow": 200
+ }
+ },
+
+ "logger_desc": "用于日志输出的配置,各服务可以有自己的配置,但要使用独立配置时,必须编写额外代码",
+ "logger": {
+ "default": {
+ "desc": "默认日志配置,该配置小节的名称已经配置在 PASTE 框架中",
+ "basic": {
+ "filename": "logs/root.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "level": 20
+ },
+ "filename": "logs/default.log",
+ "name": "Demo",
+ "max_bytes": 20971520,
+ "backup_count": 40
+ }
+ },
+
+ "rbac_desc": "RBAC 基础信息配置",
+ "rbac": {
+ "table": {
+ "assignment": "hat_auth_assignment",
+ "item": "hat_auth_item",
+ "item_child": "hat_auth_item_child",
+ "rule": "hat_auth_rule",
+ "user": "hat_user"
+ },
+ "user_class": "paste.rbac.rbac_user.RbacUser"
+ },
+
+ "version": "1.0.1"
+}
diff --git a/examples/05_gen_models/main.py b/examples/05_gen_models/main.py
new file mode 100644
index 0000000..9fe01ea
--- /dev/null
+++ b/examples/05_gen_models/main.py
@@ -0,0 +1,11 @@
+from paste.core import aio_pool
+from paste.core.logging import set_logger_config
+from paste.db import gen_models
+
+if __name__ == "__main__":
+ # 日志配置
+ logger_config_name = 'logger.default'
+ set_logger_config(logger_config_name)
+ # 生成模型代码
+ _runner = aio_pool.get_aio_runner()
+ _runner(gen_models.sqlacodegen())
\ No newline at end of file
diff --git a/examples/05_gen_models/models/db_models.py b/examples/05_gen_models/models/db_models.py
new file mode 100644
index 0000000..9776c9f
--- /dev/null
+++ b/examples/05_gen_models/models/db_models.py
@@ -0,0 +1,576 @@
+# coding: utf-8
+from sqlalchemy import CheckConstraint, Column, Date, DateTime, Float, ForeignKey, Index, String, TIMESTAMP, Text, text
+from sqlalchemy.dialects.mysql import BIGINT, INTEGER, MEDIUMTEXT
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
+class HatArticle(Base):
+ __tablename__ = 'hat_article'
+ __table_args__ = {'comment': '文章'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='ID')
+ title = Column(String(300), comment='标题')
+ content = Column(MEDIUMTEXT, comment='内容')
+ cover_image = Column(String(300), nullable=False, server_default=text("''"), comment='封面图片路径')
+ overview = Column(String(300), nullable=False, server_default=text("''"), comment='概述')
+ type = Column(String(50), nullable=False, server_default=text("'采编'"), comment='类型:原创、转载、首发、采编')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+
+class HatArticleCategory(Base):
+ __tablename__ = 'hat_article_category'
+ __table_args__ = {'comment': '文章类别表'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='ID')
+ category_name = Column(String(64), nullable=False, unique=True, server_default=text("''"), comment='类别名称')
+ parent_id = Column(BIGINT(20), nullable=False, server_default=text("0"), comment='父类别ID(默认为0)')
+ description = Column(String(500), nullable=False, server_default=text("''"), comment='类别描述')
+ sort_order = Column(INTEGER(11), nullable=False, server_default=text("1"), comment='排序值')
+ status = Column(INTEGER(11), nullable=False, index=True, server_default=text("0"), comment='状态(默认0,锁定1)')
+ created_at = Column(TIMESTAMP, nullable=False, server_default=text("current_timestamp()"))
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"))
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"))
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"))
+
+
+class HatClass(Base):
+ __tablename__ = 'hat_classes'
+ __table_args__ = {'comment': '班级表'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ name = Column(String(100), nullable=False, comment='班级名称')
+ year = Column(String(100), nullable=False, comment='班级年份')
+ adviser = Column(String(255), nullable=False, comment='辅导员')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+
+class HatClassroom(Base):
+ __tablename__ = 'hat_classroom'
+ __table_args__ = {'comment': '教室'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ name = Column(String(50), nullable=False, comment='教室名称')
+ total = Column(INTEGER(10), nullable=False, server_default=text("60"), comment='容纳人数')
+ description = Column(String(500), comment='描述')
+
+
+class HatCourse(Base):
+ __tablename__ = 'hat_course'
+ __table_args__ = {'comment': '课程表,在专业教学计划表,学分对接表中关联'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ name = Column(String(255), nullable=False, unique=True, comment='课程名称')
+ name_en = Column(String(255), comment='英文名称')
+ code = Column(String(50), nullable=False, comment='课程代码')
+ material = Column(String(1000), comment='所选教材')
+ description = Column(Text, comment='课程描述')
+ category = Column(String(50), nullable=False, comment='授课方(中方课程或外方课程)')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='当前状态')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+
+class HatCourseSchedule(Base):
+ __tablename__ = 'hat_course_schedule'
+ __table_args__ = {'comment': '课表'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ academic_year = Column(String(32), nullable=False, comment='学年')
+ semester = Column(String(64), nullable=False, comment='上课学期')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='状态')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+
+class HatEnrollStudent(Base):
+ __tablename__ = 'hat_enroll_student'
+ __table_args__ = (
+ Index('hat_enroll_student_id_card_number_phone_un', 'id_card_number', 'phone', unique=True),
+ {'comment': '学生报名信息表'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='ID')
+ student_number = Column(String(50), nullable=False, index=True, comment='学号')
+ name = Column(String(128), nullable=False, index=True, comment='姓名')
+ gender = Column(String(10), nullable=False, comment='性别')
+ native_place = Column(String(50), comment='籍贯')
+ id_card_number = Column(String(50), nullable=False, comment='身份证')
+ province = Column(String(128), nullable=False, index=True, comment='省')
+ city = Column(String(128), comment='市')
+ date_of_birth = Column(Date, comment='出生年月')
+ politics_status = Column(String(50), comment='政治面貌')
+ nation = Column(String(128), comment='民族')
+ house_address = Column(String(255), nullable=False, comment='家庭地址')
+ post_code = Column(String(10), nullable=False, comment='邮政编码')
+ phone = Column(String(50), nullable=False, comment='学生手机')
+ educational_level = Column(String(128), comment='文化程度')
+ school_of_graduation = Column(String(128), index=True, comment='毕业学校')
+ graduate_date = Column(Date, comment='毕业日期')
+ awards = Column(String(128), comment='奖励')
+ hobby = Column(String(128), comment='爱好特长')
+ cee_id = Column(String(50), comment='准考证号')
+ cee_scores = Column(String(50), comment='高考总分')
+ cee_english = Column(String(50), comment='英语成绩')
+ cee_chinese = Column(String(50), comment='语文成绩')
+ cee_math = Column(String(50), comment='数学成绩')
+ cee_type = Column(String(50), comment='高考科类')
+ ielts = Column(String(50), comment='雅思成绩')
+ no_cee_reasons = Column(String(128), comment='不参加高考原因')
+ major = Column(String(128), comment='首选专业')
+ major2 = Column(String(128), comment='次选专业')
+ major3 = Column(String(128), comment='再选专业')
+ accommodation = Column(String(10), comment='是否住宿')
+ allocate = Column(String(50), comment='服从调配')
+ abroad = Column(String(50), comment='出国留学')
+ parent_name1 = Column(String(128), nullable=False, comment='家长姓名')
+ parent_phone1 = Column(String(50), nullable=False, comment='电话')
+ relation1 = Column(String(10), comment='关系')
+ parent_name2 = Column(String(128), comment='家长姓名')
+ parent_phone2 = Column(String(50), comment='电话')
+ relation2 = Column(String(10), comment='关系')
+ information_source = Column(String(128), comment='信息来源')
+ referrer = Column(String(128), comment='推荐人')
+ admission_major = Column(String(128), comment='录取专业')
+ admission_at = Column(DateTime, comment='录取时间')
+ intensive_training = Column(String(50), comment='强化训练')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='当前状态')
+ remark = Column(String(255), comment='备注')
+ created_at = Column(DateTime, nullable=False, comment='创建时间')
+
+
+class HatMajor(Base):
+ __tablename__ = 'hat_major'
+ __table_args__ = {'comment': '专业'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ name = Column(String(255), nullable=False, unique=True, comment='专业名称')
+ name_en = Column(String(255), comment='英文名称')
+ code = Column(String(255), nullable=False, comment='专业代码')
+ description = Column(Text, comment='介绍')
+ discipline = Column(String(255), nullable=False, comment='学科门类')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='当前状态')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+
+class HatPerson(Base):
+ __tablename__ = 'hat_person'
+ __table_args__ = (
+ Index('hat_person_name_cer_idx', 'name', 'cer_no', 'cer_type_name'),
+ {'comment': '企业人员'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='ID')
+ name = Column(String(100), nullable=False, server_default=text("''"), comment='姓名')
+ sex = Column(String(1), nullable=False, server_default=text("''"), comment='性别')
+ cer_type_name = Column(String(100), nullable=False, server_default=text("''"), comment='身份证件类型')
+ cer_no = Column(String(40), nullable=False, server_default=text("''"), comment='证件号')
+ tel = Column(String(110), nullable=False, server_default=text("''"), comment='联系电话')
+ school = Column(String(200), nullable=False, server_default=text("''"), comment='毕业院校')
+ edu_bac = Column(String(20), nullable=False, server_default=text("''"), comment='文化程度')
+ major = Column(String(30), nullable=False, server_default=text("''"), comment='所学专业')
+ lite_deg = Column(String(2), nullable=False, server_default=text("''"), comment='文化程度')
+ edu_bac_code = Column(String(30), nullable=False, server_default=text("''"), comment='学历')
+ title = Column(String(40), nullable=False, server_default=text("''"), comment='职称')
+ com_addr = Column(String(512), nullable=False, server_default=text("''"), comment='通信地址')
+ postal_code = Column(String(6), nullable=False, server_default=text("''"), comment='邮编编码')
+ email = Column(String(100), nullable=False, server_default=text("''"), comment='电子邮箱')
+ status = Column(String(64), nullable=False, server_default=text("''"), comment='状态')
+ avatar = Column(String(300), server_default=text("''"), comment='头像')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+
+
+class HatStudent(Base):
+ __tablename__ = 'hat_student'
+ __table_args__ = {'comment': '学生信息表'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='ID')
+ student_number = Column(String(50), nullable=False, comment='学号')
+ name = Column(String(128), nullable=False, comment='姓名')
+ gender = Column(String(10), nullable=False, comment='性别')
+ native_place = Column(String(50), comment='籍贯')
+ id_card_number = Column(String(50), nullable=False, comment='身份证')
+ province = Column(String(128), nullable=False, comment='省')
+ city = Column(String(128), comment='市')
+ date_of_birth = Column(Date, comment='出生年月')
+ politics_status = Column(String(50), comment='政治面貌')
+ nation = Column(String(128), comment='民族')
+ house_address = Column(String(255), nullable=False, comment='家庭地址')
+ post_code = Column(String(10), nullable=False, comment='邮政编码')
+ phone = Column(String(50), nullable=False, comment='学生手机')
+ educational_level = Column(String(128), comment='文化程度')
+ school_of_graduation = Column(String(128), comment='毕业学校')
+ graduate_time = Column(Date, comment='毕业时间')
+ awards = Column(String(128), comment='奖励')
+ hobby = Column(String(128), comment='爱好特长')
+ cee_id = Column(String(50), comment='准考证号')
+ cee_scores = Column(String(50), comment='高考总分')
+ cee_english = Column(String(50), comment='英语成绩')
+ cee_chinese = Column(String(50), comment='语文成绩')
+ cee_math = Column(String(50), comment='数学成绩')
+ cee_type = Column(String(50), comment='高考科类')
+ ielts = Column(String(50), comment='雅思成绩')
+ major = Column(String(128), comment='专业')
+ allocate = Column(String(50), comment='服从调配')
+ abroad = Column(String(50), comment='出国留学')
+ parent_name1 = Column(String(128), nullable=False, comment='家长姓名')
+ parent_phone1 = Column(String(50), nullable=False, comment='电话')
+ relation1 = Column(String(10), comment='关系')
+ parent_name2 = Column(String(128), comment='家长姓名')
+ parent_phone2 = Column(String(50), comment='电话')
+ relation2 = Column(String(10), comment='关系')
+ intensive_training = Column(String(50), comment='强化训练')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='当前状态')
+ created_at = Column(DateTime, nullable=False, comment='创建时间')
+
+
+class HatUser(Base):
+ __tablename__ = 'hat_user'
+ __table_args__ = {'comment': '用户'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ username = Column(String(255), nullable=False, unique=True, comment='用户名')
+ password_hash = Column(String(255), nullable=False, comment='密码')
+ password_reset_token = Column(String(255), comment='重置标记')
+ auth_key = Column(String(255), comment='授权码')
+ status = Column(INTEGER(11), nullable=False, server_default=text("0"), comment='用户状态')
+ type = Column(String(64), nullable=False, server_default=text("'user'"), comment='用户类型')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='更新时间')
+
+
+class HatArticlePublish(Base):
+ __tablename__ = 'hat_article_publish'
+ __table_args__ = (
+ Index('hat_article_publish_un', 'article_id', 'category_id', unique=True),
+ {'comment': '文章发布表'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='ID')
+ article_id = Column(ForeignKey('hat_article.id', ondelete='CASCADE'), nullable=False, comment='文章ID')
+ category_id = Column(ForeignKey('hat_article_category.id', ondelete='CASCADE'), nullable=False, index=True, comment='文章类别ID')
+ sort_order = Column(INTEGER(11), nullable=False, index=True, server_default=text("0"), comment='排序')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ article = relationship('HatArticle')
+ category = relationship('HatArticleCategory')
+
+
+class HatClassesStudent(Base):
+ __tablename__ = 'hat_classes_student'
+ __table_args__ = (
+ Index('classes_id_student_id_un', 'class_id', 'student_id', unique=True),
+ {'comment': '班级学生关系表'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ class_id = Column(ForeignKey('hat_classes.id', ondelete='CASCADE'), nullable=False, comment='班级编号')
+ student_id = Column(ForeignKey('hat_student.id', ondelete='CASCADE'), nullable=False, index=True, comment='学生编号')
+ created_at = Column(DateTime, nullable=False, comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, comment='更新时间')
+
+ _class = relationship('HatClass')
+ student = relationship('HatStudent')
+
+
+class HatCourseScheduleDetail(Base):
+ __tablename__ = 'hat_course_schedule_detail'
+ __table_args__ = {'comment': '课表明细'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ schedule_id = Column(ForeignKey('hat_course_schedule.id', ondelete='CASCADE'), nullable=False, index=True, comment='课表')
+ course_id = Column(ForeignKey('hat_course.id'), nullable=False, index=True, comment='课程')
+ classroom_id = Column(ForeignKey('hat_classroom.id'), nullable=False, index=True, comment='教室')
+ teacher = Column(String(255), nullable=False, comment='任课老师')
+ week_day = Column(INTEGER(20), nullable=False, comment='上课日期,0~1为周日~周六')
+ sequence = Column(INTEGER(20), nullable=False, comment='序号,从1开始,代表是一天中的第几节课')
+
+ classroom = relationship('HatClassroom')
+ course = relationship('HatCourse')
+ schedule = relationship('HatCourseSchedule')
+
+
+class HatExamPaper(Base):
+ __tablename__ = 'hat_exam_paper'
+ __table_args__ = {'comment': '考卷'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ name = Column(String(255), nullable=False, comment='考卷名称')
+ course_id = Column(ForeignKey('hat_course.id'), nullable=False, index=True, comment='考试课程')
+ score = Column(Float(asdecimal=True), server_default=text("0"), comment='分值')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='当前状态')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ course = relationship('HatCourse')
+
+
+class HatExamination(Base):
+ __tablename__ = 'hat_examination'
+ __table_args__ = {'comment': '考务表'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ course_id = Column(ForeignKey('hat_course.id'), nullable=False, index=True, comment='考试课程')
+ academic_year = Column(String(200), nullable=False, comment='学年')
+ semester = Column(String(200), nullable=False, comment='学期')
+ exam_time = Column(DateTime, nullable=False, comment='考试时间')
+ classroom_id = Column(ForeignKey('hat_classroom.id'), index=True, comment='考场教室')
+ exam_paper_id = Column(BIGINT(20), comment='考卷')
+ exam_format = Column(String(50), nullable=False, comment='考试形式,线下、线上')
+ exam_method = Column(String(50), nullable=False, comment='考试方式,开卷、闭卷')
+ exam_type = Column(String(50), nullable=False, comment='考试性质,入学、期中、期末、补考、重修')
+ time_length = Column(INTEGER(20), server_default=text("60"), comment='考试时长')
+ chief_examiner = Column(String(50), nullable=False, comment='主考老师')
+ invigilator = Column(String(50), comment='监考老师')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='当前状态')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ classroom = relationship('HatClassroom')
+ course = relationship('HatCourse')
+
+
+class HatQuestionMaterial(Base):
+ __tablename__ = 'hat_question_material'
+ __table_args__ = {'comment': '考题素材(考题用到的素材,目前仅支持文字素材,可增加附件用于增加其他素材,如图片、声音等)'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ title = Column(String(255), nullable=False, comment='标题')
+ content = Column(Text, nullable=False, comment='文章内容')
+ course_id = Column(ForeignKey('hat_course.id'), nullable=False, index=True, comment='课程')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='当前状态')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ course = relationship('HatCourse')
+
+
+class HatStudentPortrait(Base):
+ __tablename__ = 'hat_student_portrait'
+ __table_args__ = {'comment': '学生头像'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ student_id = Column(ForeignKey('hat_student.id', ondelete='CASCADE'), nullable=False, unique=True, comment='考生')
+ portrait = Column(String(500), nullable=False, comment='头像')
+ created_at = Column(DateTime, nullable=False, comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, comment='更新时间')
+
+ student = relationship('HatStudent')
+
+
+class HatStudentUnusual(Base):
+ __tablename__ = 'hat_student_unusual'
+ __table_args__ = {'comment': '学生异动'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ student_id = Column(ForeignKey('hat_student.id'), nullable=False, index=True, comment='考生')
+ type = Column(String(100), nullable=False, comment='异动类型')
+ memo = Column(String(1024), comment='备注')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ student = relationship('HatStudent')
+
+
+class HatUserPerson(Base):
+ __tablename__ = 'hat_user_person'
+ __table_args__ = {'comment': '用户人员'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='ID')
+ user_id = Column(ForeignKey('hat_user.id'), nullable=False, unique=True, comment='用户ID')
+ person_id = Column(ForeignKey('hat_person.id', ondelete='CASCADE', onupdate='CASCADE'), nullable=False, index=True, comment='人员ID')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='更新时间')
+
+ person = relationship('HatPerson')
+ user = relationship('HatUser')
+
+
+class HatEnrollStudentExam(Base):
+ __tablename__ = 'hat_enroll_student_exam'
+ __table_args__ = (
+ Index('hat_enroll_student_exam_un', 'examination_id', 'student_id', unique=True),
+ {'comment': '参加入学考试的学生'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ examination_id = Column(ForeignKey('hat_examination.id'), nullable=False, comment='考务编号')
+ student_id = Column(ForeignKey('hat_enroll_student.id'), nullable=False, index=True, comment='考生')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ examination = relationship('HatExamination')
+ student = relationship('HatEnrollStudent')
+
+
+class HatEnrollStudentScore(Base):
+ __tablename__ = 'hat_enroll_student_score'
+ __table_args__ = (
+ CheckConstraint('json_valid(`answer`)'),
+ CheckConstraint('json_valid(`question_score`)'),
+ Index('hat_enroll_student_score_un', 'examination_id', 'student_id', unique=True),
+ {'comment': '入学考试成绩'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ examination_id = Column(ForeignKey('hat_examination.id'), nullable=False, comment='考务安排')
+ student_id = Column(ForeignKey('hat_enroll_student.id'), nullable=False, index=True, comment='考生')
+ started_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='考试开始时间')
+ submit_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='交卷时间')
+ submit_method = Column(String(16), nullable=False, server_default=text("'N'"), comment='交卷方式')
+ answer = Column(Text, nullable=False, comment='答案(JSON数据)')
+ question_score = Column(Text, comment='各题得分')
+ test_score = Column(String(100), nullable=False, comment='考试成绩')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='更新时间')
+
+ examination = relationship('HatExamination')
+ student = relationship('HatEnrollStudent')
+
+
+class HatQuestion(Base):
+ __tablename__ = 'hat_question'
+ __table_args__ = {'comment': '考题(可关联到素材表,对关联到同一素材的所有考题,按照时间先后排序)'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ question = Column(String(1800), nullable=False, comment='提问')
+ options = Column(String(1800), comment='选项(json)')
+ answer = Column(String(1800), nullable=False, comment='答案(json)')
+ category = Column(String(50), nullable=False, comment='题型(判断题、单项选择题、多项选择题、不定项选择题、填空题、完形填空、阅读理解、简答题、论述题、作文)')
+ course_id = Column(ForeignKey('hat_course.id'), nullable=False, index=True, comment='课程')
+ material_id = Column(ForeignKey('hat_question_material.id'), index=True, comment='素材')
+ score = Column(Float(asdecimal=True), server_default=text("0"), comment='分值')
+ difficulty = Column(Float(asdecimal=True), server_default=text("1"), comment='难度系数')
+ status = Column(INTEGER(20), nullable=False, server_default=text("10"), comment='当前状态')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ course = relationship('HatCourse')
+ material = relationship('HatQuestionMaterial')
+
+
+class HatStudentAttendance(Base):
+ __tablename__ = 'hat_student_attendance'
+ __table_args__ = {'comment': '学生考勤'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ student_id = Column(ForeignKey('hat_student.id'), nullable=False, index=True, comment='考生')
+ schedule_detail_id = Column(ForeignKey('hat_course_schedule_detail.id'), nullable=False, index=True, comment='课表明细编号')
+ sequence = Column(INTEGER(20), nullable=False, comment='序号')
+ time_at = Column(DateTime, nullable=False, comment='考勤时间')
+ type = Column(String(100), nullable=False, comment='考勤类型')
+ memo = Column(String(1024), comment='备注')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ schedule_detail = relationship('HatCourseScheduleDetail')
+ student = relationship('HatStudent')
+
+
+class HatStudentExam(Base):
+ __tablename__ = 'hat_student_exam'
+ __table_args__ = (
+ Index('hat_student_exam_un', 'examination_id', 'student_id', unique=True),
+ {'comment': '参加考试的学生'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ examination_id = Column(ForeignKey('hat_examination.id', ondelete='CASCADE'), nullable=False, comment='考务编号')
+ student_id = Column(ForeignKey('hat_student.id', ondelete='CASCADE'), nullable=False, index=True, comment='考生')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ created_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='创建者')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='修改时间')
+ updated_by = Column(String(64), nullable=False, server_default=text("'API'"), comment='修改者')
+
+ examination = relationship('HatExamination')
+ student = relationship('HatStudent')
+
+
+class HatStudentScore(Base):
+ __tablename__ = 'hat_student_score'
+ __table_args__ = (
+ CheckConstraint('json_valid(`answer`)'),
+ CheckConstraint('json_valid(`question_score`)'),
+ Index('hat_student_score_un', 'examination_id', 'student_id', unique=True),
+ {'comment': '考试成绩'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ examination_id = Column(ForeignKey('hat_examination.id'), nullable=False, comment='考务安排')
+ student_id = Column(ForeignKey('hat_student.id'), nullable=False, index=True, comment='考生')
+ started_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='考试开始时间')
+ submit_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='交卷时间')
+ submit_method = Column(String(16), nullable=False, server_default=text("'N'"), comment='交卷方式')
+ answer = Column(Text, nullable=False, comment='答案(JSON数据)')
+ question_score = Column(Text, comment='各题得分')
+ test_score = Column(String(100), nullable=False, comment='考试成绩')
+ created_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, server_default=text("current_timestamp()"), comment='更新时间')
+
+ examination = relationship('HatExamination')
+ student = relationship('HatStudent')
+
+
+class HatStudentUnusualAttachment(Base):
+ __tablename__ = 'hat_student_unusual_attachment'
+ __table_args__ = {'comment': '学生异动附件'}
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ unusual_id = Column(ForeignKey('hat_student_unusual.id', ondelete='CASCADE', onupdate='CASCADE'), nullable=False, index=True, comment='异动编号')
+ name = Column(String(255), nullable=False, comment='附件名称')
+ created_at = Column(DateTime, nullable=False, comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, comment='更新时间')
+
+ unusual = relationship('HatStudentUnusual')
+
+
+class HatExamPaperQuestion(Base):
+ __tablename__ = 'hat_exam_paper_question'
+ __table_args__ = (
+ Index('hat_exam_paper_question_un', 'exam_paper_id', 'question_id', unique=True),
+ {'comment': '考卷考题'}
+ )
+
+ id = Column(BIGINT(20), primary_key=True, comment='系统编号')
+ exam_paper_id = Column(ForeignKey('hat_exam_paper.id', ondelete='CASCADE', onupdate='CASCADE'), nullable=False, comment='考卷')
+ question_id = Column(ForeignKey('hat_question.id'), nullable=False, index=True, comment='考题')
+ sort = Column(INTEGER(20), nullable=False, server_default=text("0"), comment='排序')
+
+ exam_paper = relationship('HatExamPaper')
+ question = relationship('HatQuestion')
diff --git a/examples/06_chart_example/chart_bar_example.py b/examples/06_chart_example/chart_bar_example.py
new file mode 100644
index 0000000..05d628d
--- /dev/null
+++ b/examples/06_chart_example/chart_bar_example.py
@@ -0,0 +1,153 @@
+import numpy as np
+import os
+import traceback
+from paste.chart.bar import (
+ gen_vertical_bars,
+ gen_horizontal_stacked_bars,
+ gen_percent_stacked_bars
+)
+
+
+class ChartBarExample:
+ """
+ 图表测试管理器:封装对 paste.chart.bar 中三个函数的调用。
+ 不修改任何参数结构,仅提供清晰的调用封装与输出管理。
+ """
+
+ def __init__(self, output_directory="./charts"):
+ """
+ 初始化测试器,定义所有测试数据。
+ 数据结构完全匹配原始函数调用方式。
+ """
+ self.output_directory = output_directory
+ os.makedirs(self.output_directory, exist_ok=True)
+
+ # 纵向堆叠柱形图数据(直接对应 gen_vertical_bars 参数)
+ self.primary_vals = [10, 20, 15, 25, 18]
+ self.nested_vals = [5, 8, 3, 10, 6]
+ self.x_labels_vert = ['产品1', '产品2', '产品3', '产品4', '产品5']
+ self.group_labels_vert = ['销售量', '退货量']
+
+ # 横向堆叠柱形图数据(直接对应 gen_horizontal_stacked_bars 参数)
+ self.data_matrix = np.array([
+ [10, 20, 15],
+ [15, 12, 18],
+ [8, 16, 10],
+ [12, 14, 13]
+ ])
+ self.x_labels_hori = ['线上销售', '门店销售', '批发销售']
+ self.y_labels_hori = ['北京', '上海', '广州', '深圳']
+ self.y_data_unit_hori = '万元'
+ self.title_hori = '销售构成'
+
+ # 百分比堆叠柱形图数据(直接对应 gen_percent_stacked_bars 参数)
+ self.data_percent = {
+ 'A组': [10, 20, 15, 18],
+ 'B组': [5, 10, 5, 8],
+ 'C组': [3, 7, 10, 4]
+ }
+ self.x_labels_percent = ['Q1', 'Q2', 'Q3', 'Q4']
+ self.title_percent = '季度占比'
+
+ def generate_vertical_bars(self) -> str:
+ """调用 gen_vertical_bars,参数完全一致"""
+ try:
+ return gen_vertical_bars(
+ self.primary_vals,
+ self.nested_vals,
+ self.x_labels_vert,
+ self.group_labels_vert
+ )
+ except Exception as e:
+ print(f"纵向堆叠柱形图生成失败: {e}")
+ traceback.print_exc()
+ raise
+
+ def generate_horizontal_stacked_bars(self) -> str:
+ """调用 gen_horizontal_stacked_bars,参数完全一致"""
+ try:
+ return gen_horizontal_stacked_bars(
+ self.data_matrix,
+ self.x_labels_hori,
+ self.y_labels_hori,
+ self.y_data_unit_hori,
+ self.title_hori
+ )
+ except Exception as e:
+ print(f"横向堆叠柱形图生成失败: {e}")
+ traceback.print_exc()
+ raise
+
+ def generate_percent_stacked_bars(self) -> str:
+ """调用 gen_percent_stacked_bars,参数完全一致"""
+ try:
+ return gen_percent_stacked_bars(
+ self.data_percent,
+ self.x_labels_percent,
+ self.title_percent
+ )
+ except Exception as e:
+ print(f"百分比堆叠柱形图生成失败: {e}")
+ traceback.print_exc()
+ raise
+
+ def save_svg(self, svg_data: str, filename: str) -> None:
+ """
+ 将 SVG 的 base64 Data URL 写入文件(保留原始 SVG 格式)。
+ 注意:svg_data 是 "data:image/svg+xml;base64,...",需提取真实 SVG 内容。
+ """
+ if not svg_data or not isinstance(svg_data, str):
+ print(f"生成的 SVG 数据无效(为空或非字符串): {filename}")
+ return
+
+ # 提取 base64 编码部分(去除 data URL 前缀)
+ if svg_data.startswith("data:image/svg+xml;base64,"):
+ base64_content = svg_data[len("data:image/svg+xml;base64,"):]
+ try:
+ # 解码 base64 得到原始 SVG 字符串
+ import base64
+ svg_content = base64.b64decode(base64_content).decode('utf-8')
+ except Exception as e:
+ print(f"解码 base64 失败: {e}")
+ svg_content = svg_data # 退化为直接写入
+ else:
+ # 如果不是标准格式,直接写入(兼容调试)
+ svg_content = svg_data
+
+ filepath = os.path.join(self.output_directory, filename)
+ with open(filepath, 'w', encoding='utf-8') as f:
+ f.write(svg_content)
+ print(f"已保存: {filepath}")
+
+ def run(self) -> None:
+ """按顺序执行所有图表生成与保存"""
+ print("开始生成图表...")
+ try:
+ print("生成纵向堆叠柱形图...")
+ svg1 = self.generate_vertical_bars()
+ self.save_svg(svg1, "vertical_bars.svg")
+
+ print("生成横向堆叠柱形图...")
+ svg2 = self.generate_horizontal_stacked_bars()
+ self.save_svg(svg2, "horizontal_stacked_bars.svg")
+
+ print("生成百分比堆叠柱形图...")
+ svg3 = self.generate_percent_stacked_bars()
+ self.save_svg(svg3, "percent_stacked_bars.svg")
+
+ print("\n所有图表已成功生成。")
+ print(f"输出目录: {self.output_directory}")
+ print("文件列表:")
+ print(" - vertical_bars.svg")
+ print(" - horizontal_stacked_bars.svg")
+ print(" - percent_stacked_bars.svg")
+
+ except Exception as e:
+ print(f"\n测试失败: {e}")
+ traceback.print_exc()
+
+
+# 程序入口
+if __name__ == "__main__":
+ tester = ChartBarExample()
+ tester.run()
\ No newline at end of file
diff --git a/examples/06_chart_example/chart_pie_example.py b/examples/06_chart_example/chart_pie_example.py
new file mode 100644
index 0000000..d5a61f2
--- /dev/null
+++ b/examples/06_chart_example/chart_pie_example.py
@@ -0,0 +1,110 @@
+import os
+import traceback
+from paste.chart.pie import gen_pie
+
+
+class ChartPieExample:
+ """
+ 环形图测试管理器:封装对 paste.chart.pie.gen_pie 的调用。
+ 数据结构严格匹配函数参数要求,支持扩展更多测试用例。
+ """
+
+ def __init__(self, output_directory="./charts"):
+ """
+ 初始化测试器,定义所有测试数据。
+ 数据结构完全匹配 gen_pie 函数的参数要求。
+ """
+ self.output_directory = output_directory
+ os.makedirs(self.output_directory, exist_ok=True)
+
+ # 构造符合 gen_pie 要求的 DataFrame 数据(模拟真实业务场景)
+ # 假设业务场景:网络设备统计(服务器、交换机、路由器等)
+ import pandas as pd
+ self.data_df = pd.DataFrame({
+ 'device_count': [35, 28, 22, 15, 10], # value_column
+ 'percentage': ['35.2%', '28.1%', '22.0%', '15.0%', '9.7%'], # percentage_column
+ 'device_type': ['服务器', '交换机', '路由器', '防火墙', 'AP'] # legend_labels
+ })
+
+ # 测试参数
+ self.value_column = 'device_count'
+ self.percentage_column = 'percentage'
+ self.legend_labels = 'device_type'
+ self.color_palette = 'BuPu' # 可尝试 'viridis', 'Set3', 'plasma'
+ self.dpi = 128
+
+ # 输出文件名
+ self.filename = "pie_chart.svg"
+
+ def generate_pie_chart(self) -> str:
+ """
+ 调用 gen_pie 函数,参数完全一致。
+ 注意:gen_pie 接收的是 pandas.DataFrame,不是列表。
+ """
+ try:
+ svg_data = gen_pie(
+ data_df=self.data_df,
+ value_column=self.value_column,
+ percentage_column=self.percentage_column,
+ legend_labels=self.legend_labels,
+ color_palette=self.color_palette,
+ dpi=self.dpi
+ )
+ if not svg_data or not isinstance(svg_data, str):
+ raise ValueError("gen_pie 返回的 SVG 数据为空或类型错误")
+ return svg_data
+ except Exception as e:
+ print(f"环形图生成失败: {e}")
+ traceback.print_exc()
+ raise
+
+ def save_svg(self, svg_data: str, filename: str) -> None:
+ """
+ 将 SVG 的 base64 Data URL 写入文件(保留原始 SVG 格式)。
+ 注意:svg_data 是 "data:image/svg+xml;base64,...",需提取真实 SVG 内容。
+ """
+ if not svg_data or not isinstance(svg_data, str):
+ print(f"生成的 SVG 数据无效(为空或非字符串): {filename}")
+ return
+
+ # 提取 base64 编码部分(去除 data URL 前缀)
+ if svg_data.startswith("data:image/svg+xml;base64,"):
+ base64_content = svg_data[len("data:image/svg+xml;base64,"):]
+ try:
+ # 解码 base64 得到原始 SVG 字符串
+ import base64
+ svg_content = base64.b64decode(base64_content).decode('utf-8')
+ except Exception as e:
+ print(f"解码 base64 失败: {e}")
+ svg_content = svg_data # 退化为直接写入
+ else:
+ # 如果不是标准格式,直接写入(兼容调试)
+ svg_content = svg_data
+
+ filepath = os.path.join(self.output_directory, filename)
+ with open(filepath, 'w', encoding='utf-8') as f:
+ f.write(svg_content)
+ print(f"已保存: {filepath}")
+
+ def run(self) -> None:
+ """按顺序执行图表生成与保存"""
+ print("开始生成环形图...")
+ try:
+ print("正在生成环形图...")
+ svg_data = self.generate_pie_chart()
+ self.save_svg(svg_data, self.filename)
+
+ print(f"\n环形图已成功生成。")
+ print(f"输出目录: {self.output_directory}")
+ print(f"文件列表:")
+ print(f" - {self.filename}")
+
+ except Exception as e:
+ print(f"\n测试失败: {e}")
+ traceback.print_exc()
+
+
+# 程序入口
+if __name__ == "__main__":
+ tester = ChartPieExample()
+ tester.run()
\ No newline at end of file
diff --git a/examples/06_chart_example/charts/charts.html b/examples/06_chart_example/charts/charts.html
new file mode 100644
index 0000000..2d09caf
--- /dev/null
+++ b/examples/06_chart_example/charts/charts.html
@@ -0,0 +1,16 @@
+
+
+
+
+
+ Charts
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/06_chart_example/charts/horizontal_stacked_bars.svg b/examples/06_chart_example/charts/horizontal_stacked_bars.svg
new file mode 100644
index 0000000..6bd1c32
--- /dev/null
+++ b/examples/06_chart_example/charts/horizontal_stacked_bars.svg
@@ -0,0 +1,1905 @@
+
+
+
diff --git a/examples/06_chart_example/charts/lines.svg b/examples/06_chart_example/charts/lines.svg
new file mode 100644
index 0000000..891403d
--- /dev/null
+++ b/examples/06_chart_example/charts/lines.svg
@@ -0,0 +1,1006 @@
+
+
+
diff --git a/examples/06_chart_example/charts/percent_stacked_bars.svg b/examples/06_chart_example/charts/percent_stacked_bars.svg
new file mode 100644
index 0000000..8b8627e
--- /dev/null
+++ b/examples/06_chart_example/charts/percent_stacked_bars.svg
@@ -0,0 +1,1210 @@
+
+
+
diff --git a/examples/06_chart_example/charts/pie_chart.svg b/examples/06_chart_example/charts/pie_chart.svg
new file mode 100644
index 0000000..f40eaca
--- /dev/null
+++ b/examples/06_chart_example/charts/pie_chart.svg
@@ -0,0 +1,1373 @@
+
+
+
diff --git a/examples/06_chart_example/charts/splines.svg b/examples/06_chart_example/charts/splines.svg
new file mode 100644
index 0000000..852b5f2
--- /dev/null
+++ b/examples/06_chart_example/charts/splines.svg
@@ -0,0 +1,1213 @@
+
+
+
diff --git a/examples/06_chart_example/charts/vertical_bars.svg b/examples/06_chart_example/charts/vertical_bars.svg
new file mode 100644
index 0000000..55721c1
--- /dev/null
+++ b/examples/06_chart_example/charts/vertical_bars.svg
@@ -0,0 +1,1217 @@
+
+
+
diff --git a/examples/06_chart_example/line_chart_example.py b/examples/06_chart_example/line_chart_example.py
new file mode 100644
index 0000000..41a1309
--- /dev/null
+++ b/examples/06_chart_example/line_chart_example.py
@@ -0,0 +1,124 @@
+import pandas as pd
+import os
+from paste.chart.line import gen_lines, gen_splines
+
+
+class LineChartExample:
+ """
+ 折线图与平滑曲线测试管理器。
+ 封装 gen_lines 和 gen_splines 的调用,统一输出管理。
+ """
+
+ def __init__(self, output_directory="./charts"):
+ """
+ 初始化测试器,定义所有测试数据。
+ 使用 'ME' 替代 'M' 以兼容 pandas >=2.0。
+ """
+ self.output_directory = output_directory
+ os.makedirs(self.output_directory, exist_ok=True)
+
+ # =========================
+ # gen_lines 测试数据(使用 'ME' 替代废弃的 'M')
+ # =========================
+ # 时间序列数据:3个年份,每月最后一个交易日(12个月)
+ dates_2022 = pd.date_range('2022-01-01', periods=12, freq='ME')
+ dates_2023 = pd.date_range('2023-01-01', periods=12, freq='ME')
+ dates_2024 = pd.date_range('2024-01-01', periods=12, freq='ME')
+
+ self.data_dict_lines = {
+ '2022': pd.Series([100 + i * 2 for i in range(12)], index=dates_2022),
+ '2023': pd.Series([110 + i * 1.5 for i in range(12)], index=dates_2023),
+ '2024': pd.Series([120 + i * 1 for i in range(12)], index=dates_2024),
+ }
+
+ self.color_palette_lines = 'BuPu'
+ self.dpi_lines = 128
+
+ # =========================
+ # gen_splines 测试数据(数值索引,无频率问题)
+ # =========================
+ self.data_dict_splines = {
+ 'A组': pd.Series([10, 15, 12, 18, 16], index=[0, 1, 2, 3, 4]),
+ 'B组': pd.Series([12, 14, 13, 17, 15], index=[0, 1, 2, 3, 4]),
+ 'C组': pd.Series([8, 20, 10, 22, 11], index=[0, 1, 2, 3, 4]),
+ }
+
+ self.total_splines = pd.Series({
+ 'A组': 71,
+ 'B组': 71,
+ 'C组': 71
+ })
+
+ self.color_palette_splines = 'viridis'
+ self.dpi_splines = 128
+ self.smooth_points = 100
+ self.spline_k = 3
+ self.markevery = 30
+
+ def generate_lines(self) -> str:
+ """调用 gen_lines,参数顺序完全匹配原始函数"""
+ return gen_lines(
+ self.data_dict_lines,
+ self.color_palette_lines,
+ self.dpi_lines
+ )
+
+ def generate_splines(self) -> str:
+ """调用 gen_splines,参数顺序完全匹配原始函数"""
+ return gen_splines(
+ self.data_dict_splines,
+ self.total_splines,
+ None,
+ self.color_palette_splines,
+ self.dpi_splines
+ )
+
+ def save_svg(self, svg_data: str, filename: str) -> None:
+ """
+ 将 SVG 的 base64 Data URL 写入文件(保留原始 SVG 格式)。
+ 注意:svg_data 是 "data:image/svg+xml;base64,...",需提取真实 SVG 内容。
+ """
+ if not svg_data or not isinstance(svg_data, str):
+ print(f"生成的 SVG 数据无效(为空或非字符串): {filename}")
+ return
+
+ # 提取 base64 编码部分(去除 data URL 前缀)
+ if svg_data.startswith("data:image/svg+xml;base64,"):
+ base64_content = svg_data[len("data:image/svg+xml;base64,"):]
+ try:
+ # 解码 base64 得到原始 SVG 字符串
+ import base64
+ svg_content = base64.b64decode(base64_content).decode('utf-8')
+ except Exception as e:
+ print(f"解码 base64 失败: {e}")
+ svg_content = svg_data # 退化为直接写入
+ else:
+ # 如果不是标准格式,直接写入(兼容调试)
+ svg_content = svg_data
+
+ filepath = os.path.join(self.output_directory, filename)
+ with open(filepath, 'w', encoding='utf-8') as f:
+ f.write(svg_content)
+ print(f"已保存: {filepath}")
+
+ def run(self) -> None:
+ """按顺序生成并保存所有图表"""
+ print("生成折线图...")
+ svg1 = self.generate_lines()
+ self.save_svg(svg1, "lines.svg")
+
+ print("生成平滑曲线图...")
+ svg2 = self.generate_splines()
+ self.save_svg(svg2, "splines.svg")
+
+ print("\n所有图表已生成。")
+ print(f"输出目录: {self.output_directory}")
+ print("文件列表:")
+ print(" - lines.svg")
+ print(" - splines.svg")
+
+
+# 程序入口
+if __name__ == "__main__":
+ tester = LineChartExample()
+ tester.run()
\ No newline at end of file
diff --git a/examples/12_batch_api_calls/README.md b/examples/12_batch_api_calls/README.md
new file mode 100644
index 0000000..0cefda8
--- /dev/null
+++ b/examples/12_batch_api_calls/README.md
@@ -0,0 +1,47 @@
+## 批量接口调用范例
+
+```mermaid
+flowchart TD
+ subgraph 真实场景
+ S1["爬虫
1000个页面"]
+ S2["批量下单
500个订单"]
+ S3["数据同步
5个系统"]
+ S4["压测
自己服务器"]
+ end
+
+ subgraph Paste方案
+ P1["new_http_request × N"]
+ P2["async_concurrency 一发"]
+ P3["统一处理响应"]
+ end
+
+ S1 --> P1 --> P2 --> P3
+ S2 --> P1 --> P2 --> P3
+ S3 --> P1 --> P2 --> P3
+ S4 --> P1 --> P2 --> P3
+```
+
+### 支持的能力
+
+- GET / POST / JSON / Form / 文件上传
+- 自动随机 User-Agent
+- 自动提取 Cookie
+- 批量并发控制
+- 自动重试
+- 统一响应/错误处理
+
+### 代码量
+
+- Java: 200-300 行
+- Paste: **20 行**
+
+### 示例
+
+```python
+# 准备1000个请求
+for i in range(1000):
+ req = new_http_request(url, method="POST", body={"id": i})
+ await queue.put(req)
+
+# 批量发出
+await async_concurrency(queue, con_count=50)
\ No newline at end of file
diff --git a/paste/__init__.py b/paste/__init__.py
new file mode 100755
index 0000000..f414c4e
--- /dev/null
+++ b/paste/__init__.py
@@ -0,0 +1,10 @@
+__package_name__ = "paste"
+
+__version__ = "2.0.1"
+
+__author__ = "Wayne Zhang"
+
+__email__ = "waynezwf@qq.com"
+
+def get_version():
+ return f"{__package_name__} version: V{__version__}, written by {__author__}."
\ No newline at end of file
diff --git a/paste/chart/__init__.py b/paste/chart/__init__.py
new file mode 100644
index 0000000..29bd989
--- /dev/null
+++ b/paste/chart/__init__.py
@@ -0,0 +1,8 @@
+import seaborn
+
+
+def get_seaborn_colors(n, palette='husl'):
+ """
+ 使用Seaborn的调色板生成颜色。
+ """
+ return seaborn.color_palette(palette, n_colors=n).as_hex()
diff --git a/paste/chart/bar.py b/paste/chart/bar.py
new file mode 100644
index 0000000..ae98b42
--- /dev/null
+++ b/paste/chart/bar.py
@@ -0,0 +1,291 @@
+import base64
+from io import BytesIO
+from typing import List, Dict
+
+import numpy as np
+from matplotlib import pyplot as plt
+
+from paste import chart
+from paste.util import ufont
+
+
+# ===========================
+# 工具函数:统一字体与保存逻辑
+# ===========================
+
+def _setup_plot():
+ """
+ 设置全局绘图参数(字体、负号、DPI),避免重复代码。
+ """
+ available_font = ufont.get_fonts()
+ plt.rcParams.update({
+ 'font.sans-serif': list(available_font),
+ 'axes.unicode_minus': False,
+ })
+
+
+def _save_to_base64(dpi: int = 128) -> str:
+ """
+ 将当前图形保存为 base64 编码的 SVG 字符串,自动关闭图形。
+ 返回 Data URL 格式。
+ """
+ buffer = BytesIO()
+ plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight')
+ plt.close()
+ image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
+ return f"data:image/svg+xml;base64,{image_base64}"
+
+
+# ===========================
+# 1. gen_vertical_bars:竖向堆叠柱状图
+# ===========================
+
+def gen_vertical_bars(
+ primary_values: List[float],
+ nested_values: List[float],
+ x_labels: List[str],
+ group_labels: List[str],
+ color_palette: str = 'BuPu',
+ dpi: int = 128
+) -> str:
+ """
+ 生成竖向堆叠柱状图(两个分组)。
+
+ 数据结构说明:
+ - primary_values: [v1, v2, ..., vn] # 主柱高度,长度 = n
+ - nested_values: [v1, v2, ..., vn] # 嵌套柱高度,长度 = n,从主柱顶部叠加
+ - x_labels: ['A', 'B', 'C', ...] # 每个柱子的X轴标签,长度 = n
+ - group_labels: ['Group1', 'Group2'] # 两个分组的图例名称,长度必须为2
+
+ 注意:所有列表长度必须一致(n),否则会报错。
+ 本函数仅支持两个分组,若需扩展,建议使用 genHBars。
+
+ :param primary_values: 主柱数据(底部),数值列表
+ :param nested_values: 嵌套柱数据(顶部),数值列表
+ :param x_labels: X轴标签列表
+ :param group_labels: 图例标签列表,长度为2
+ :param color_palette: Seaborn 色盘名称
+ :param dpi: 图像分辨率
+ :return: base64 编码的 SVG 图像字符串
+ """
+ if len(group_labels) != 2:
+ raise ValueError("group_labels 必须包含两个元素:[主组, 嵌套组]")
+
+ _setup_plot()
+
+ colors = chart.get_seaborn_colors(2, palette=color_palette)
+ fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
+
+ # 绘制主柱(底部)
+ bars1 = ax.bar(x_labels, primary_values, label=group_labels[0], color=colors[0])
+
+ # 绘制嵌套柱(从主柱顶部开始,即 bottom=primary_values)
+ # 注意:原代码中 bottom=0 是错误的!应为 bottom=primary_values
+ # 修正:从主柱顶部开始叠加
+ bars2 = ax.bar(x_labels, nested_values, bottom=primary_values, label=group_labels[1], color=colors[1])
+
+ # 添加数值标签
+ ax.bar_label(bars1, label_type='edge', padding=3) # 边缘标签
+ ax.bar_label(bars2, label_type='center', padding=1) # 中心标签
+
+ # 旋转X轴标签,避免重叠
+ plt.xticks(rotation=45, ha='right')
+
+ # 添加图例
+ ax.legend()
+
+ # 自动紧凑布局
+ plt.tight_layout()
+
+ return _save_to_base64(dpi)
+
+
+# ===========================
+# 2. gen_horizontal_stacked_bars:横向堆叠柱状图(多分组)
+# ===========================
+
+def gen_horizontal_stacked_bars(
+ data_shap: np.ndarray,
+ x_labels: List[str],
+ y_labels: List[str],
+ y_data_unit: str = '',
+ legend_title: str = '',
+ color_palette: str = 'BuPu',
+ dpi: int = 128
+) -> str:
+ """
+ 生成横向堆叠柱状图(支持多个分组)。
+
+ 数据结构说明:
+ - data_shap: numpy.ndarray, shape = (len(y_labels), len(x_labels))
+ - 每行代表一个 Y 轴类别(如:城市、部门)
+ - 每列代表一个 X 轴分组(如:年份、产品类型)
+ - 例如:data_shap[2][1] 表示第3个Y类别(y_labels[2])在第2个X分组(x_labels[1])的数值
+
+ - x_labels: ['2021', '2022', '2023'] # 每个分组的名称(对应列)
+ - y_labels: ['北京', '上海', '广州'] # 每个柱子的类别名称(对应行)
+ - y_data_unit: 如 '人'、'万元',用于标注总量
+
+ :param data_shap: 二维数值数组,shape=(行数, 列数)
+ :param x_labels: 每个分组的标签列表
+ :param y_labels: 每个柱子的标签列表
+ :param y_data_unit: Y轴总量单位(如 '人')
+ :param legend_title: 图例标题
+ :param color_palette: Seaborn 色盘
+ :param dpi: 分辨率
+ :return: base64 编码的 SVG 图像字符串
+ """
+ if data_shap.ndim != 2:
+ raise ValueError("data_shap 必须是二维数组,shape=(Y, X)")
+
+ if len(x_labels) != data_shap.shape[1]:
+ raise ValueError(f"x_labels 长度 ({len(x_labels)}) 与 data_shap 列数 ({data_shap.shape[1]}) 不匹配")
+
+ if len(y_labels) != data_shap.shape[0]:
+ raise ValueError(f"y_labels 长度 ({len(y_labels)}) 与 data_shap 行数 ({data_shap.shape[0]}) 不匹配")
+
+ _setup_plot()
+
+ colors = chart.get_seaborn_colors(len(x_labels), palette=color_palette)
+ fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
+
+ # 初始化左侧起始位置(从0开始)
+ left_pos = np.zeros(len(y_labels))
+ bars = [] # 存储每个分组的柱形对象,用于后续标签绘制
+
+ # 按列(分组)逐个绘制横向堆叠柱形
+ for i in range(len(x_labels)):
+ bar = ax.barh(
+ y_labels,
+ data_shap[:, i],
+ left=left_pos,
+ label=x_labels[i],
+ color=colors[i],
+ edgecolor='white',
+ linewidth=0.5
+ )
+ bars.append(bar)
+ left_pos += data_shap[:, i] # 更新下一次的起始位置
+
+ # 计算每个Y类别的总和
+ totals = np.sum(data_shap, axis=1)
+ # 构建带总量的Y标签:如 "北京\n1200人"
+ y_labels_mix = [f"{place}\n{int(total)}{y_data_unit}" for place, total in zip(y_labels, totals)]
+ ax.set_yticks(np.arange(len(y_labels)))
+ ax.set_yticklabels(y_labels_mix)
+ ax.invert_yaxis() # 使第一个类别在顶部
+
+ # 图例
+ ax.legend(title=legend_title, bbox_to_anchor=(0.958, 0.25), frameon=True, framealpha=0.7)
+
+ # 智能标签函数:根据柱宽决定标签位置
+ def add_smart_labels(spacing_pixels: int = 10):
+ dpi = fig.dpi
+ xlim = ax.get_xlim()
+ x_range = xlim[1] - xlim[0]
+ spacing_data = spacing_pixels * x_range / (fig.get_size_inches()[0] * dpi)
+
+ text_style = {
+ 'fontsize': 9,
+ 'va': 'center',
+ 'bbox': {
+ 'boxstyle': 'round',
+ 'facecolor': 'white',
+ 'alpha': 0.7,
+ 'edgecolor': 'none',
+ 'pad': 0.3
+ }
+ }
+
+ for i, (place, total_width) in enumerate(zip(y_labels, totals)):
+ values = [bar[i].get_width() for bar in bars]
+ label_text = "+".join(str(int(v)) for v in values if v > 0)
+ if not label_text:
+ continue
+
+ text_width = (len(label_text) * 0.015 + 0.05) * x_range
+ y_pos = bars[0][i].get_y() + bars[0][i].get_height() / 2
+
+ if total_width >= text_width + 2 * spacing_data:
+ ax.text(total_width - spacing_data, y_pos, label_text, ha='right', **text_style)
+ else:
+ ax.text(total_width + spacing_data, y_pos, label_text, ha='left', **text_style)
+
+ add_smart_labels()
+
+ # 添加网格线(仅x轴方向)
+ ax.grid(axis='x', linestyle=':', alpha=0.6)
+ plt.tight_layout()
+
+ return _save_to_base64(dpi)
+
+
+# ===========================
+# 3. gen_percent_stacked_bars:百分比堆叠柱状图
+# ===========================
+
+def gen_percent_stacked_bars(
+ data: Dict[str, List[float]],
+ x_labels: List[str],
+ legend_title: str = '',
+ color_palette: str = 'BuPu',
+ dpi: int = 128
+) -> str:
+ """
+ 绘制百分比堆叠柱状图(所有柱子高度统一为100%)。
+
+ 数据结构说明:
+ - data: Dict[str, List[float]]
+ - 每个 key 代表一个分组(如 'A组', 'B组')
+ - 每个 value 是长度为 len(x_labels) 的数值列表,表示该组在每个X标签下的数值
+ - 示例:{'A组': [30, 40, 50], 'B组': [70, 60, 50]} → 每列总和为100
+
+ - x_labels: ['2021', '2022', '2023'] # 每个柱子的标签
+
+ 注意:每个X位置的总和必须一致(或接近),否则百分比可能失真。
+ 本函数自动将每列归一化为100%。
+
+ :param data: 分组数据字典,键为组名,值为数值列表
+ :param x_labels: X轴标签列表
+ :param legend_title: 图例标题
+ :param color_palette: Seaborn 色盘
+ :param dpi: 分辨率
+ :return: base64 编码的 SVG 图像字符串
+ """
+ if not data:
+ raise ValueError("data 不能为空字典")
+
+ # 验证所有组长度一致
+ group_lengths = [len(values) for values in data.values()]
+ if not all(l == len(x_labels) for l in group_lengths):
+ raise ValueError(f"所有分组的数值列表长度必须等于 x_labels 长度 ({len(x_labels)})")
+
+ _setup_plot()
+
+ # 转换为 numpy 数组
+ group_names = list(data.keys())
+ data_array = np.array([data[group] for group in group_names]) # shape: (n_groups, n_x)
+ percent_data = data_array / data_array.sum(axis=0) * 100 # 按列归一化
+
+ fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
+
+ bottom = np.zeros(len(x_labels)) # 初始底部为0
+
+ # 绘制每个分组
+ for i, (group, color) in enumerate(zip(group_names, chart.get_seaborn_colors(len(group_names), palette=color_palette))):
+ ax.bar(x_labels, percent_data[i], bottom=bottom, label=group, color=color)
+ bottom += percent_data[i]
+
+ # 设置Y轴范围为0-100%
+ ax.set_ylim(0, 100)
+
+ # 添加百分比标签
+ for container in ax.containers:
+ ax.bar_label(container, label_type='center', fmt='%.1f%%', padding=0)
+
+ # 图例
+ ax.legend(title=legend_title, loc='center left', bbox_to_anchor=(1, 0.5), frameon=True, framealpha=0.7)
+
+ plt.tight_layout()
+
+ return _save_to_base64(dpi)
\ No newline at end of file
diff --git a/paste/chart/line.py b/paste/chart/line.py
new file mode 100644
index 0000000..c9e5089
--- /dev/null
+++ b/paste/chart/line.py
@@ -0,0 +1,213 @@
+import base64
+import datetime
+from io import BytesIO
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import pandas as pd
+from matplotlib import pyplot as plt
+from matplotlib.dates import MonthLocator, DateFormatter
+from scipy.interpolate import make_interp_spline
+
+from paste import chart
+from paste.util import ufont
+
+
+def gen_lines(data_dict: Dict[str, pd.Series], color_palette: str = 'BuPu', dpi: int = 128) -> str:
+ """
+ 生成多条折线图,用于对比不同年份的时间序列数据(如月度数据)。
+
+ 数据结构要求:
+ - data_dict: Dict[str, pd.Series]
+ - key: 字符串,表示年份(如 '2022', '2023'),必须可转为整数
+ - value: pd.Series,索引为日期(datetime-like),值为数值(如销售额、访问量等)
+ - 示例:
+ {
+ '2022': pd.Series([100, 120, 90], index=pd.date_range('2022-01-01', periods=3, freq='M')),
+ '2023': pd.Series([110, 130, 95], index=pd.date_range('2023-01-01', periods=3, freq='M'))
+ }
+
+ 功能说明:
+ - 自动识别当前年份,并高亮显示其曲线(加粗、加圆点标记)
+ - 使用每两个月一次的主刻度,格式化为 'MM-DD'
+ - 保存为 base64 编码的 SVG 图像,适用于 Web 前端直接嵌入
+
+ :param data_dict: 年度时间序列数据字典,结构如上
+ :param color_palette: Seaborn 色盘名称,默认 'BuPu'
+ :param dpi: 输出图像分辨率,默认 128
+ :return: base64 编码的 SVG 图像 Data URL 字符串
+ """
+ # === 颜色准备 ===
+ colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette)
+ if len(colors) == 0:
+ raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效")
+
+ # === 字体设置 ===
+ available_font = ufont.get_fonts()
+ plt.rcParams['font.sans-serif'] = list(available_font)
+ plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题
+
+ # === 创建画布 ===
+ plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi)
+ ax = plt.gca()
+
+ # === 绘制每条折线 ===
+ current_year = str(datetime.date.today().year)
+ for i, year in enumerate(data_dict.keys()):
+ s = data_dict[year]
+
+ # 检查数据有效性
+ if len(s) == 0:
+ raise ValueError(f"数据序列 {year} 为空,请检查输入")
+
+ # 判断是否为当前年份,决定样式
+ is_current_year = (year == current_year)
+ linewidth = 2.5 if is_current_year else 1.5
+ marker = 'o' if is_current_year else None
+ label = f'{year}年'
+
+ ax.plot(
+ s.index, s.values,
+ color=colors[i],
+ linewidth=linewidth,
+ alpha=0.9,
+ label=label,
+ marker=marker,
+ markersize=4,
+ markevery=30 # 每30个点显示一个标记,避免密集
+ )
+
+ # === 智能日期刻度 ===
+ ax.xaxis.set_major_locator(MonthLocator(bymonth=range(1, 13, 2))) # 每两个月一个主刻度
+ ax.xaxis.set_major_formatter(DateFormatter('%m-%d')) # 格式化为 MM-DD
+
+ # === 网格与边框美化 ===
+ ax.grid(True, which='major', linestyle='--', alpha=0.6)
+ ax.grid(True, which='minor', linestyle=':', alpha=0.3)
+ for spine in ['top', 'right']:
+ ax.spines[spine].set_visible(False)
+ ax.spines['left'].set_color('#d9d9d9')
+ ax.spines['bottom'].set_color('#d9d9d9')
+
+ # === 图例 ===
+ legend = ax.legend(loc='upper right', frameon=True, framealpha=0.7, edgecolor='#f0f0f0')
+ legend.get_frame().set_linewidth(1)
+
+ # === 输出为 base64 SVG ===
+ buffer = BytesIO()
+ plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
+ plt.close()
+
+ image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
+ img_base64 = f"data:image/svg+xml;base64,{image_base64}"
+ return img_base64
+
+
+def gen_splines(data_dict: Dict[str, pd.Series], total: pd.Series, x_labels: Optional[List[Union[str, int]]] = None,
+ color_palette: str = 'BuPu', dpi: int = 128) -> str:
+ """
+ 生成平滑曲线图,适用于非时间序列数据(如按类别排序的数值),通过样条插值实现曲线平滑。
+
+ 数据结构要求:
+ - data_dict: Dict[str, pd.Series]
+ - key: 字符串,表示数据系列名称(如 'A产品', 'B产品')
+ - value: pd.Series,索引为数值型(如 1~12),值为观测值(如销量)
+ - 示例:
+ {
+ 'A产品': pd.Series([10, 15, 12, 18], index=[1,2,3,4]),
+ 'B产品': pd.Series([8, 14, 16, 20], index=[1,2,3,4])
+ }
+
+ - total: pd.Series
+ - 索引必须与 data_dict 的 key 一致
+ - 值为每个系列的总和(用于图例标注)
+ - 示例:pd.Series({'A产品': 55, 'B产品': 58})
+
+ - x_labels: List[Union[str, int]], 可选
+ - 用于自定义 X 轴刻度标签(通常为原始索引值)
+ - 若为 None,则使用插值后的密集点(不显示原始刻度)
+ - 若提供,则在绘图后手动设置刻度,避免插值后刻度混乱
+
+ 功能说明:
+ - 按 total 值降序排列绘图顺序(重要数据优先显示)
+ - 使用三次样条插值(k=3)生成平滑曲线
+ - 图例中显示原始总值(如 "A产品(55)")
+
+ :param data_dict: 各系列的原始观测数据,结构如上
+ :param total: 每个系列的总和,用于排序和图例标注
+ :param x_labels: 可选,原始 X 轴标签列表(如 [1,2,3,4]),用于控制刻度显示
+ :param color_palette: Seaborn 色盘名称
+ :param dpi: 输出分辨率
+ :return: base64 编码的 SVG 图像 Data URL 字符串
+ """
+ # === 颜色准备 ===
+ colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette)
+ if len(colors) == 0:
+ raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效")
+
+ # === 字体设置 ===
+ available_font = ufont.get_fonts()
+ plt.rcParams['font.sans-serif'] = list(available_font)
+ plt.rcParams['axes.unicode_minus'] = False
+
+ # === 创建画布 ===
+ plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi)
+ ax = plt.gca()
+
+ # === 按总值排序,确保重要数据优先显示 ===
+ total_sorted = total.sort_values(ascending=False)
+
+ # === 绘制每条平滑曲线 ===
+ for i, series_name in enumerate(total_sorted.index):
+ s = data_dict[series_name]
+
+ # 检查数据完整性
+ if len(s) < 2:
+ raise ValueError(f"系列 {series_name} 数据点少于2个,无法进行插值")
+
+ x = s.index.values.astype(float) # 确保为数值类型
+ y = s.values.astype(float)
+
+ # 插值:生成 100 个平滑点(线性插值不足以平滑,样条更优)
+ x_new = np.linspace(x.min(), x.max(), 100)
+ spline = make_interp_spline(x, y, k=3) # 三次样条插值
+ y_smooth = spline(x_new)
+
+ # 转回 Series 便于绘图
+ s_smooth = pd.Series(y_smooth, index=x_new)
+
+ ax.plot(
+ s_smooth.index, s_smooth.values,
+ color=colors[i],
+ linewidth=2,
+ alpha=0.9,
+ label=f'{series_name}({total_sorted[series_name]})', # 图例中显示总值
+ markersize=4,
+ markevery=30 # 标记稀疏,避免视觉干扰
+ )
+
+ # === 手动设置 X 轴刻度(若提供)===
+ if x_labels is not None:
+ # 确保 x_labels 与原始 x 一致(可选验证)
+ plt.xticks(ticks=x_labels, labels=x_labels)
+
+ # === 网格与边框 ===
+ ax.grid(True, which='major', linestyle='--', alpha=0.6)
+ ax.grid(True, which='minor', linestyle=':', alpha=0.3)
+ for spine in ['top', 'right']:
+ ax.spines[spine].set_visible(False)
+ ax.spines['left'].set_color('#d9d9d9')
+ ax.spines['bottom'].set_color('#d9d9d9')
+
+ # === 图例 ===
+ legend = ax.legend(loc='upper left', frameon=True, framealpha=0.7, edgecolor='#f0f0f0')
+ legend.get_frame().set_linewidth(1)
+
+ # === 输出为 base64 SVG ===
+ buffer = BytesIO()
+ plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
+ plt.close()
+
+ image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
+ img_base64 = f"data:image/svg+xml;base64,{image_base64}"
+ return img_base64
\ No newline at end of file
diff --git a/paste/chart/pie.py b/paste/chart/pie.py
new file mode 100644
index 0000000..b24ffea
--- /dev/null
+++ b/paste/chart/pie.py
@@ -0,0 +1,144 @@
+import base64
+from io import BytesIO
+
+from matplotlib import pyplot as plt
+from matplotlib.patches import Rectangle
+
+from paste import chart
+from paste.util import ufont
+
+
+def gen_pie(data_df, value_column, percentage_column, legend_labels, color_palette='BuPu', dpi=128):
+ """
+ 生成环形图(Doughnut Chart),并附带右侧图例(色块 + 文字描述)。
+
+ 注意:虽然函数名为 gen_pie,但实际绘制的是环形图(有中心空洞),非传统饼图。
+
+ 参数说明(数据结构要求):
+ --------------------------
+ data_df : pandas.DataFrame
+ 必须包含以下三列:
+ - value_column (数值列): 每个类别的数值(如设备数量),用于计算扇区角度。
+ - percentage_column (百分比列): 每个类别的百分比(如 35.2%),用于显示在图例中。
+ - legend_labels (标签列): 每个类别的名称(如 '服务器', '交换机'),用于图例文本。
+
+ 示例结构:
+ | server_count | percentage | device_type |
+ |--------------|------------|-------------|
+ | 35 | 35.2% | 服务器 |
+ | 28 | 28.1% | 交换机 |
+ | 22 | 22.0% | 路由器 |
+
+ 要求:
+ - 所有数值必须 ≥ 0;
+ - 百分比列应为字符串格式(含'%'符号),或可被转换为字符串;
+ - 所有列长度必须一致,且与 data_df 的行数匹配。
+
+ value_column : str
+ data_df 中表示数值的列名(如 'server_count')。
+
+ percentage_column : str
+ data_df 中表示百分比的列名(如 'percentage'),用于图例中显示。
+
+ legend_labels : str
+ data_df 中表示类别名称的列名(如 'device_type'),用于图例文本。
+
+ color_palette : str, default='BuPu'
+ Seaborn 颜色调色板名称,用于为每个扇区分配颜色。
+ 可选值参考:'BuPu', 'viridis', 'plasma', 'Set3' 等。
+ 若类别数 > 调色板颜色数,会自动循环使用。
+
+ dpi : int, default=128
+ 输出图像的分辨率(dots per inch),影响 SVG 清晰度。
+
+ 返回值:
+ --------
+ str : Base64 编码的 SVG 图像 Data URL,可直接用于 HTML
。
+ """
+
+ # === 1. 颜色准备 ===
+ # 从 Seaborn 获取指定调色板的颜色序列,长度等于数据行数
+ colors = chart.get_seaborn_colors(len(data_df.index), palette=color_palette)
+
+ # === 2. 字体设置 ===
+ # 获取系统可用中文字体,优先使用支持中文的字体
+ available_font = ufont.get_fonts()
+ plt.rcParams['font.sans-serif'] = list(available_font)
+ plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题
+
+ # === 3. 创建画布与子图 ===
+ # 使用非白色背景,便于嵌入网页(透明背景)
+ fig = plt.figure(figsize=(8, 6), facecolor='none', dpi=dpi)
+
+ # 左侧:环形图区域
+ ax1 = fig.add_axes((0.1, 0.05, 0.6, 0.9)) # [left, bottom, width, height]
+ # 右侧:图例区域(色块+文字)
+ ax2 = fig.add_axes((0.75, 0.05, 0.2, 0.9))
+
+ # 移除两个子图的坐标轴(纯图形,无刻度)
+ for ax in [ax1, ax2]:
+ ax.set_axis_off()
+
+ # === 4. 绘制环形图 ===
+ # 使用 wedgeprops 设置内环宽度,实现环形效果
+ wedges, _ = ax1.pie(
+ data_df[value_column], # 数值列,决定扇区大小
+ colors=colors, # 颜色序列
+ startangle=90, # 从正上方开始绘制(更美观)
+ wedgeprops=dict(
+ width=0.6, # 环形宽度(0~1),越大越薄
+ edgecolor='white', # 边缘白色,提升对比度
+ linewidth=1 # 边缘线宽
+ ),
+ radius=1.2, # 半径略大于1,避免边缘被裁剪
+ )
+
+ # === 5. 右侧图例:色块 + 文字 ===
+ # 图例参数定义
+ num_items = len(data_df.index)
+ box_size = 0.08 # 每个色块大小(宽度和高度)
+ text_offset = 0.05 # 文字与色块的水平间距
+ font_size = 24 # 字体大小
+ line_height = 0.1 # 每行占用高度(包括上下间距)
+ total_legend_height = num_items * line_height
+ start_y = 0.45 + total_legend_height / 2 # 使图例垂直居中于右侧区域
+
+ y_pos = start_y # 当前绘制的 y 坐标
+
+ # 遍历每一类,绘制色块和文本
+ for i, (label, color) in enumerate(zip(data_df[legend_labels], colors)):
+ # 绘制色块矩形
+ ax2.add_patch(
+ Rectangle(
+ (0.05, y_pos - box_size / 2), # 左下角坐标:x=0.05,y居中
+ box_size, box_size, # 宽高均为 box_size
+ facecolor=color, # 填充颜色
+ edgecolor='white', # 白色描边
+ lw=1 # 线宽
+ )
+ )
+
+ # 绘制文字标签:名称 + 数值 + 百分比
+ ax2.text(
+ 0.05 + box_size + text_offset, # 文字起始 x 位置:色块右侧 + 间距
+ y_pos, # y 位置居中于色块
+ f"{label}({data_df[value_column][i]}台,{data_df[percentage_column][i]}%)",
+ va='center', # 垂直居中
+ ha='left', # 水平左对齐
+ fontsize=font_size,
+ fontweight='bold'
+ )
+
+ y_pos -= line_height # 下移一行,准备下一个图例项
+
+ # === 6. 输出图像 ===
+ # 使用 BytesIO 缓存 SVG 图像
+ buffer = BytesIO()
+ plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
+ plt.close() # 关闭图形以释放内存
+
+ # 编码为 Base64 并构造 Data URL
+ image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
+ img_base64 = f"data:image/svg+xml;base64,{image_base64}"
+
+ return img_base64
\ No newline at end of file
diff --git a/paste/core/__init__.py b/paste/core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/paste/core/aio_pool.py b/paste/core/aio_pool.py
new file mode 100644
index 0000000..adb0b8e
--- /dev/null
+++ b/paste/core/aio_pool.py
@@ -0,0 +1,153 @@
+import asyncio
+import datetime
+from typing import Optional, Callable, Set, Any
+
+import psutil
+from dateutil.relativedelta import relativedelta
+
+from paste.core.logging import echo_log
+
+MAX_BACKGROUND_TASKS = 3000
+"""
+最大任务数量,根据服务器内存调整。
+"""
+
+aio_loop: Optional[asyncio.AbstractEventLoop] = None
+"""
+异步循环对象。
+"""
+
+aio_runner: Optional[Callable] = None
+"""
+异步方法运行器,对应::
+
+ asyncio.events.AbstractEventLoop.run_until_complete()
+
+方法的返回值 。
+"""
+
+global_background_tasks: Set[asyncio.Task] = set()
+"""
+全局后台任务池。
+"""
+
+
+async def run_background_task(coro, max_tasks: int = MAX_BACKGROUND_TASKS):
+ """
+ 运行后台任务。
+
+ :param coro: 要在后台执行的协程。
+ :param max_tasks: 背压总量
+ """
+ global MAX_BACKGROUND_TASKS
+ if max_tasks != MAX_BACKGROUND_TASKS:
+ MAX_BACKGROUND_TASKS = max_tasks
+
+ if len(global_background_tasks) >= MAX_BACKGROUND_TASKS:
+ # 增加背压控制
+ await asyncio.wait(global_background_tasks, return_when=asyncio.FIRST_COMPLETED)
+
+ task = asyncio.create_task(coro)
+ global_background_tasks.add(task)
+ # 任务完成后自动移除
+ task.add_done_callback(global_background_tasks.discard)
+
+
+def get_aio_loop():
+ """
+ 这里必须采用方法,在适当的时间点创建事件循环对象,否则会导致服务无法启动。
+ 主要是测试到 EventLoop 之间的冲突,或异步事件已经在运行,导致无法顺利执行。
+
+ :return: 事件循环对象
+ """
+ global aio_loop
+ if aio_loop:
+ return aio_loop
+ else:
+ try:
+ # 尝试获取当前运行中的事件循环
+ aio_loop = asyncio.get_running_loop()
+ except RuntimeError:
+ # 如果没有运行中的循环,才创建新的
+ aio_loop = asyncio.new_event_loop()
+ return aio_loop
+
+
+def get_aio_runner():
+ """
+ 这里必须采用方法,在适当的时间点创建事件循环对象,否则会导致服务无法启动。
+ 主要是测试到 EventLoop 之间的冲突,或异步事件已经在运行,导致无法顺利执行。
+
+ :return: 运行器对象
+ """
+ global aio_loop, aio_runner
+ if aio_runner:
+ return aio_runner
+ else:
+ aio_runner = get_aio_loop().run_until_complete
+ return aio_runner
+
+
+def process_info(pid: int):
+ """
+ 若传入的 PID 存在,则返回进程信息,否则返回 None。
+
+ :param pid: 进程 ID
+ :return: 进程信息
+ """
+ try:
+ process = psutil.Process(pid)
+ _delta = relativedelta(datetime.datetime.now(), datetime.datetime.fromtimestamp(process.create_time()))
+ _d, _h, _m, _s = _delta.days, _delta.hours, _delta.minutes, _delta.seconds
+ process.cpu_percent()
+ return {
+ 'name': process.name(),
+ 'cpu_usage': f"{process.cpu_percent()}%",
+ 'memory_usage': f"{process.memory_info().rss / (1024 * 1024):.3f}MB",
+ 'running_time': f"{_d}天{_h}时{_m}分{_s}秒",
+ }
+ except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
+ return None
+
+
+async def gather_with_retry(*coro_constructors: Callable[[], Any],
+ max_retries: int = 3,
+ eba: int = 0.5,
+ return_exceptions: bool = False) -> tuple[Any, ...]:
+ """
+ 封装 asyncio.gather,支持对一组并发任务进行 N 次重试。
+
+ 与 asyncio.gather 对齐:
+ - 使用 *args 传参,而非列表
+ - 支持 return_exceptions 参数
+ - 返回 list 类型,顺序一致
+
+ 但要求:每个参数必须是一个**无参函数**,调用后返回一个 awaitable。
+ 这样才能在每次重试时重新创建任务。
+
+ Args:
+ *coro_constructors: Callable 对象,async 方法要用 lambda 封装
+ max_retries (int): 最大重试次数(总尝试次数 = max_retries + 1)
+ eba (int): 指数退避的起始等待时间
+ return_exceptions (bool): 若为 True,异常作为结果返回,不抛出
+
+ Returns:
+ list: 所有任务结果列表。若 return_exceptions=True,异常也会作为列表元素。
+
+ Raises:
+ Exception: 当 return_exceptions=False 且所有重试都失败时,抛出最后一次尝试中的第一个异常。
+ """
+ for attempt in range(max_retries + 1):
+ try:
+ tasks = [ctor() for ctor in coro_constructors]
+ results = await asyncio.gather(*tasks, return_exceptions=return_exceptions)
+ return tuple(results)
+ except Exception as e:
+ if attempt == max_retries:
+ echo_log(f"共执行 {max_retries} 次后全部失败: {str(e)}")
+ else:
+ echo_log(f"执行第 {attempt + 1} 次重试失败.")
+ # 指数退避
+ await asyncio.sleep(eba * (attempt + 1))
+
+ raise RuntimeError("Unreachable")
diff --git a/paste/core/config.py b/paste/core/config.py
new file mode 100644
index 0000000..115cf96
--- /dev/null
+++ b/paste/core/config.py
@@ -0,0 +1,53 @@
+"""
+读取配置信息的方法集合。
+"""
+
+import json
+import os
+
+from paste.util import ufile, udict
+
+GLOBAL_CONFIG = None
+"""
+全局单例配置系统。
+"""
+
+
+def load_config() -> dict:
+ """
+ 生成配置 JSON 对象。
+
+ :return: JSON 对象
+ """
+ global GLOBAL_CONFIG
+ if GLOBAL_CONFIG is None:
+ config_file = os.path.abspath(os.path.join(os.path.curdir, 'config.json'))
+ GLOBAL_CONFIG = json.loads(ufile.read_to_buffer(config_file))
+ return GLOBAL_CONFIG
+
+
+def get_config_by_path(path: str, default=None):
+ """
+ 读取配置参数。若 path 存在则返回值;若 path 不存在,且没有默认值,则抛出异常,否则返回默认值。
+
+ :param path: 字典中的 key 路径,以"."号分隔
+ :param default: 默认值,为 None 时表示未设置,此时若键名不存在,会抛出异常
+ """
+ config = load_config()
+ _result = udict.get_by_path(config, path, default)
+ if _result is None:
+ if default is None:
+ raise AssertionError('未读取到配置参数: %s' % path)
+ else:
+ return default
+ return _result
+
+
+def get_config(key: str, default=None):
+ """
+ 读取配置参数。若 key 存在则返回值;若 key 不存在,且没有默认值,则抛出异常,否则返回默认值。
+
+ :param key: 键名,或配置字典中的 path
+ :param default: 默认值,为 None 时表示未设置,此时若键名不存在,会抛出异常
+ """
+ return get_config_by_path(path=key, default=default)
diff --git a/paste/core/logging.py b/paste/core/logging.py
new file mode 100644
index 0000000..e483015
--- /dev/null
+++ b/paste/core/logging.py
@@ -0,0 +1,160 @@
+"""
+实现对日志文件的配置封装,详细参考 getLogger 方法。
+输出日志使用 echo_log 方法。
+输出到日志文件使用 logToFile 方法。
+"""
+
+import logging
+import sys
+import traceback
+from logging import handlers
+from typing import Any, Union, Optional
+
+from paste.core import config
+
+logger_config_name = 'logger.default'
+"""
+默认配置字段名称,当在 getLogger 方法中设置了不同名称后,该变量会被修改。
+"""
+
+paste_logger: Optional[logging.Logger] = None
+"""
+全局日志对象,获取日志对象时初始化。
+"""
+
+
+def set_logger_config(config_name: str):
+ """
+ 设置新的日志配置名称。
+
+ :param config_name: 日志配置名称
+ """
+ global logger_config_name
+ if config_name != logger_config_name:
+ logger_config_name = config_name
+
+
+def get_logger():
+ """
+ 取得日志对象。先根据配置,更新系统日志配置。若配置了额外的日志文件、格式、层级,则增加响应的日志输出。
+
+ 注意:除非额外配置,否则都使用与系统日志相同的配置参数。
+
+ 配置结构参考::
+
+ {
+ "logger": {
+ "basic": {
+ "filename": "sys_log.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "level": 20
+ },
+ "filename": "zb_data_log.log",
+ "name": "ZbData"
+ }
+ }
+ """
+ global paste_logger
+
+ if paste_logger is None:
+ # 系统日志基本配置
+ _log_base_cfg: dict = config.get_config(f'{logger_config_name}.basic', {})
+ # 系统日志文件
+ _base_log_file = _log_base_cfg.get('filename', '')
+ _base_log_format = _log_base_cfg.get('format', '')
+ _base_formatter = logging.Formatter(_base_log_format)
+ _base_level = _log_base_cfg.get('level', logging.INFO)
+
+ # 日志名称,若不配置,则使用名称:SQL_ALEX
+ _log_name = config.get_config(f'{logger_config_name}.name', 'SQL_ALEX')
+ # 日志文件最大值
+ _log_max_bytes = config.get_config(f'{logger_config_name}.max_bytes', 0)
+ # 日志文件备份数量
+ _log_backup_count = config.get_config(f'{logger_config_name}.backup_count', 0)
+ # 日志格式,若不配置,则使用 base 中的配置
+ _log_format = config.get_config(f'{logger_config_name}.format', _base_log_format)
+ _formatter = logging.Formatter(_log_format)
+ # 日志层级,若不配置,则使用 base 中的配置
+ _log_level = config.get_config(f'{logger_config_name}.level', _base_level)
+ # 日志文件
+ _log_file = config.get_config(f'{logger_config_name}.filename', '')
+
+ # 更新系统日志基本配置
+ logging.basicConfig(**_log_base_cfg)
+ # 重新绑定系统日志文件句柄
+ _base_log_file_handler: Optional[handlers.RotatingFileHandler] = None
+ if _base_log_file not in (None, ''):
+ _base_log_file_handler = handlers.RotatingFileHandler(
+ _base_log_file, maxBytes=_log_max_bytes, backupCount=_log_backup_count
+ )
+ _base_log_file_handler.setFormatter(_base_formatter)
+ logging.root.handlers = [_base_log_file_handler]
+
+ # 创建日志对象
+ paste_logger = logging.Logger(name=_log_name, level=_log_level)
+ # 绑定日志文件句柄
+ if _log_file not in ('', None):
+ # 若配置了日志文件,则创建文件句柄
+ _file_handler = handlers.RotatingFileHandler(
+ _log_file, maxBytes=_log_max_bytes, backupCount=_log_backup_count
+ )
+ _file_handler.setFormatter(_formatter)
+ paste_logger.addHandler(_file_handler)
+ else:
+ # 若未配置,则使用系统日志文件
+ if _base_log_file_handler is not None:
+ paste_logger.addHandler(_base_log_file_handler)
+
+ # 绑定控制台输出
+ _console_handler = logging.StreamHandler()
+ _console_handler.setFormatter(_formatter)
+ paste_logger.addHandler(_console_handler)
+
+ return paste_logger
+
+
+def echo_log(msg: Union[str, Exception], level: int = logging.INFO, is_log_exc: bool = False):
+ """
+ 输出日志文本。默认输出到日志文件,但是可能不便于查询,这里应该考虑支持输出到日志数据库。
+
+ :param msg: 消息内容,当是 Exception 对象时,从 args 中取出第一项作为消息
+ :param level: 消息等级
+ :param is_log_exc: 是否输出异常的 Traceback 信息到日志文件
+ """
+ _root = logging.root
+ _logging = get_logger()
+ _log_level = level
+ if isinstance(msg, Exception):
+ _log_level = logging.ERROR
+ if len(msg.args) > 0 and isinstance(msg.args[0], str):
+ msg = msg.args[0]
+ else:
+ msg = str(msg)
+
+ _logging.log(level=_log_level, msg=msg)
+ if is_log_exc:
+ exception_to_file()
+
+
+def exception_to_file():
+ """
+ 自动检测异常,并输出异常的 Traceback 信息到日志文件。
+ """
+ _, _, tb = sys.exc_info()
+ if tb is not None:
+ _msg_list = ['Traceback: \n\n'] + traceback.format_tb(tb)
+ log_to_file(msg=''.join(_msg_list), level=logging.ERROR)
+
+
+def log_to_file(msg: Any, level: int = logging.INFO):
+ """
+ 输出消息到日志文件。
+
+ :param msg: 消息
+ :param level: 消息等级
+ """
+ _logger = get_logger()
+ _record = _logger.makeRecord(name=_logger.name, level=level, fn=__file__, lno=0, args=(), exc_info=None, msg=msg)
+ for hdl in _logger.handlers:
+ if isinstance(hdl, logging.FileHandler):
+ hdl.handle(_record)
diff --git a/paste/db/__init__.py b/paste/db/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/paste/db/baseadapter.py b/paste/db/baseadapter.py
new file mode 100755
index 0000000..8f811af
--- /dev/null
+++ b/paste/db/baseadapter.py
@@ -0,0 +1,830 @@
+"""
+数据访问适配器。主要封装了数据访问的基础方法。
+"""
+import datetime
+import random
+import uuid
+from typing import Optional, Any, Union, List, Sequence
+
+import pandas as pd
+from sqlalchemy import Column, String, DateTime, Date, inspect, Table, Integer, select, Numeric, ForeignKey, func, \
+ text, desc, Result, Engine, tuple_
+from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker
+from sqlalchemy.orm import sessionmaker, Session, InstanceState, InstrumentedAttribute, registry
+from sqlalchemy.sql import Select
+from sqlalchemy.sql.base import ColumnCollection
+from sqlalchemy.sql.compiler import StrSQLCompiler
+from sqlalchemy.sql.elements import BinaryExpression
+
+from paste.db import engine
+from paste.util import udict
+
+LOCAL_DATE_FORMAT = '%Y-%m-%d'
+LOCAL_TIME_FORMAT = '%H:%M:%S'
+LOCAL_DATETIME_FORMAT = f'{LOCAL_DATE_FORMAT} {LOCAL_TIME_FORMAT}'
+
+
+def guid() -> str:
+ """
+ 生成 GUID。
+
+ :return: GUID
+ """
+ return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(uuid.uuid1()) + str(random.random()))).replace('-', '')
+
+
+def get_session() -> Session:
+ """
+ 获取线程安全的连接会话对象。
+
+ :return: 连接会话对象
+ """
+ _engine: Engine = engine.connect_engine()
+ session_maker = sessionmaker(bind=_engine)
+ return session_maker()
+
+
+def get_aio_session() -> AsyncSession:
+ """
+ 获取异步连接会话对象。
+
+ :return: 连接会话对象
+ """
+ _engine: AsyncEngine = engine.async_connect_engine()
+ session_maker = async_sessionmaker(bind=_engine, class_=AsyncSession)
+ return session_maker()
+
+
+Base = registry().generate_base()
+"""
+取得 SQLAlchemy ORM 的声明式基类。
+"""
+
+
+class BaseAdapter(Base):
+ """
+ 适配器类,集成数据库引擎、连接、会话、元数据操作、部分 DML 数据处理等功能。
+ 主要实现了同步执行方法。
+ """
+ __abstract__ = True
+
+ def __init__(self, **kwargs):
+ """
+ 构造模型对象。
+
+ :param args:
+ :param kwargs:
+ """
+ self._is_new = True
+ super().__init__(**kwargs)
+
+ @classmethod
+ def get_session(cls):
+ """
+ 线程安全的连接会话。
+
+ :return: 同步会话对象
+ """
+ return get_session()
+
+ @classmethod
+ def get_aio_session(cls):
+ """
+ 异步连接会话。
+
+ :return: 异步会话对象
+ """
+ return get_aio_session()
+
+ @classmethod
+ def ping(cls):
+ """
+ 尝试连接数据库服务器。
+
+ :return: 连接结果
+ """
+ engine.connect_engine().connect()
+
+ @classmethod
+ async def tables_in_db(cls):
+ """
+ 取得数据库中所有表名称。
+ :return: 表名称列表
+ """
+ _query = select(text('table_name')).select_from(
+ text('information_schema.tables')
+ ).where(
+ text('table_schema = database()')
+ )
+ _names = list()
+ _session = cls.get_aio_session()
+ try:
+ _result: Result = await _session.execute(_query)
+ _names = list(_result.scalars().all())
+ finally:
+ await _session.close()
+ return _names
+
+ @classmethod
+ async def is_table_exist(cls, table_name: str):
+ """
+ 判断表是否存在。
+ :param table_name: 要判断的表名称
+ """
+ _query = select(func.count()).select_from(
+ text('information_schema.tables')
+ ).where(
+ text('table_schema = database()'), text('table_name = :table_name')
+ ).params(
+ table_name=table_name
+ )
+ _count = 0
+ _session = cls.get_aio_session()
+ try:
+ _result: Result = await _session.execute(_query)
+ _count = int(_result.scalar())
+ finally:
+ await _session.close()
+ return _count > 0
+
+ @classmethod
+ def table(cls) -> Table:
+ """
+ 取得当前模型对象所对应的表对象。
+
+ :return: 表对象
+ """
+ tables = cls.metadata.tables
+ assert cls.__tablename__ is not None, '属性:cls.__tablename__ 未找到.'
+ assert cls.__tablename__ in tables, ('数据库中未找到表:%s.' % cls.__tablename__)
+ return tables[cls.__tablename__]
+
+ @classmethod
+ def columns(cls) -> List[Column]:
+ """
+ 当前对象配置的所有列。
+
+ :return: 所有列
+ """
+ cols = cls.table().columns
+ assert isinstance(cols, ColumnCollection), '%s 列类型错误' % cls.__name__
+ return list(cols)
+
+ @classmethod
+ def instrument_attr(cls, column_name: str) -> Optional[InstrumentedAttribute]:
+ """
+ 取得查询条件绑定属性,用于配置查询条件表达式。
+
+ :param column_name: 列名称
+ :return: 类配置的列信息,即查询条件绑定属性
+ """
+ attr = getattr(cls, column_name)
+ if not isinstance(attr, InstrumentedAttribute):
+ return None
+ return attr
+
+ @classmethod
+ def label(cls, column: Union[Column, Any]):
+ """
+ 以字段的 comment 作为标签返回。
+
+ :return: 指定列的备注
+ """
+ return column.comment
+
+ @classmethod
+ def labels(cls) -> dict:
+ """
+ 取得所有列的名称描述。
+
+ :return: 所有列的描述字典,结构为:{`field`: `comment`}
+ """
+ _label_dict = {}
+ for _col in cls.columns():
+ _label_dict[_col.key] = _col.comment
+ return _label_dict
+
+ @classmethod
+ def field(cls, column: Union[Column, Any]) -> str:
+ """
+ 取得字段名称。
+
+ :param column: 列对象
+ :return: 字段名称
+ """
+ return column.key
+
+ @classmethod
+ def fields(cls) -> List[str]:
+ """
+ 取得所有字段名称列表。
+
+ :return: 所有字段名称列表
+ """
+ return [col.key for col in cls.columns()]
+
+ @classmethod
+ def new_id(cls, **kwargs) -> str:
+ """
+ 用GUID作为新的ID。
+
+ :return: ID
+ """
+ return guid()
+
+ @property
+ def is_new(self):
+ """
+ 通过检查是否已经存在ID判断是否为新建对象。
+
+ :return: 是否为新建对象
+ """
+ return hasattr(self, '_is_new') and self._is_new
+
+ @classmethod
+ def raw_sql(cls, query) -> StrSQLCompiler:
+ """
+ 显示编译后的原始 SQL 命令。
+
+ :return: 编译后的 SQL 命令文本
+ """
+ raw_sql = query.compile(compile_kwargs={'literal_binds': True})
+ return raw_sql
+
+ @classmethod
+ def row_count(cls, *where_expressions: BinaryExpression) -> int:
+ """
+ 按条件取得记录行数。
+
+ :param where_expressions: 查询条件
+ :return: 返回记录行数
+ """
+ query = select(func.count(1)).select_from(cls).where(*where_expressions)
+ session = cls.get_session()
+ count = session.execute(query).scalars().first()
+ session.close()
+ return count
+
+ @classmethod
+ def exist_by_id(cls, id_val: Union[str, int]):
+ """
+ 检查是否有存在的主键。
+
+ :param id_val: 主键值
+ """
+ _wheres = [cls.instrument_attr('id') == id_val]
+ count: int = cls.row_count(*_wheres)
+ return count >= 1
+
+ @classmethod
+ def search_wheres(cls, likes: Optional[dict[str, str]] = None, **kwargs):
+ """
+ 按参数组织查询条件。
+
+ :param likes: 需要执行模糊查询的列
+ :param kwargs: 属性填充参数
+ :return: 查询条件列表
+ """
+ if likes is None:
+ likes = []
+ _query_model = cls().copy_from_dict(kwargs)
+ return _query_model.filter_expressions(likes=likes)
+
+ @classmethod
+ def datalist_query(cls, row_list: List[dict], condition_cols: List[Column]) -> Optional[Select]:
+ """
+ 根据数据列表、查询条件列生成查询。仅支持单表查询。
+
+ :param row_list: 数据列表
+ :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
+ :return: 查询到的数据模型列表
+ """
+ if not row_list or not condition_cols:
+ return None
+
+ _condition_fields = [_col.key for _col in condition_cols]
+
+ # 验证字段存在
+ for field in _condition_fields:
+ if not hasattr(cls, field):
+ raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
+
+ # 构建值列表的元组形式
+ _values_tuples = []
+ for _row in row_list:
+ _val_list = []
+ for _f in _condition_fields:
+ _val = udict.get_with_default(_row, _f, '')
+ # 单独处理日期时间格式,但是此处固定格式
+ if isinstance(_val, datetime.datetime):
+ _val = _val.strftime(LOCAL_DATETIME_FORMAT)
+ if isinstance(_val, datetime.date):
+ _val = _val.strftime(LOCAL_DATE_FORMAT)
+ _val_list.append(_val)
+
+ # 过滤掉全为空的元组
+ if any(v != '' for v in _val_list):
+ _values_tuples.append(tuple(_val_list))
+
+ if not _values_tuples:
+ return None
+
+ # 去重,提高性能
+ _values_tuples = list(set(_values_tuples))
+
+ # 使用参数化查询
+ if len(condition_cols) == 1:
+ # 单字段查询
+ field = getattr(cls, _condition_fields[0])
+ _single_values = [t[0] for t in _values_tuples]
+ query = select(cls).where(field.in_(_single_values))
+ else:
+ # 多字段组合查询
+ conditions = [getattr(cls, field) for field in _condition_fields]
+ query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
+
+ return query
+
+ @classmethod
+ def dataframe_query(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
+ """
+ 根据数据列表、查询条件列生成查询。仅支持单表查询。
+
+ :param row_df: 数据框架
+ :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
+ :return: 查询到的数据模型列表
+ """
+ if row_df.empty or not condition_cols:
+ return None
+
+ _condition_fields = [_col.key for _col in condition_cols]
+
+ # 验证字段存在
+ for field in _condition_fields:
+ if not hasattr(cls, field):
+ raise AttributeError(f"Field '{field}' not found in {cls.__name__}")
+
+ # 去除重复行并删除包含 NaN 的行
+ row_df = row_df[_condition_fields].drop_duplicates().dropna()
+
+ if row_df.empty:
+ return None
+
+ if len(condition_cols) == 1:
+ # 单字段查询
+ _field_name = _condition_fields[0]
+ field = getattr(cls, _field_name)
+ # 直接获取 Series 的值
+ _values_list = row_df[_field_name].tolist()
+ if not _values_list:
+ return None
+ query = select(cls).where(field.in_(_values_list))
+ else:
+ # 多字段组合查询 - 使用向量化操作避免 iterrows()
+ _values_tuples = [tuple(row) for row in row_df.values]
+ # 再次去重确保
+ _values_tuples = list(set(_values_tuples))
+
+ if not _values_tuples:
+ return None
+
+ conditions = [getattr(cls, field) for field in _condition_fields]
+ query = select(cls).where(tuple_(*conditions).in_(_values_tuples))
+
+ return query
+
+ @classmethod
+ def find_by_id(cls, id_val: Union[str, int], reset_session: Optional[bool] = True) -> Optional['BaseAdapter']:
+ """
+ 根据主键ID查找数据,确认有 id 字段后方可使用。
+
+ :param id_val: 主键值
+ :param reset_session: 重置会话连接
+ :return: 数据模型对象
+ """
+ query = select(cls).where(cls.instrument_attr('id') == id_val)
+ session = cls.get_session()
+ model = session.execute(query).scalars().first()
+ session.close()
+
+ if reset_session and isinstance(model, BaseAdapter):
+ model.reset_session()
+
+ return model
+
+ @classmethod
+ def find_by_created(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
+ """
+ 按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
+
+ :param from_t: 开始时间
+ :param to_t: 结束时间
+ """
+ if to_t is None:
+ to_t = from_t
+
+ query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
+ session = cls.get_session()
+ rows = session.execute(query).scalars().all()
+ session.close()
+ return rows
+
+ @classmethod
+ def find_by_updated(cls, from_t: datetime.datetime, to_t: Optional[datetime.datetime] = None) -> Sequence['BaseAdapter']:
+ """
+ 按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
+
+ :param from_t: 开始时间
+ :param to_t: 结束时间
+ """
+ if to_t is None:
+ to_t = from_t
+
+ query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
+ session = cls.get_session()
+ rows = session.execute(query).scalars().all()
+ session.close()
+ return rows
+
+ @classmethod
+ def find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]) -> Sequence['BaseAdapter']:
+ """
+ 根据数据列表查询已经在数据库中的数据。
+
+ :param row_list: 数据列表
+ :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
+ :return: 查询到的数据模型列表
+ """
+ rows: Sequence[cls] = []
+ _query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
+ if _query is not None:
+ session = cls.get_session()
+ rows = session.execute(_query).scalars().all()
+ session.close()
+
+ return rows
+
+ @classmethod
+ def find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
+ """
+ 根据数据框架查询已经在数据库中的数据。
+
+ :param row_df: 数据框架
+ :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
+ :return: 查询到的数据模型列表
+ """
+ _query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
+ if _query is not None:
+ session = cls.get_session()
+ _rows = session.connection().execute(_query).all()
+ session.close()
+ return pd.DataFrame(_rows)
+
+ return None
+
+ def load(self, model_data: dict):
+ """
+ 从字典对象载入数据。
+ 该方法仅从字典中读取与对象属性对应的数据,忽略其他数据。
+ 注意:与 copyXXX 方法不同,该方法跳过 model_data 中的 None 值,保留原始值不变
+
+ :param model_data: 包含数据的字典对象
+ """
+ for attr in self.__dir__():
+ if model_data.get(attr, None) is not None:
+ self.__setattr__(attr, model_data.get(attr))
+ return self
+
+ def copy_from(self, source: 'BaseAdapter', mapping_fields: Optional[List[str]] = None):
+ """
+ 从源数据模型复制数据,仅复制相同字段的数据。
+
+ :param source: 源数据对象
+ :param mapping_fields: 映射字段列表,若为空,则从源数据模型中获取字段列表
+ """
+
+ # 源对象字段列表
+ if mapping_fields is None:
+ mapping_fields = [column.key for column in source.columns()]
+
+ for column in self.columns():
+ # 若源对象中不包含同名字段,则跳过
+ if column.key not in mapping_fields:
+ continue
+
+ # 取出对象属性值
+ attr_val = source.ins_value(column.key)
+ setattr(self, column.key, attr_val)
+
+ return self
+
+ def copy_from_dict(self, source: dict, mapping_keys: Union[list, tuple, set] = None, skip_none: bool = False):
+ """
+ 从源字典对象复制数据,仅复制相同键值的数据。
+
+ :param source: 源数据对象
+ :param mapping_keys: 映射键列表,若为空,则从源字典对象中获取键列表
+ :param skip_none: 是否跳过 None 值,即若源数据对象中属性值为 None 时,跳过
+ :return 返回自身
+ """
+
+ # 源对象字段列表
+ if mapping_keys is None:
+ mapping_keys = source.keys()
+
+ for column in self.columns():
+ if column.key not in mapping_keys:
+ # 若源对象中不包含同名字段,则跳过
+ continue
+
+ # 取出对象属性值
+ attr_val = source.get(column.key, None)
+ if skip_none and attr_val is None:
+ # 若设置了跳过 None 且此时属性值为空时,跳过
+ continue
+ setattr(self, column.key, attr_val)
+
+ return self
+
+ def inspect(self) -> InstanceState:
+ """
+ 取得对象的监视对象。
+
+ :return: 监视对象
+ """
+ return inspect(self)
+
+ def session(self) -> Session:
+ """
+ 取得当前对象的连接会话对象。
+
+ :return: 连接会话对象
+ """
+ return self.inspect().session
+
+ def has_session(self) -> bool:
+ """
+ 检测对象是否已有连接会话对象。
+
+ :return: 有则返回 True 否则返回 False
+ """
+ return self.session() is not None
+
+ def add_session(self):
+ """
+ 将对象加入到连接会话对象。
+ """
+ if not self.has_session():
+ self.get_session().add(self)
+ return self.session()
+
+ def close_session(self):
+ """
+ 关闭会话。
+ """
+ if self.has_session():
+ self.session().close()
+
+ def reset_session(self):
+ """
+ 重置对象的连接会话。
+ """
+ self.close_session()
+ self.add_session()
+
+ def ins_value(self, column_name: str) -> Any:
+ """
+ 取得查询条件绑定的值,该值与对象属性一致。与::
+ self.column_name
+ 或
+ self['column_name']
+
+ :param column_name: 列名
+ :return: 对象属性值
+ """
+ # 取出对象属性值
+ attr_val = getattr(self, column_name)
+ if attr_val is None:
+ return None
+ return attr_val
+
+ def filter_expressions(self, likes: Optional[dict[str, str]] = None) -> List[BinaryExpression]:
+ """
+ 自动生成查询条件表达式。
+ 参数 likes 为需要执行模糊查询的字段字典::
+
+ {
+ field_name1: '%{}%',
+ field_name2: '{}%',
+ field_name3: '%{}',
+ ...
+ }
+
+ :param likes: 需要执行模糊查询的字段字典
+ :return: 包含查询条件表达式的列表
+ """
+ expressions = list()
+ for column in self.columns():
+ # 取出对象属性值
+ attr_val = self.ins_value(column.key)
+ # 若属性值为空,不增加条件,跳出
+ if attr_val in (None, ''):
+ continue
+
+ # 取出类属性
+ attr = self.instrument_attr(column.key)
+ # 若类属性类型不正确,不增加条件,跳出
+ if attr is None:
+ continue
+
+ # 取出列的数据类型
+ column_type = column.type
+ if isinstance(column_type, (String, Integer, Numeric, ForeignKey)):
+ # 针对属性数据类型,进行不同处理
+ if isinstance(attr_val, (list, tuple, set)):
+ if not attr_val:
+ # 跳过0长数组
+ continue
+ # 数组使用 in
+ expressions.append(attr.in_(attr_val))
+ elif isinstance(attr_val, (int, float)):
+ # 数值类型
+ expressions.append(attr == attr_val)
+ else:
+ # 字符类型,使用 like
+ if likes and attr.key in likes:
+ _f_str = likes[attr.key]
+ expressions.append(attr.like(_f_str.format(attr_val)))
+ else:
+ expressions.append(attr == attr_val)
+ elif isinstance(column_type, (DateTime, Date)):
+ if isinstance(attr_val, (list, tuple, set)):
+ if len(attr_val) == 2:
+ # 如果长度为 2 则使用 between and 方式
+ expressions.append(attr.between(attr_val[0], attr_val[1]))
+ else:
+ # 数组使用 in
+ expressions.append(attr.in_(attr_val))
+ else:
+ expressions.append(attr.between(attr_val, attr_val))
+ else:
+ # 其他类型,用等号
+ expressions.append(attr == attr_val)
+
+ return expressions
+
+ def gen_query(self, likes: Optional[dict[str, str]] = None):
+ """
+ 根据自身参数生成查询对象。
+
+ :param likes: 需要执行模糊查询的字段字典
+ :return: 查询对象
+ """
+ cls = self.__class__
+ expressions = self.filter_expressions(likes)
+ _query: Select = select(cls).where(*expressions)
+ return _query
+
+ def find(self) -> Sequence['BaseAdapter']:
+ """
+ 根据自身参数,查询数据库。
+
+ :return: 查询到的结果对象
+ """
+ session = self.get_session()
+ _model_list = session.execute(self.gen_query()).scalars().all()
+ session.close()
+ return _model_list
+
+ def find_first(self) -> 'BaseAdapter':
+ """
+ 根据自身参数,查询数据库,仅查询第一条。
+
+ :return: 查询到的结果对象
+ """
+ session = self.get_session()
+ _model = session.execute(self.gen_query()).scalars().first()
+ session.close()
+ return _model
+
+ def find_piece(self, *where: BinaryExpression, offset=0, limit=500, is_desc=False,
+ likes: Optional[dict[str, str]] = None):
+ """
+ 根据自身参数,查询数据库。
+
+ :param where: 查询条件
+ :param offset: 偏移量
+ :param limit: 读取数量
+ :param is_desc: 是否逆序排列
+ :param likes: 模糊条件
+ :return: 查询到的结果对象
+ """
+ clz = self.__class__
+ expressions = self.filter_expressions(likes=likes)
+ if where is not None:
+ expressions += where
+
+ query = select(clz).where(*expressions)
+ if limit > 0:
+ query = query.limit(limit=limit)
+ if offset >= 0:
+ query = query.offset(offset=offset)
+ if is_desc:
+ if hasattr(clz, 'id'):
+ query = query.order_by(desc(clz.id))
+
+ session = self.get_session()
+ _model_list = session.execute(query).scalars().all()
+ session.close()
+ return _model_list
+
+ def before_save(self):
+ """
+ 保存前的动作,一般应当在子类中覆盖该方法,增加在保存前应当执行的动作。
+ """
+ if self.is_new:
+ if hasattr(self, 'id'):
+ setattr(self, 'id', self.new_id())
+
+ if hasattr(self, 'created_at'):
+ setattr(self, 'created_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
+
+ if hasattr(self, 'updated_at'):
+ setattr(self, 'updated_at', datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT))
+
+ return self
+
+ def save(self, auto_expunge: Optional[bool] = True, session: Optional[Session] = None):
+ """
+ 保存数据模型对象,若该对象尚未有连接会话,则自动加入连接会话。
+
+ :param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
+ :param session: 会话对象,主要用于事务
+ :return: 保存状态
+ """
+ _has_session: bool = True
+ """
+ 该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
+ """
+
+ self.before_save()
+ if session is None:
+ _has_session = False
+ session = self.add_session()
+
+ try:
+ session.add(self)
+ if not _has_session:
+ # 使用新会话时,直接提交
+ session.commit()
+ self._is_new = False
+ except Exception as e:
+ session.rollback()
+ raise e
+ else:
+ if auto_expunge and not _has_session:
+ session.refresh(self)
+ session.expunge(self)
+ return True
+
+ def to_dict(self) -> dict:
+ """
+ 数据模型转字典。递归处理内部类型对象。
+
+ :return: 转换后的字典数据
+ """
+ # 模型数据字典
+ m_dict = {}
+
+ # 遍历处理内部转换
+ for _key, _val in dict(self.__dict__).items():
+ if f'{_key}'.startswith('_'):
+ # 跳过所有私有属性
+ continue
+
+ if isinstance(_val, BaseAdapter):
+ # 内部数据对象数据,直接转换
+ m_dict[_key] = _val.to_dict()
+ elif isinstance(_val, list):
+ # 遍历转换数据对象列表
+ _tmp_list = []
+ for _i, _v in enumerate(_val):
+ if isinstance(_v, BaseAdapter):
+ _tmp_list.append(_v.to_dict())
+ else:
+ _tmp_list.append(_v)
+ m_dict[_key] = _tmp_list
+ elif isinstance(_val, dict):
+ # 遍历转换数据对象字典
+ _tmp_dict = {}
+ for _ik, _iv in _val.items():
+ if isinstance(_iv, BaseAdapter):
+ _tmp_dict[_ik] = _iv.to_dict()
+ else:
+ _tmp_dict[_ik] = _iv
+ m_dict[_key] = _tmp_dict
+ else:
+ # 其他属性直接赋值
+ m_dict[_key] = _val
+
+ return m_dict
diff --git a/paste/db/basemodel.py b/paste/db/basemodel.py
new file mode 100644
index 0000000..c14f473
--- /dev/null
+++ b/paste/db/basemodel.py
@@ -0,0 +1,629 @@
+"""
+数据模型基础类,继承于数据表。集成了模型的基础功能,如数据验证、错误消息、数据影射、对象比较等功能。
+"""
+import datetime
+from decimal import Decimal, ROUND_HALF_UP
+from typing import Union, Any, Optional, Callable
+
+import pandas as pd
+from sqlalchemy import Column, text, desc
+
+from paste.db import baseadapter
+from paste.db.basetable import BaseTable
+from paste.util import udict, ustr
+from paste.util.pagination import Pagination
+from paste.util.snow_id import IdWorker
+
+LOCAL_DATE_FORMAT = baseadapter.LOCAL_DATE_FORMAT
+LOCAL_TIME_FORMAT = baseadapter.LOCAL_TIME_FORMAT
+LOCAL_DATETIME_FORMAT = baseadapter.LOCAL_DATETIME_FORMAT
+
+
+class BaseModel(BaseTable):
+ """
+ 数据模型基类。集成了验证辅助功能。
+ """
+
+ __abstract__ = True
+
+ @classmethod
+ def new_id(cls, datacenter_id: int = 1, worker_id: int = 1, sequence: int = 0) -> int:
+ """
+ 生成新的 Snow ID 对象,并生成 ID 值。
+
+ :param datacenter_id: 数据中心(机器区域)ID
+ :param worker_id: 机器ID
+ :param sequence: 起始序号
+ :return: 新的 Snow ID 值
+ """
+ return IdWorker.get_id_worker(datacenter_id, worker_id, sequence).get_id()
+
+ @classmethod
+ def now(cls):
+ """
+ 取得当前时间的格式化字符串。
+
+ :return: 当前时间格式化字符串
+ """
+ return datetime.datetime.now().strftime(LOCAL_DATETIME_FORMAT)
+
+ @classmethod
+ def is_len(cls, v: str, length: int):
+ """
+ 检测字符串长度的函数,例如检测是否是18位。主要用于数据校验。
+
+ :param v: 待检测的值
+ :param length: 目标长度
+ :return: 相同返回 True,否则返回 False
+ """
+ v = f"{v}"
+ return not cls.is_empty_or_none(v) and len(v) == length
+
+ @classmethod
+ def is_in_range(cls, v: Union[int, float], v_min: Union[int, float], v_max: Union[int, float]):
+ """
+ 返回检测数值范围的函数,检测是处于最大最小值范围内。主要用于数据校验。
+
+ :param v: 待检测的值
+ :param v_min: 最小值(包含)
+ :param v_max: 最大值(包含)
+ :return: 在范围内返回 True,否则返回 False
+ """
+ return not cls.is_empty_or_none(v) and v_min <= v <= v_max
+
+ @classmethod
+ def is_in_items(cls, v: Union[int, float], items: list = None):
+ """
+ 返回检测数值是否在列表中。主要用于数据校验。
+
+ :param v: 待检测的值
+ :param items: 所有项目
+ :return: 在列表内返回 True,否则返回 False
+ """
+ return items is not None and v in items
+
+ @classmethod
+ def is_empty_or_none(cls, v: Any):
+ """
+ 检查是 None 或 空字符串。
+
+ :param v: 待检查的内容
+ :return: 为 None 或 Nan 或 '' 时返回 True,否则返回 False
+ """
+ return v is None or pd.isna(v) or f"{v}" == ''
+
+ @classmethod
+ def not_empty_or_none(cls, v: Any):
+ """
+ 与 isEmptyOrNone 函数功能相反。
+ """
+ return not cls.is_empty_or_none(v)
+
+ @classmethod
+ def is_digit(cls, v: str):
+ """
+ 检查字符串是否是整数。
+
+ :param v: 带检查内容
+ :return: 是整数返回 True,否则返回 False
+ """
+ v = f"{v}"
+ return v.isdigit()
+
+ @classmethod
+ def is_decimal(cls, v: str):
+ """
+ 检查是否是浮点数,若为整数,也返回 True。
+ :param v: 待检查内容
+ :return: 浮点数或整数返回 True,否则返回 False
+ """
+ v = f"{v}"
+ is_decimal = True
+ vs = v.replace(',', '').split('.')
+ for _v in vs:
+ is_decimal = is_decimal and cls.is_digit(_v)
+ return is_decimal
+
+ @classmethod
+ def is_datetime(cls, v: str):
+ """
+ 检查是否是日期时间格式。
+ :param v: 待检查内容
+ :return: 日期时间返回 True,否则返回 False
+ """
+ try:
+ datetime.datetime.strptime(v, LOCAL_DATETIME_FORMAT)
+ except (ValueError, Exception):
+ return False
+ return True
+
+ @classmethod
+ def is_date(cls, v: str):
+ """
+ 检查是否是日期格式。
+
+ :param v: 待检查内容
+ :return: 日期返回 True,否则返回 False
+ """
+ try:
+ datetime.datetime.strptime(v, LOCAL_DATE_FORMAT)
+ except (ValueError, Exception):
+ return False
+ return True
+
+ @classmethod
+ def is_time(cls, v: str):
+ """
+ 检查是否是时间格式。
+
+ :param v: 待检查内容
+ :return: 时间返回 True,否则返回 False
+ """
+ try:
+ datetime.datetime.strptime(v, LOCAL_TIME_FORMAT)
+ except (ValueError, Exception):
+ return False
+ return True
+
+ @classmethod
+ def error_empty_msg(cls, column: Union[Column, Any]):
+ """
+ 空字符串错误。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须包含内容.' % cls.label(column=column)
+
+ @classmethod
+ def error_date_msg(cls, column: Union[Column, Any]):
+ """
+ 日期格式错误。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须是日期.' % cls.label(column=column)
+
+ @classmethod
+ def error_datetime_msg(cls, column: Union[Column, Any]):
+ """
+ 日期时间格式错误。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须是日期时间.' % cls.label(column=column)
+
+ @classmethod
+ def error_time_msg(cls, column: Union[Column, Any]):
+ """
+ 时间格式错误。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须是时间.' % cls.label(column=column)
+
+ @classmethod
+ def error_decimal_msg(cls, column: Union[Column, Any]):
+ """
+ 非浮点或双精度类型错误。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须是浮点或双进度类型.' % cls.label(column=column)
+
+ @classmethod
+ def error_format_msg(cls, column: Union[Column, Any]):
+ """
+ 格式错误。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s格式错误.' % cls.label(column=column)
+
+ @classmethod
+ def error_int_msg(cls, column: Union[Column, Any]):
+ """
+ 非整数类型错误。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须是整数.' % cls.label(column=column)
+
+ @classmethod
+ def error_len_msg(cls, column: Union[Column, Any], length: int):
+ """
+ 长度错误消息。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须是%d位.' % (cls.label(column=column), length)
+
+ @classmethod
+ def error_in_range_msg(cls, column: Union[Column, Any],
+ v_min: Union[int, float], v_max: Union[int, float]):
+ """
+ 范围错误。主要用于数据值校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须在:[%s,%s] 范围内.' % (cls.label(column=column), f"{v_min}", f"{v_max}")
+
+ @classmethod
+ def error_in_items_msg(cls, column: Union[Column, Any], items: list = None):
+ """
+ 范围错误。主要用于数据项校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ if items is None:
+ return '%s超出范围.' % cls.label(column=column)
+ else:
+ return '%s必须在:[%s] 范围内.' % (cls.label(column=column), ','.join(items))
+
+ @classmethod
+ def error_str_msg(cls, column: Union[Column, Any]):
+ """
+ 非字符串类型错误。主要用于数据校验错误。
+
+ :return: 以字段备注为主的错误消息
+ """
+ return '%s必须是字符串' % cls.label(column=column)
+
+ field_validators: dict[Column, tuple] = {}
+ """
+ 字段验证器配置。
+ 规则为:字段名 -> 验证配置
+ 验证配置为一个 tuple 数据,各元素说明如下::
+
+ 第 0 项:验证方法与消息方法,类型为 method 或 tuple,若仅有验证方法,则直接是方法名即可,若两者皆有,则为 tuple。
+ 第 1 项:是否跳过 None 值,类型为 bool。
+ 第 2~n 项,验证方法或消息方法的参数,注意验证方法与消息方法除第一项参数外的其他参数必须一致。
+ """
+
+ @classmethod
+ def validate_fields(cls, row: dict) -> list[dict[str, str]]:
+ """
+ 结合 field_validators 的配置,对字段执行验证。
+ 若发现错误,则记录在 _errors 中并返回。
+
+ :param row: 待验证数据。
+ :return: 验证得到的错误描述。
+ """
+ _errors: list[dict[str: str]] = []
+ for _column, _validator in cls.field_validators.items():
+ # 消息函数
+ _message_func = None
+ # 验证函数,是否跳过空值
+ _verify_func, _skip_null = _validator[:2]
+
+ if isinstance(_verify_func, tuple):
+ _verify_func, _message_func = _verify_func
+
+ _value = udict.get_with_default(row, _column.key, None)
+ if _value is None and _skip_null:
+ continue
+
+ _args = _validator[2:]
+ _vfy_args = (_value,) + _args
+ _err_args = (_column,) + _args
+
+ assert isinstance(_verify_func, Callable), '验证器配置错误.'
+ if not _verify_func(*_vfy_args):
+ if isinstance(_message_func, Callable):
+ _errors.append({_column.key: _message_func(*_err_args)})
+ else:
+ _errors.append({_column.key: f'{_column.key} 字段数据错误.'})
+ return _errors
+
+ @classmethod
+ def validate_dict(cls, row: dict, row_list: list[dict], err_list: list[dict]):
+ """
+ 验证字典数据。仅将结果加入对应的列表,不改变原有数据。
+
+ :param row: 待验证的行数据对象
+ :param row_list: 用于存放验证成功模型的列表
+ :param err_list: 用于存放错误消息的列表
+ """
+ try:
+ row_list.append(row)
+ except TypeError:
+ err_list.append(row)
+
+ @classmethod
+ def validate_dict_list(cls, row_list: list[dict]) -> tuple[list[dict], list[dict]]:
+ """
+ 验证字典列表数据,首先清除历史模型列表和错误消息。
+
+ :param row_list: 待验证的字典数组
+ :return: 数据模型列表和错误消息列表
+ """
+ _row_list: list[dict] = []
+ _err_list: list[dict] = []
+
+ for row in row_list:
+ cls.validate_dict(row=row, row_list=_row_list, err_list=_err_list)
+ return _row_list, _err_list
+
+ @classmethod
+ def mapping_data_struct(cls, source: Optional[dict], mapping: Optional[dict]):
+ """
+ 将源数据字典中的数据,按照映射关系字典的方式转换为新的字典对象。
+
+ 下面是一个递归映射关系字典的样本::
+
+ dict_key_mapping = {
+ 'devUseNo': 'dev_use_no',
+ 'mainId': 'id',
+ 'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''),
+ 'mainCycleCode': 'main_cycle',
+ 'fileList': {
+ '__name__': 'main_files',
+ '__mapping__': {
+ 'fileName': 'file_name',
+ 'filePath': 'file_url',
+ },
+ },
+ 'mainDetailList': {
+ '__name__': 'main_items',
+ '__mapping__': {
+ 'id': 'id',
+ 'itemId': 'item_id',
+ 'itemName': 'item_name',
+ 'itemRequest': 'item_request',
+ 'itemResult': 'item_result',
+ 'remarks': 'remarks',
+ 'itemFileList': {
+ '__name__': 'main_item_files',
+ '__mapping__': {
+ 'fileName': 'file_name',
+ 'filePath': 'file_url',
+ },
+ },
+ },
+ },
+ }
+
+ 映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型::
+
+ 1、为字符串时,表示从源数据字典中直接读取。
+ 2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。
+ 3、为字典时,表示有子对象数据,此时需要配置 __name__ 属性和 __mapping__ 属性。
+ 4、非以上情况的,直接使用该内容作为目标字典属性的数据。
+
+ :param source: 源数据字典
+ :param mapping: 映射关系字典
+
+ :return: 转换后的字典
+ """
+ if source is None or mapping is None:
+ return None
+
+ target = {}
+ for _tar_attr, _src_attr in mapping.items():
+ if isinstance(_src_attr, str):
+ #
+ # 直接处理 key 映射关系
+ # 注意,对于需要强制设置为字符串的,不能直接使用字符串,会被误认为是 key 映射关系,应当使用无参数 lambda 表达式。
+ #
+ target[_tar_attr] = source.get(_src_attr, None)
+ elif isinstance(_src_attr, Callable):
+ #
+ # 处理函数或 lambda 表达式
+ #
+ target[_tar_attr] = _src_attr(source)
+ elif isinstance(_src_attr, dict):
+ if '__name__' in _src_attr and '__mapping__' in _src_attr:
+ #
+ # 包含名称映射的,表示新的映射关系,递归处理
+ # 这里仅处理类型为 dict 和 list 的数据
+ #
+
+ # 取出内部源数据字典和映射关系
+ _sd = source.get(_src_attr.get('__name__'), None)
+ _mp = _src_attr.get('__mapping__', None)
+
+ if isinstance(_sd, dict):
+ #
+ # 直接递归映射
+ #
+ target[_tar_attr] = cls.mapping_data_struct(_sd, _mp)
+ elif isinstance(_sd, list):
+ #
+ # 遍历后递归映射
+ #
+ _t_list = []
+ for _sd_item in _sd:
+ _t_list.append(cls.mapping_data_struct(_sd_item, _mp))
+ target[_tar_attr] = _t_list
+ else:
+ #
+ # 非 dict,list 的,直接设置
+ #
+ target[_tar_attr] = _sd
+ else:
+ #
+ # 无映射关系的,直接设置
+ #
+ target[_tar_attr] = _src_attr
+ else:
+ #
+ # 非 str,function,dict 的,直接设置
+ #
+ target[_tar_attr] = _src_attr
+
+ return target
+
+ @classmethod
+ def transform(cls, rows: list[dict], mapping: dict):
+ """
+ 将源数据 rows 字典中的数据,按照映射关系字典 mapping 的方式转换为新的字典对象。
+
+ 下面是一个递归映射关系字典的样本::
+
+ dict_key_mapping = {
+ 'devUseNo': 'dev_use_no',
+ 'mainId': 'id',
+ 'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''),
+ 'mainCycleCode': 'main_cycle',
+ 'fileList': {
+ '__name__': 'main_files',
+ '__mapping__': {
+ 'fileName': 'file_name',
+ 'filePath': 'file_url',
+ },
+ },
+ 'mainDetailList': {
+ '__name__': 'main_items',
+ '__mapping__': {
+ 'id': 'id',
+ 'itemId': 'item_id',
+ 'itemName': 'item_name',
+ 'itemRequest': 'item_request',
+ 'itemResult': 'item_result',
+ 'remarks': 'remarks',
+ 'itemFileList': {
+ '__name__': 'main_item_files',
+ '__mapping__': {
+ 'fileName': 'file_name',
+ 'filePath': 'file_url',
+ },
+ },
+ },
+ },
+ }
+
+ 映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型::
+
+ 1、为字符串时,表示从源数据字典中直接读取。
+ 2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。
+ 3、为字典时,表示有子对象数据,此时需要配置 __name__ 属性和 __mapping__ 属性。
+ 4、非以上情况的,直接使用该内容作为目标字典属性的数据。
+ :param rows: 源数据字典列表
+ :param mapping: 映射关系字典
+ :return: 转换结果
+ """
+ _dict_list: list[dict] = []
+ for _r in rows:
+ _tar_dict = cls.mapping_data_struct(_r, mapping)
+ _dict_list.append(_tar_dict)
+ return _dict_list
+
+ @classmethod
+ def convert(cls, dataframe: pd.DataFrame, mapping: dict):
+ """
+ 将源数据框架 dataframe 中的数据,按照映射关系字典 mapping 的方式转换为新的 dataframe 对象。
+
+ 下面是一个递归映关系射字典的样本::
+
+ dict_key_mapping = {
+ 'devUseNo': 'dev_use_no',
+ 'mainId': 'id',
+ 'mainCycle': lambda dict_obj: MAIN_CYCLE_LABELS.get(dict_obj['main_cycle'], ''),
+ 'mainCycleCode': 'main_cycle',
+ }
+
+ 映射关系字典遵循:{`目标属性`: `源属性`} 的结构,对`源属性`,允许有以下几种类型::
+
+ 1、为字符串时,表示从源数据字典中直接读取。
+ 2、为函数或 lambda 表达式时,执行函数,并将源数据字典以参数形式传给该函数。
+ 3、非以上情况的,直接使用该内容作为目标字典属性的数据。
+
+ 注意:与字典映射转换不同,:class:`pd.DataFrame` 映射转换不支持多层递归转换。
+
+ :param dataframe: 源数据 dataframe
+ :param mapping: 映射关系字典
+ :return: 转换结果
+ """
+ _tar_df = pd.DataFrame()
+ for _tar_attr, _src_attr in mapping.items():
+ if isinstance(_src_attr, str):
+ _tar_df[_tar_attr] = dataframe[_src_attr]
+ elif isinstance(_src_attr, Callable):
+ _tar_df[_tar_attr] = dataframe.apply(_src_attr, axis=1)
+ else:
+ _tar_df[_tar_attr] = _tar_attr
+ return _tar_df
+
+ @classmethod
+ def is_equal(cls, data_dict: dict, data_model: 'BaseModel', skip_kes: list[str] = None, decimals: str = '0.00'):
+ """
+ 判断 data_dict 中的值是否都与 equ_model 中的对应值相等。一般而言若相等,则表明无需更新数据模型,否则就需要更新。
+
+ :param data_dict: 数据字典,用于遍历比对的数据,也是用于更新的数据
+ :param data_model: 数据模型
+ :param skip_kes: 允许跳过,不做比较的字段
+ :param decimals: 浮点数保留的小数位,默认 2 位
+ :return: 是否相等,各字段是否相等的对应关系字典
+ """
+ is_equal = True
+ equal_dict: dict = {}
+ if skip_kes is None:
+ skip_kes = []
+
+ for _key, _new_val in data_dict.items():
+ if _key in skip_kes:
+ continue
+
+ if _new_val is None:
+ # 跳过新值中的 None
+ continue
+
+ if _key not in data_model.__dict__:
+ # 跳过不存在的属性
+ continue
+
+ _old_val = data_model.__dict__.get(_key, None)
+ if isinstance(_old_val, (Decimal, float)):
+ _old_val = Decimal(f"{_old_val}").quantize(Decimal(decimals), rounding=ROUND_HALF_UP)
+ _new_val = Decimal(f"{_new_val}").quantize(Decimal(decimals), rounding=ROUND_HALF_UP)
+ elif isinstance(_old_val, int):
+ _new_val = int(_new_val)
+ elif isinstance(_old_val, datetime.datetime):
+ _old_val = _old_val.strftime(LOCAL_DATETIME_FORMAT)
+ _datetime = ustr.to_datetime(_new_val, [LOCAL_DATETIME_FORMAT, LOCAL_DATE_FORMAT])
+ _new_val = _datetime.strftime(LOCAL_DATETIME_FORMAT) if _datetime is not None else f"{_new_val}"
+ elif isinstance(_old_val, datetime.date):
+ _old_val = _old_val.strftime(LOCAL_DATE_FORMAT)
+ _date = ustr.to_datetime(_new_val, [LOCAL_DATE_FORMAT, LOCAL_DATETIME_FORMAT])
+ _new_val = _date.strftime(LOCAL_DATE_FORMAT) if _date is not None else f"{_new_val}"
+ else:
+ _old_val = f"{_old_val}" if _old_val is not None else ''
+ if isinstance(_new_val, float):
+ _new_val = int(_new_val)
+ _new_val = f"{_new_val}"
+
+ _isFieldEqual = _new_val == _old_val
+ is_equal = is_equal and _isFieldEqual
+ equal_dict[_key] = _isFieldEqual
+
+ return is_equal, equal_dict
+
+ @classmethod
+ async def page_info(cls, *where_clause, page_size: int = 20):
+ """
+ 分页参数。
+
+ :return: 页数, 数据行数
+ """
+ _row_count = await cls.async_row_count(*where_clause)
+ _pagination = Pagination(row_count=_row_count)
+ return _pagination.pages(page_size=page_size), _row_count
+
+ @classmethod
+ def sort_clauses(cls, sort_d: dict):
+ """
+ 按照参数 sort_d 中的定义,组织排序表达式。参数 sortd_d 应该具有如下结构::
+
+ {
+ 'field_name1': 'asc',
+ 'field_name2': 'desc',
+ }
+
+ :param sort_d 排序参数
+ """
+ _sort_clause = []
+ for _fn, _st in sort_d.items():
+ if _st in ('', 'asc', 'ascend'):
+ _sort_clause.append(text(_fn))
+ if _st in ('desc', 'descend'):
+ _sort_clause.append(desc(text(_fn)))
+ return _sort_clause
diff --git a/paste/db/basetable.py b/paste/db/basetable.py
new file mode 100755
index 0000000..8415b3e
--- /dev/null
+++ b/paste/db/basetable.py
@@ -0,0 +1,337 @@
+"""
+集成了表的基本操作。
+"""
+import datetime
+from typing import Optional, Union, List, Sequence
+
+import pandas as pd
+from sqlalchemy import func, desc, Column, select, util, Result
+from sqlalchemy.engine import ScalarResult, CursorResult
+from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
+from sqlalchemy.sql import Select
+from sqlalchemy.sql.elements import BinaryExpression
+from sqlalchemy.sql.operators import ColumnOperators
+
+from paste.db import baseadapter, engine
+
+
+class BaseTable(baseadapter.BaseAdapter):
+ """
+ 表对象(数据模型)基类。集成基础数据操作功能。
+ 实现异步执行方法。
+ """
+
+ __abstract__ = True
+
+ @classmethod
+ async def raw_execute(cls, query, session: AsyncSession=None,
+ params=None, options=None) -> Optional[CursorResult]:
+ """
+ 异步连接执行原始查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。
+
+ :param query: 查询语句
+ :param session: 数据库会话对象
+ :param params: 查询参数
+ :param options: 查询选项
+ :return: CursorResult 游标
+ """
+ if options is None:
+ options = util.EMPTY_DICT
+ if session is None:
+ connection: AsyncConnection = await cls.get_aio_session().connection()
+ else:
+ connection = await session.connection()
+
+ try:
+ cursor_result: CursorResult = await connection.execute(query, params, execution_options=options)
+ await connection.commit()
+ return cursor_result
+ finally:
+ await connection.close()
+
+ @classmethod
+ async def orm_execute(cls, query, session: AsyncSession=None, params=None, options=None) -> Optional[Result]:
+ """
+ 异步会话执行查询,返回游标。链式调用返回对象的 all() 方法后,将得到 list[Row] 对象。
+
+ :param query: 查询请求
+ :param session: 数据库会话对象
+ :param params: 查询参数
+ :param options: 查询选项
+ """
+ _has_session: bool = True
+ if session is None:
+ _has_session = False
+ session = cls.get_aio_session()
+
+ try:
+ result: Result = await session.execute(query, params, execution_options=options)
+ return result
+ finally:
+ if not _has_session:
+ await session.close()
+
+ @classmethod
+ async def orm_execute_scalars(cls, query, session: AsyncSession=None, **kwargs) -> ScalarResult:
+ """
+ 使用异步执行查询。并执行 scalars() 方法,这会进行 ORM 映射。
+
+ :param query: 查询对象
+ :param session: 数据库会话对象
+ :return: 数据模型对象列表
+ """
+ result: Result = await cls.orm_execute(query, session, **kwargs)
+ return result.scalars()
+
+ @classmethod
+ async def query_all(cls, query: Select, session: AsyncSession=None) -> List['BaseTable']:
+ """
+ 使用异步执行 ORM 查询。并执行 scalars().all() 方法。
+
+ :param query: 查询对象
+ :param session: 数据库会话对象
+ :return: 数据模型对象列表,注意:当查询部分字段时仅返回查询列表中第一个字段的数据列表
+ """
+ scalars = await cls.orm_execute_scalars(query, session)
+ return list(scalars.all())
+
+ @classmethod
+ async def query_first(cls, query: Select, session: AsyncSession=None) -> 'BaseTable':
+ """
+ 使用异步执行 ORM 查询。并执行 scalars().first() 方法。
+
+ :param query: 查询对象
+ :param session: 数据库会话对象
+ :return: 数据模型对象列表
+ """
+ scalars = await cls.orm_execute_scalars(query, session)
+ return scalars.first()
+
+ @classmethod
+ async def query_count(cls, query: Select, is_only_count: bool = False, session: AsyncSession=None) -> int:
+ """
+ 使用异步执行 ORM 查询,查询原查询的数据行数。
+
+ :param query: 查询对象
+ :param is_only_count: 是否使用仅 count 方式查询
+ :param session: 数据库会话对象
+ :return: 数据行数
+ """
+ if is_only_count:
+ count_query = query.with_only_columns(func.count(1))
+ else:
+ count_query = select(func.count(1)).select_from(query.subquery())
+ # 执行查询
+ row_count: int = (await cls.orm_execute_scalars(count_query, session)).first()
+ return row_count
+
+ @classmethod
+ async def query_as_df(cls, query, session: AsyncSession=None, params=None, options=None):
+ """
+ 执行 RAW 原始查询,并将数据请求转换为 :class:`pd.DataFrame` 返回。
+ 注意:为了保持字段数据精度或便于数据处理,可在 SQL 语句中使用 :meth:`func.ifnull` 或 :meth:`func.convert` 等方法直接转换数据类型。
+
+ :param query: 查询语句
+ :param session: 数据库会话对象
+ :param params: 查询参数
+ :param options: 查询选项
+ :return: 数据 DataFrame
+ """
+ result = await cls.raw_execute(query, session, params, options)
+ return pd.DataFrame(result.all(), columns=pd.Series(result.keys()))
+
+ @classmethod
+ async def async_row_count(cls, *where_expressions: Union[ColumnOperators, BinaryExpression],
+ session: AsyncSession=None) -> int:
+ """
+ 异步按条件取得数据行数量。
+
+ :param where_expressions: 条件表达式列表
+ :param session: 数据库会话对象
+ :return: 行数
+ """
+ query = select(func.count(1)).select_from(cls).where(*where_expressions)
+ result = await cls.orm_execute_scalars(query, session)
+ return result.first()
+
+ @classmethod
+ async def async_exist_by_id(cls, id_val: Union[str, int]):
+ """
+ 检查是否有存在的主键。
+
+ :param id_val: 主键值
+ """
+ count: int = await cls.async_row_count(cls.instrument_attr('id') == id_val)
+ return count >= 1
+
+ @classmethod
+ async def async_find_by_id(cls, id_val: Union[str, int]) -> Optional['BaseTable']:
+ """
+ 根据主键ID查找数据,确认有 id 字段后方可使用。
+
+ :param id_val: 主键值
+ :return: 数据模型对象
+ """
+ query = select(cls).where(cls.instrument_attr('id') == id_val)
+ model = await cls.query_first(query=query)
+ return model
+
+ @classmethod
+ async def async_find_by_created(cls, from_t: datetime.datetime,
+ to_t: Optional[datetime.datetime] = None) -> List['BaseTable']:
+ """
+ 按创建时间搜索。本方法使用了固定字段名称,确认有 created_at 字段后方可使用。
+
+ :param from_t: 开始时间
+ :param to_t: 结束时间
+ """
+ if to_t is None:
+ to_t = from_t
+
+ query = select(cls).where(cls.instrument_attr('created_at').between(from_t, to_t))
+ return await cls.query_all(query)
+
+ @classmethod
+ async def async_find_by_updated(cls, from_t: datetime.datetime,
+ to_t: Optional[datetime.datetime] = None) -> List['BaseTable']:
+ """
+ 按更新时间搜索。本方法使用了固定字段名称,确认有 updated_at 字段后方可使用。
+
+ :param from_t: 开始时间
+ :param to_t: 结束时间
+ """
+ if to_t is None:
+ to_t = from_t
+
+ query = select(cls).where(cls.instrument_attr('updated_at').between(from_t, to_t))
+ return await cls.query_all(query)
+
+ @classmethod
+ async def async_find_by_datalist(cls, row_list: List[dict], condition_cols: List[Column]):
+ """
+ 根据数据列表查询已经在数据库中的数据。
+
+ :param row_list: 数据列表
+ :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
+ :return: 查询到的数据模型列表
+ """
+ model_list: list[cls] = []
+ query = cls.datalist_query(row_list=row_list, condition_cols=condition_cols)
+ if query is not None:
+ model_list = await cls.query_all(query=query)
+
+ return model_list
+
+ @classmethod
+ async def async_find_by_dataframe(cls, row_df: pd.DataFrame, condition_cols: List[Column]):
+ """
+ 根据数据框架查询已经在数据库中的数据。
+
+ :param row_df: 数据框架
+ :param condition_cols: 要查询,且作为条件的列。注意次序与索引次序相同
+ :return: 查询到的数据模型列表
+ """
+ query = cls.dataframe_query(row_df=row_df, condition_cols=condition_cols)
+ if query is not None:
+ _result = await cls.orm_execute_scalars(query=query)
+ _rows: Sequence[cls] = _result.all()
+ return pd.DataFrame(_rows)
+
+ return None
+
+ async def async_find(self, likes: Optional[dict[str, str]] = None) -> List['BaseTable']:
+ """
+ 根据自身参数,查询数据库。
+
+ :param likes: 模糊条件
+ :return: 查询到的结果对象
+ """
+ expressions = self.filter_expressions(likes=likes)
+ query = select(self.__class__).where(*expressions)
+ return await self.query_all(query)
+
+ async def async_find_first(self, likes: Optional[dict[str, str]] = None) -> Optional['BaseTable']:
+ """
+ 根据自身参数,查询数据库,仅查询第一条。
+
+ :param likes: 模糊条件
+ :return: 查询到的结果对象
+ """
+ expressions = self.filter_expressions(likes=likes)
+ query = select(self.__class__).where(*expressions)
+ result = await self.query_first(query)
+ return result
+
+ async def async_find_piece(self, *where: Union[ColumnOperators, BinaryExpression],
+ offset=0, limit=500, is_desc=False,
+ likes: Optional[dict[str, str]] = None) -> List['BaseTable']:
+ """
+ 根据自身参数,查询数据库。
+
+ :param where: 查询条件
+ :param offset: 偏移量
+ :param limit: 读取数量
+ :param is_desc: 是否逆序排列
+ :param likes: 模糊条件
+ :return: 查询到的结果对象
+ """
+ clz = self.__class__
+ expressions = self.filter_expressions(likes=likes)
+ if where is not None:
+ expressions += where
+
+ query = select(clz).where(*expressions)
+ if limit > 0:
+ query = query.limit(limit=limit)
+ if offset >= 0:
+ query = query.offset(offset=offset)
+ if is_desc:
+ if hasattr(clz, 'id'):
+ query = query.order_by(desc(clz.id))
+
+ return await self.query_all(query)
+
+ async def async_save(self, auto_expunge: Optional[bool] = True, session: Optional[AsyncSession] = None):
+ """
+ 保存数据模型对象,首先强制关闭原有会话,获取新的会话,并加入对象。
+
+ :param auto_expunge: 自动刷新对象并将其移出连接会话,不提供外部 session 时有效
+ :param session: 会话对象,主要用于事务
+ :return: 保存状态
+ """
+ _has_session: bool = True
+ """
+ 该参数用于说明是否提供了外部 session 对象,默认为 True 时表示提供。
+ """
+
+ self.before_save()
+ self.close_session()
+ if session is None:
+ _has_session = False
+ session = self.get_aio_session()
+
+ try:
+ session.add(self)
+ if not _has_session:
+ # 使用新会话时,直接提交
+ await session.commit()
+ self._is_new = False
+ except Exception as e:
+ await session.rollback()
+ raise e
+ else:
+ if auto_expunge and not _has_session:
+ await session.refresh(self)
+ session.expunge(self)
+ return True
+ finally:
+ if not _has_session:
+ # 使用新会话时,主动关闭
+ await session.close()
+
+
+def create_all_tables():
+ """
+ 创建所有的表格。
+ """
+ baseadapter.registry.metadata.create_all(engine.connect_engine())
diff --git a/paste/db/engine.py b/paste/db/engine.py
new file mode 100755
index 0000000..23d8ef6
--- /dev/null
+++ b/paste/db/engine.py
@@ -0,0 +1,43 @@
+"""
+从配置文件读取数据引擎连接信息,连接数据库。
+"""
+
+from typing import Union
+
+from sqlalchemy import create_engine
+from sqlalchemy.engine import Engine
+from sqlalchemy.engine.mock import MockConnection
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
+
+from paste.core import config
+
+ASYNC_CONNECTOR_ENGINE = None
+GLOBAL_CONNECTOR_ENGINE = None
+
+
+def connect_engine() -> Union[MockConnection, Engine]:
+ """
+ 全局数据连接引擎。
+
+ :return: 数据连接引擎
+ """
+ global GLOBAL_CONNECTOR_ENGINE
+ if GLOBAL_CONNECTOR_ENGINE is None:
+ GLOBAL_CONNECTOR_ENGINE = create_engine(
+ config.get_config('db_engine.engine'), **config.get_config('db_engine.engine_option')
+ )
+ return GLOBAL_CONNECTOR_ENGINE
+
+
+def async_connect_engine() -> AsyncEngine:
+ """
+ 异步数据连接引擎。
+
+ :return: 异步数据连接引擎
+ """
+ global ASYNC_CONNECTOR_ENGINE
+ if ASYNC_CONNECTOR_ENGINE is None:
+ ASYNC_CONNECTOR_ENGINE = create_async_engine(
+ config.get_config('db_engine.async_engine'), **config.get_config('db_engine.engine_option')
+ )
+ return ASYNC_CONNECTOR_ENGINE
diff --git a/paste/db/gen_models.py b/paste/db/gen_models.py
new file mode 100644
index 0000000..62bad44
--- /dev/null
+++ b/paste/db/gen_models.py
@@ -0,0 +1,51 @@
+"""
+生成数据模型代码,注意这里显式排除了在配置文件 RBAC 中配置的数据表。
+如果不需要排除,可直接使用 sqlacodegen 自带的命令。
+"""
+import subprocess
+from os import path
+
+import pandas as pd
+
+from paste.core import config
+from paste.core.logging import echo_log
+from paste.db.basetable import BaseTable
+
+exclude_tables = list(config.get_config('rbac.table').values())
+"""
+需要排除的数据表,默认排除 RBAC 数据表,这部分表格在 RBAC 模块中已经配置好了。
+"""
+
+
+def db_engin():
+ return config.get_config('db_engine.engine')
+
+
+async def sqlacodegen(is_exclude_rbac_table: bool = True):
+ """
+ 生成代码文件。
+
+ :param is_exclude_rbac_table: 是否排除 RBAC 相关的数据表,默认排除不生成数据模型
+ """
+ _table_names = await BaseTable.tables_in_db()
+
+ # 剔除 RBAC 数据表
+ if _table_names and is_exclude_rbac_table:
+ # 转换为 DataFrame
+ _tables_df = pd.DataFrame(_table_names)
+ # 剔除不要包含的表
+ _name_df: pd.DataFrame = _tables_df.loc[~_tables_df.iloc[:, 0].isin(exclude_tables)]
+ # 转换剩余数据为'表名'字符串列表
+ _table_names = _name_df[0].tolist()
+
+ if len(_table_names) == 0:
+ return
+
+ echo_log(f"将为以下表生成数据模型:{_table_names}")
+ _tables = f"--tables={','.join(_table_names)}"
+ # 默认创建在当前目录的 models 目录中
+ _outfile = f"--outfile={path.join(path.curdir, 'models', 'db_models.py')}"
+
+ _engin = db_engin()
+ subprocess.call(['sqlacodegen', _engin, _tables, _outfile])
+ echo_log(f"生成完成.")
\ No newline at end of file
diff --git a/paste/db/redis.py b/paste/db/redis.py
new file mode 100644
index 0000000..76d9e14
--- /dev/null
+++ b/paste/db/redis.py
@@ -0,0 +1,995 @@
+"""
+封装了 Python 对 Redis 的基本操作。
+同时处理了 Java 在操作 Redis 后留下的字节码问题。
+"""
+import asyncio
+import hashlib
+import pathlib
+import random
+import types
+from logging import ERROR, WARNING
+from typing import Optional, Callable, Awaitable, Union, Tuple, Dict
+
+import javaobj
+import redis
+from redis.asyncio import ConnectionPool, StrictRedis
+from redis.client import Pipeline
+
+from paste.core import aio_pool, config, logging
+from paste.util.snow_id import IdWorker
+
+
+class LuaScriptManager:
+ """
+ Lua 脚本管理器。
+ 负责加载、缓存和执行 Lua 脚本。
+ """
+
+ # 默认 Lua 脚本内容(作为内置默认值,无需外部文件)
+ DEFAULT_SCRIPTS = {
+ "stock_decr": """
+ -- 扣减库存(原子操作)
+ -- KEYS[1]: 库存 key
+ -- ARGV[1]: 扣减数量
+ -- 返回值: 1=成功, 0=库存不足, -1=key不存在
+
+ local key = KEYS[1]
+ local quantity = tonumber(ARGV[1])
+
+ local current = redis.call('GET', key)
+ if not current then
+ return -1
+ end
+
+ current = tonumber(current)
+ if current >= quantity then
+ redis.call('DECRBY', key, quantity)
+ return 1
+ else
+ return 0
+ end
+ """,
+
+ "stock_incr": """
+ -- 增加库存(原子操作)
+ -- KEYS[1]: 库存 key
+ -- ARGV[1]: 增加数量
+ -- 返回值: 当前库存
+
+ local key = KEYS[1]
+ local quantity = tonumber(ARGV[1])
+
+ redis.call('INCRBY', key, quantity)
+ return redis.call('GET', key)
+ """,
+
+ "stock_peek": """
+ -- 查看库存(原子操作)
+ -- KEYS[1]: 库存 key
+ -- 返回值: 当前库存
+
+ local key = KEYS[1]
+ local current = redis.call('GET', key)
+
+ if not current then
+ return -1
+ end
+ return tonumber(current)
+ """,
+ }
+
+ _scripts: Dict[str, Tuple[str, str]] = {} # name -> (sha, script_content)
+ _script_dir: Optional[pathlib.Path] = None
+ _use_external_files: bool = False # 是否使用外部文件
+
+ @classmethod
+ def set_script_dir(cls, script_dir: str, use_external: bool = True):
+ """
+ 设置 Lua 脚本目录
+
+ :param script_dir: 脚本目录路径
+ :param use_external: 是否使用外部文件(False 则使用内置默认脚本)
+ """
+ cls._script_dir = pathlib.Path(script_dir) if script_dir else None
+ cls._use_external_files = use_external
+
+ @classmethod
+ async def load_script(cls, redis_client: StrictRedis, script_name: str) -> str:
+ """
+ 加载并注册 Lua 脚本
+ 优先使用外部文件,不存在则使用内置默认脚本
+
+ :param redis_client: Redis 客户端
+ :param script_name: 脚本名称(如 stock_decr)
+ :return: 脚本 SHA
+ """
+ script_content = None
+
+ # 尝试从外部文件加载
+ if cls._use_external_files and cls._script_dir:
+ script_path = cls._script_dir / f"{script_name}.lua"
+ if script_path.exists():
+ with open(script_path, 'r', encoding='utf-8') as f:
+ script_content = f.read()
+ logging.echo_log(f"Lua 脚本从外部文件加载: {script_path}")
+
+ # 使用内置默认脚本
+ if script_content is None:
+ if script_name not in cls.DEFAULT_SCRIPTS:
+ raise ValueError(f"脚本不存在: {script_name},且无内置默认值")
+ script_content = cls.DEFAULT_SCRIPTS[script_name]
+ logging.echo_log(f"Lua 脚本使用内置默认值: {script_name}")
+
+ # 计算 SHA
+ sha = hashlib.sha1(script_content.encode()).hexdigest()
+
+ # 缓存脚本
+ cls._scripts[script_name] = (sha, script_content)
+
+ # 预加载到 Redis
+ try:
+ await redis_client.script_load(script_content)
+ except Exception:
+ pass # 预加载失败不影响后续使用
+
+ return sha
+
+ @classmethod
+ async def load_default_scripts(cls, redis_client: StrictRedis):
+ """
+ 加载所有默认脚本
+ """
+ for script_name in cls.DEFAULT_SCRIPTS.keys():
+ await cls.load_script(redis_client, script_name)
+ logging.echo_log(f"已加载 {len(cls.DEFAULT_SCRIPTS)} 个默认 Lua 脚本")
+
+ @classmethod
+ async def execute(cls, redis_client: StrictRedis, script_name: str,
+ keys: list, args: list) -> any:
+ """
+ 执行 Lua 脚本
+ 优先使用 evalsha(性能更好),失败则降级到 eval
+
+ :param redis_client: Redis 客户端
+ :param script_name: 脚本名称
+ :param keys: KEYS 参数列表
+ :param args: ARGV 参数列表
+ :return: 脚本执行结果
+ """
+ if script_name not in cls._scripts:
+ # 脚本未加载,尝试加载
+ await cls.load_script(redis_client, script_name)
+
+ sha, script_content = cls._scripts[script_name]
+
+ try:
+ return await redis_client.evalsha(sha, len(keys), *keys, *args)
+ except redis.ResponseError as e:
+ if "NOSCRIPT" in str(e):
+ # 重新加载并重试
+ await redis_client.script_load(script_content)
+ return await redis_client.evalsha(sha, len(keys), *keys, *args)
+ else:
+ raise
+
+ @classmethod
+ async def reload_script(cls, redis_client: StrictRedis, script_name: str) -> str:
+ """重新加载指定的 Lua 脚本"""
+ if script_name in cls._scripts:
+ del cls._scripts[script_name]
+ return await cls.load_script(redis_client, script_name)
+
+
+class Redis:
+ """
+ Redis 基础操作。
+ """
+
+ connect_pool: Optional[ConnectionPool] = None
+
+ prefix = b'\xac\xed\x00\x05'
+ utf_flag = b'\x74'
+
+ lua_scripts = LuaScriptManager
+ """Lua 脚本管理器。"""
+
+ @classmethod
+ def is_java_serialized(cls, bs: Union[bytes, str]):
+ """
+ 判断是否为 Java 序列化后的数据。
+
+ :param bs: 字节流
+ """
+ if not isinstance(bs, bytes):
+ return False
+ return bs[:4] == cls.prefix
+
+ @classmethod
+ async def get_pool(cls) -> ConnectionPool:
+ """
+ 取得 Redis 连接池。
+
+ :return: 连接池对象
+ """
+ if cls.connect_pool is None:
+ _conn_params = config.get_config("redis.connection")
+ cls.connect_pool = ConnectionPool.from_url(**_conn_params)
+ return cls.connect_pool
+
+ @classmethod
+ async def close_pool(cls):
+ if cls.connect_pool is not None:
+ await cls.connect_pool.disconnect()
+ cls.connect_pool = None
+
+ @classmethod
+ async def get_redis(cls) -> StrictRedis:
+ """
+ 取得数据库对象。
+
+ :return: 数据库对象
+ """
+ _pool = await cls.get_pool()
+ return StrictRedis(
+ connection_pool=_pool,
+ socket_timeout=5,
+ socket_connect_timeout=5,
+ health_check_interval=30,
+ socket_keepalive=True
+ )
+
+ @classmethod
+ async def ping(cls):
+ """
+ 测试连接。
+
+ :return: 测试结果
+ """
+ async with await cls.get_redis() as _redis:
+ return await _redis.ping()
+
+ @classmethod
+ async def get_pipe(cls, transaction: bool = True, shard_hint=None) -> Pipeline:
+ """
+ 取得管道对象。
+ """
+ async with await cls.get_redis() as _redis:
+ return _redis.pipeline(transaction=transaction, shard_hint=shard_hint)
+
+ # ========== Lua 脚本初始化 ==========
+
+ @classmethod
+ async def init_lua_scripts(cls, script_dir: str = None, use_external: bool = False):
+ """
+ 初始化 Lua 脚本
+ 建议在应用启动时调用一次
+
+ :param script_dir: 外部脚本目录(可选)
+ :param use_external: 是否使用外部文件,默认 False 使用内置脚本
+ """
+ if script_dir:
+ cls.lua_scripts.set_script_dir(script_dir, use_external)
+
+ async with await cls.get_redis() as _redis:
+ await cls.lua_scripts.load_default_scripts(_redis)
+
+ # ========== 库存核心方法(原子操作) ==========
+
+ @classmethod
+ async def stock_decr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, str]:
+ """
+ 扣减库存(原子操作)
+ 使用 Lua 脚本保证原子性,防止超卖
+
+ :param stock_key: 库存 Key(支持分片,如 stock:iPhone15:shard:0)
+ :param quantity: 扣减数量
+ :return: (是否成功, 消息)
+ """
+ async with await cls.get_redis() as _redis:
+ try:
+ result = await cls.lua_scripts.execute(
+ _redis,
+ "stock_decr",
+ keys=[stock_key],
+ args=[quantity]
+ )
+
+ if result == 1:
+ return True, "扣减成功"
+ elif result == 0:
+ return False, "库存不足"
+ else:
+ return False, "商品不存在"
+ except Exception as e:
+ logging.echo_log(f"扣减库存异常: {e}", level=ERROR, is_log_exc=True)
+ return False, f"系统异常: {e}"
+
+ @classmethod
+ async def stock_incr(cls, stock_key: str, quantity: int = 1) -> Tuple[bool, int]:
+ """
+ 增加库存(原子操作)
+ 用于退货入库、补货等场景
+
+ :param stock_key: 库存 Key
+ :param quantity: 增加数量
+ :return: (是否成功, 当前库存)
+ """
+ async with await cls.get_redis() as _redis:
+ try:
+ result = await cls.lua_scripts.execute(
+ _redis,
+ "stock_incr",
+ keys=[stock_key],
+ args=[quantity]
+ )
+ return True, int(result)
+ except Exception as e:
+ logging.echo_log(f"增加库存异常: {e}", level=ERROR, is_log_exc=True)
+ return False, 0
+
+ @classmethod
+ async def stock_peek(cls, stock_key: str) -> int:
+ """
+ 查看剩余库存(原子操作)
+
+ :param stock_key: 库存 Key
+ :return: 剩余库存
+ """
+ async with await cls.get_redis() as _redis:
+ try:
+ result = await cls.lua_scripts.execute(
+ _redis,
+ "stock_peek",
+ keys=[stock_key],
+ args=[]
+ )
+ return int(result) if result >= 0 else 0
+ except Exception as e:
+ logging.echo_log(f"查询库存异常: {e}", level=ERROR, is_log_exc=True)
+ return 0
+
+ # ========== 库存分片辅助方法 ==========
+
+ @classmethod
+ def get_shard_key(cls, sku_id: str, shard_id: int) -> str:
+ """
+ 获取分片 Key。
+ 推荐格式:{业务域}:{实体}:{唯一标识}:{分片/维度}:{扩展}
+
+ :param sku_id: 商品ID
+ :param shard_id: 分片ID
+ :return: 分片 Key
+ """
+ return f"stock:{sku_id}:shard:{shard_id}"
+
+ @classmethod
+ def get_user_shard(cls, sku_id: str, user_id: str, shard_count: int = 10) -> str:
+ """
+ 根据用户ID获取分片 Key
+
+ :param sku_id: 商品ID
+ :param user_id: 用户ID
+ :param shard_count: 分片总数
+ :return: 分片 Key
+ """
+ shard = hash(user_id) % shard_count
+ return cls.get_shard_key(sku_id, shard)
+
+ @classmethod
+ async def init_sharded_stock(cls, sku_id: str, total_stock: int, shard_count: int = 10):
+ """
+ 初始化分片库存
+
+ :param sku_id: 商品ID
+ :param total_stock: 总库存
+ :param shard_count: 分片数量
+ """
+ base = total_stock // shard_count
+ remainder = total_stock % shard_count
+
+ async with await cls.get_redis() as _redis:
+ for i in range(shard_count):
+ shard_key = cls.get_shard_key(sku_id, i)
+ stock = base + (1 if i < remainder else 0)
+ await _redis.set(shard_key, stock)
+ logging.echo_log(f"初始化分片 {i}: {shard_key} = {stock}")
+
+ # ========== 基础 KV 操作 ==========
+
+ @classmethod
+ async def keys(cls):
+ """
+ 取得所有的 Key。
+ """
+ async with await cls.get_redis() as _redis:
+ _keys = await _redis.keys()
+ return _keys
+
+ @classmethod
+ async def show_keys(cls):
+ """
+ 控制台显示所有的 Keys。
+ """
+ _keys = await cls.keys()
+ for _key in _keys:
+ if isinstance(_key, bytes):
+ if cls.is_java_serialized(_key):
+ print(_key[7:].decode('utf-8'), '=>', _key)
+ else:
+ print(_key.decode('utf-8'), '=>', _key)
+ else:
+ print(_key)
+
+ @classmethod
+ async def get(cls, key: Union[bytes, str]):
+ """
+ 多种方式读取 Redis 中的数据。
+
+ :param key: Redis Key 名称
+ :return: 数据内容
+ """
+ async with await cls.get_redis() as _redis:
+ _result = await _redis.get(key)
+
+ if _result is None and not cls.is_java_serialized(key):
+ if isinstance(key, str):
+ key_bytes = key.encode('utf-8')
+ else:
+ key_bytes = key
+ _key = cls.prefix + cls.utf_flag + len(key_bytes).to_bytes(2, 'big') + key_bytes
+ _result = await _redis.get(_key)
+
+ if _result is None:
+ return _result
+
+ if isinstance(_result, bytes) and cls.is_java_serialized(_result):
+ return javaobj.loads(_result)
+ else:
+ return _result
+
+ @classmethod
+ async def set(cls, key: str, value: any, ex: int = None):
+ """
+ 设置键值对
+ """
+ async with await cls.get_redis() as _redis:
+ return await _redis.set(key, value, ex=ex)
+
+ @classmethod
+ async def delete(cls, key: str):
+ """
+ 删除键
+ """
+ async with await cls.get_redis() as _redis:
+ return await _redis.delete(key)
+
+ @classmethod
+ async def exists(cls, key: str) -> bool:
+ """
+ 检查键是否存在
+ """
+ async with await cls.get_redis() as _redis:
+ return await _redis.exists(key) > 0
+
+ @classmethod
+ async def expire(cls, key: str, seconds: int):
+ """
+ 设置过期时间
+ """
+ async with await cls.get_redis() as _redis:
+ return await _redis.expire(key, seconds)
+
+ @classmethod
+ async def incr(cls, key: str) -> int:
+ """
+ 原子递增
+ """
+ async with await cls.get_redis() as _redis:
+ return await _redis.incr(key)
+
+ # ========== 回调处理 ==========
+
+ @classmethod
+ def get_func_name(cls, func):
+ """
+ 得到方法名称。
+
+ :param func: 方法对象
+ :return: 方法名称
+ """
+ if isinstance(func, types.FunctionType):
+ return func.__name__
+ elif isinstance(func, types.MethodType):
+ return func.__func__.__name__
+ elif isinstance(func, (classmethod, staticmethod)):
+ return func.__func__.__name__
+ elif hasattr(func, '__call__'):
+ return func.__class__.__name__
+ else:
+ return str(func)
+
+ @classmethod
+ async def callback(cls, func: Callable, message_key: str, is_delete=False):
+ """
+ 根据消息 KEY 读取数据,并执行回调函数,如果回调函数正确执行,则根据参数 is_delete 判断删除消息。
+
+ :param func: 回调函数
+ :param message_key: 消息 KEY
+ :param is_delete: 是否删除处理过的消息
+ """
+ result = None
+ async with await cls.get_redis() as _redis:
+ try:
+ message_data = await _redis.hgetall(message_key)
+ if not message_data:
+ logging.echo_log(f"警告: 空消息数据 {message_key}.", level=WARNING)
+ return result
+
+ if func:
+ # 处理回调
+ result = func(message_data)
+ # 处理协程
+ if isinstance(result, Awaitable):
+ result = await result
+
+ if is_delete:
+ # 回调正确执行,且设置为删除删除的,才会删除消息
+ await _redis.delete(message_key)
+ logging.echo_log(f"消息已删除: {message_key};数据为:{message_data}.")
+ except redis.RedisError as e:
+ logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True)
+ except Exception as e:
+ logging.echo_log(
+ f"执行回调异常:{e};方法:{cls.get_func_name(func)};消息: {message_key}.",
+ level=ERROR, is_log_exc=True
+ )
+ return result
+
+
+class PubSubActor(Redis):
+ """
+ 发布订阅执行器。用于发布消息和订阅消息。
+ 订阅采用阻塞式读取,可以在读取到数据后,执行回调方法,并根据参数确定是否删除历史消息。
+ """
+
+ def __init__(self, hash_name: str):
+ self.hash_name = f"{hash_name}_HASH_NAME"
+ self.channel = f"{hash_name}_CHANNEL"
+
+ self.running = False
+ """
+ 优雅退出控制标志
+ """
+
+ self.stopping = False
+ """
+ 控制整个 run_forever 循环退出
+ """
+
+ async def publish(self, data: dict) -> str:
+ """
+ 数据写入 Redis 并发布消息。
+
+ :param data: 写入 Redis 的数据
+ :return: 消息ID
+ """
+ async with await self.get_redis() as _redis:
+ # 生成雪花 ID 作为 Hash Key
+ _random_num = random.randint(1000, 9999)
+ _id = IdWorker.get_id_worker(3, 3, _random_num).get_id()
+
+ # 写入Redis hash
+ await _redis.hset(f"{self.hash_name}:{_id}", mapping=data)
+ # 发布新消息通知
+ await _redis.publish(self.channel, _id)
+ return _id
+
+ async def subscribe(self, func: Callable = None, is_delete=False):
+ """
+ 监听消息。
+
+ :param func: 监听回调程序
+ :param is_delete: 回调执行完毕后,是否删除消息
+ """
+ async with await self.get_redis() as _redis:
+ _pubsub = _redis.pubsub()
+ await _pubsub.subscribe(self.channel)
+
+ try:
+ self.running = True
+
+ # 使用 while 循环,而不是直接 async for,以便加入超时控制
+ while not self.stopping and self.running:
+ try:
+ # 每次循环都重新获取迭代器
+ listen_iter = _pubsub.listen()
+ message = await asyncio.wait_for(listen_iter.__anext__(), timeout=60.0)
+
+ if message["type"] != "message":
+ continue
+
+ message_id = message["data"]
+ message_key = f"{self.hash_name}:{message_id}"
+
+ try:
+ # 隔离处理回调异常
+ # 采用后台运行的方式处理,防止消息排队,提高消息处理性能
+ await aio_pool.run_background_task(self.callback(func, message_key, is_delete), 10)
+ # await self.callback(func, message_key, is_delete=is_delete)
+ except Exception:
+ # 继续处理下条消息
+ continue
+
+ except asyncio.TimeoutError:
+ # 超时,是心跳成功的标志
+ logging.echo_log("心跳:连接正常,继续监听...")
+ continue
+ except redis.exceptions.ConnectionError as e:
+ # 连接错误,触发重连
+ logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True)
+ raise e
+ except StopAsyncIteration:
+ # pubsub 正常关闭
+ logging.echo_log("PubSub 迭代器已停止.")
+ break
+ except (asyncio.CancelledError, KeyboardInterrupt):
+ logging.echo_log("收到退出信号,停止监听...")
+ self.running = False
+ raise
+ except Exception as e:
+ logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True)
+ raise e
+ finally:
+ self.running = False
+ try:
+ await _pubsub.unsubscribe(self.channel)
+ await _pubsub.close()
+ except Exception as close_err:
+ logging.echo_log(f"资源关闭异常: {close_err}.")
+ finally:
+ logging.echo_log("监听已完全停止.")
+
+ async def run_forever(self, func: Callable = None, is_delete=False):
+ """
+ 持久运行的监听器,包含自动重连逻辑和优雅退出。
+ """
+ while not self.stopping:
+ try:
+ logging.echo_log("启动新的监听会话...")
+ await self.subscribe(func, is_delete)
+ except (asyncio.CancelledError, KeyboardInterrupt):
+ logging.echo_log("收到退出信号,停止监听...")
+ self.stopping = True
+ break
+ except Exception as e:
+ logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True)
+
+ if self.stopping:
+ logging.echo_log("总开关已打开,停止重连.")
+ break
+
+ logging.echo_log("等待重新连接...")
+ try:
+ # 关键改动:直接使用 sleep,它本身就是可中断的
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ logging.echo_log("等待期间被取消,准备退出.")
+ break
+
+ logging.echo_log("监听服务已完全停止.")
+
+ async def history(self, func: Callable = None, is_delete=False):
+ """
+ 处理历史数据。
+
+ :param func: 监听回调程序
+ :param is_delete: 回调执行完毕后,是否删除消息
+ """
+ async with await self.get_redis() as _redis:
+ _keys = await _redis.keys()
+ for _k in _keys:
+ try:
+ # 隔离处理回调异常
+ # 采用后台运行的方式处理,防止消息排队,提高消息处理性能
+ await aio_pool.run_background_task(self.callback(func, _k, is_delete), 10)
+ # await self.callback(func, _k, is_delete=is_delete)
+ except Exception:
+ # 继续处理下条消息
+ continue
+
+ def subscribe_stop(self):
+ self.running = False
+ self.stopping = True
+
+
+class StreamActor(Redis):
+ """
+ 流执行器。使用 Redis Streams 实现发布消息和消费消息。
+ 消费采用消费者组模式,支持消息确认和可靠传递。
+ 方法结构与 PubSubActor 保持一致,便于无缝替换。
+
+ 此版本集成了启动时的僵尸任务自动恢复功能。
+ """
+
+ @classmethod
+ def actor_config(cls, config_path: str):
+ """
+ 根据路径,取得配置信息。
+
+ :param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔
+ :return:
+ """
+ _stream_name = config_path.split(".")[-1].upper()
+ _stream_config = config.get_config(config_path)
+ _group_name = _stream_config.get('group', f"{_stream_name}_GROUP")
+ _consumer_name = _stream_config.get('consumer', f"{_stream_name}_CONSUMER")
+ _snow_id = IdWorker.get_id_worker().get_id()
+ _consumer_name = f"{_consumer_name}_{_snow_id}"
+ return _stream_name, _group_name, _consumer_name
+
+ @classmethod
+ def new_actor(cls, config_path: str):
+ """
+ 根据配置文件中的配置小节创建流执行器。
+
+ :param config_path: 配置路径,配置文件中,直到 stream 的 Key,用点【.】分隔
+ :return: 执行器对象
+ """
+ _stream_name, _group_name, _consumer_name = cls.actor_config(config_path)
+ return cls(_stream_name, _group_name, _consumer_name)
+
+ def __init__(self, stream_name: str, group_name: str, consumer_name: str):
+ """
+ 初始化流执行器。
+
+ :param stream_name: Redis Stream 的名称
+ :param group_name: 消费者组的名称
+ :param consumer_name: 当前消费者的名称
+ """
+ self.stream_name = stream_name
+ self.group_name = group_name
+ self.consumer_name = consumer_name
+
+ self.running = False
+ """
+ 优雅退出控制标志
+ """
+
+ self.stopping = False
+ """
+ 控制整个 run_forever 循环退出
+ """
+
+ async def _ensure_group_exists(self):
+ """确保消费者组已存在,如果不存在则创建。"""
+ try:
+ _redis = await self.get_redis()
+ await _redis.xgroup_create(
+ name=self.stream_name,
+ groupname=self.group_name,
+ id='0', # 从头开始消费
+ mkstream=True # Stream 不存在时自动创建
+ )
+ logging.echo_log(f"消费者组 '{self.group_name}' 已创建.")
+ except redis.exceptions.ResponseError as e:
+ if "Consumer Group name already exists" in str(e):
+ logging.echo_log(f"消费者组 '{self.group_name}' 已存在.")
+ else:
+ raise
+
+ async def publish(self, data: dict) -> str:
+ """
+ 将数据作为消息写入 Redis Stream。
+
+ :param data: 写入 Stream 的数据字典
+ :return: 消息ID
+ """
+ async with await self.get_redis() as _redis:
+ # 添加时会自动生成唯一的消息ID
+ message_id = await _redis.xadd(name=self.stream_name, fields=data)
+ logging.echo_log(f"消息已发布至 Stream '{self.stream_name}',ID: {message_id};数据为:{data}.")
+ return message_id
+
+ async def reclaim_stale_tasks(self, func: Callable, is_delete: bool, stale_threshold_ms: int = 5 * 60 * 1000):
+ """
+ 检查并尝试重新处理僵尸任务。
+
+ Args:
+ func (Callable): 用于处理任务的业务回调函数。
+ is_delete (bool): 处理成功后是否确认消息。
+ stale_threshold_ms (int): 判定为僵尸任务的空闲时间阈值(毫秒)。
+ """
+ async with await self.get_redis() as _redis:
+ # 1. 发现僵尸任务
+ try:
+ stale_tasks = await _redis.xpending_range(
+ name=self.stream_name,
+ groupname=self.group_name,
+ min='-',
+ max='+',
+ count=10, # 每次最多处理10个僵尸任务,避免启动时阻塞太久
+ idle=stale_threshold_ms
+ )
+ except Exception as e:
+ logging.echo_log(f"检查僵尸任务时出错: {e}", level=ERROR, is_log_exc=True)
+ return
+
+ if not stale_tasks:
+ logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.")
+ return
+
+ if not stale_tasks or not isinstance(stale_tasks, list):
+ logging.echo_log(f"未发现空闲超过 {stale_threshold_ms / 1000} 秒的僵尸任务.")
+ return
+ message_ids = [task['message_id'] for task in stale_tasks]
+ logging.echo_log(f"发现 {len(message_ids)} 个僵尸任务,尝试认领并重新处理...")
+
+ # 2. 认领任务
+ try:
+ reclaimed_messages = await _redis.xclaim(
+ name=self.stream_name,
+ groupname=self.group_name,
+ consumername=self.consumer_name, # 认领给自己
+ min_idle_time=stale_threshold_ms,
+ message_ids=message_ids,
+ justid=False # 我们需要消息内容来处理
+ )
+ except Exception as e:
+ logging.echo_log(f"认领僵尸任务时出错: {e}", level=ERROR, is_log_exc=True)
+ return
+
+ if not reclaimed_messages:
+ logging.echo_log("未能成功认领任何僵尸任务.")
+ return
+
+ logging.echo_log(f"成功认领 {len(reclaimed_messages)} 个僵尸任务,开始处理.")
+
+ # 3. 处理被认领的任务
+ for message_id, message_data in reclaimed_messages:
+ # 使用我们已有的 _callback_wrapper 来处理,保证逻辑一致
+ await self._callback_wrapper(
+ func=func,
+ message_id=message_id,
+ message_data=message_data,
+ is_delete=is_delete
+ )
+
+ async def history(self, func: Callable, is_delete: bool):
+ """
+ 启动时的恢复程序。
+ 检查并处理长时间未完成的僵尸任务,确保系统健壮性。
+ """
+ logging.echo_log("执行启动恢复程序,检查僵尸任务...")
+ # 将 func 和 is_delete 传递下去
+ await self.reclaim_stale_tasks(func=func, is_delete=is_delete, stale_threshold_ms=5 * 60 * 1000)
+ logging.echo_log("启动恢复程序执行完毕.")
+
+ async def subscribe(self, func: Callable = None, is_delete=False):
+ """
+ 从消费者组中监听并处理新消息。
+ 启动时会先执行恢复程序,处理僵尸任务。
+ """
+ await self._ensure_group_exists()
+
+ # === 核心改动:启动时先执行恢复,并传入回调参数 ===
+ await self.history(func=func, is_delete=is_delete)
+
+ async with await self.get_redis() as _redis:
+ try:
+ self.running = True
+ logging.echo_log("僵尸任务恢复完成,开始监听新消息...")
+
+ while not self.stopping and self.running:
+ try:
+ # 阻塞读取新消息
+ streams = await _redis.xreadgroup(
+ groupname=self.group_name,
+ consumername=self.consumer_name,
+ streams={self.stream_name: '>'}, # '>' 表示只读取新消息
+ count=1,
+ block=5000 # 5秒超时,类似 PubSub 的心跳
+ )
+
+ if not streams:
+ # 超时,是心跳成功的标志
+ logging.echo_log("心跳:连接正常,继续监听...")
+ continue
+
+ # 解析消息
+ stream, messages = streams[0]
+ message_id, message_data = messages[0]
+ logging.echo_log(f"收到新消息: ID={message_id}, 数据={message_data}")
+
+ try:
+ # 隔离处理回调异常
+ # 采用后台运行的方式处理,防止消息排队,提高消息处理性能
+ await aio_pool.run_background_task(
+ self._callback_wrapper(func, message_id, message_data, is_delete), 10
+ )
+ except Exception:
+ # 回调处理失败,消息未被确认,将留在队列中稍后重试
+ continue
+
+ except redis.exceptions.ConnectionError as e:
+ # 连接错误,触发重连
+ logging.echo_log(f"检测到连接错误: {e}. 将触发重连...", level=ERROR, is_log_exc=True)
+ raise e
+ except (asyncio.CancelledError, KeyboardInterrupt):
+ logging.echo_log("收到退出信号,停止监听...")
+ self.running = False
+ raise
+ except Exception as e:
+ logging.echo_log(f"监听会话因错误结束: {e}.", level=ERROR, is_log_exc=True)
+ raise e
+ finally:
+ self.running = False
+ logging.echo_log("Stream 监听已完全停止.")
+
+ async def _callback_wrapper(self, func: Callable, message_id: str, message_data: dict, is_delete: bool):
+ """
+ 一个包装器,用于将 Stream 的消息处理逻辑适配到基类的 callback 方法签名上。
+ 这样做可以复用基类 callback 中的异常处理逻辑。
+ """
+ # 如果没有提供回调函数,则无法处理,直接返回,避免丢失任务
+ if not func:
+ logging.echo_log(f"警告: 收到消息 {message_id} 但未提供业务回调函数,消息将被忽略.", level=WARNING)
+ return
+
+ # 不能直接调用基类的 callback,因为它会尝试删除
+ # 在这里复制它的异常处理逻辑,但使用 Stream 的操作
+ result = None
+ async with await self.get_redis() as _redis:
+ try:
+ if func:
+ # 处理回调
+ result = func(message_data)
+ # 处理协程
+ if isinstance(result, Awaitable):
+ result = await result
+
+ if is_delete:
+ # 先从 PENDING 列表中移除
+ await _redis.xack(self.stream_name, self.group_name, message_id)
+ # 再从 Stream 中逻辑删除
+ await _redis.xdel(self.stream_name, message_id)
+ logging.echo_log(f"消息已确认 (ACK): {message_id};数据为:{message_data}.")
+ except redis.RedisError as e:
+ logging.echo_log(f"Redis 操作异常: {e}.", level=ERROR, is_log_exc=True)
+ except Exception as e:
+ logging.echo_log(
+ f"执行回调异常:{e};方法:{self.get_func_name(func)};消息: {message_id}.",
+ level=ERROR, is_log_exc=True
+ )
+ return result
+
+ async def run_forever(self, func: Callable = None, is_delete=False):
+ """
+ 持久运行的监听器,包含自动重连逻辑和优雅退出。
+ """
+ while not self.stopping:
+ try:
+ logging.echo_log("启动新的监听会话...")
+ await self.subscribe(func, is_delete)
+ except (asyncio.CancelledError, KeyboardInterrupt):
+ logging.echo_log("收到退出信号,停止监听...")
+ self.stopping = True
+ break
+ except Exception as e:
+ logging.echo_log(f"监听会话因未知错误结束: {e}. 10秒后重试...", level=ERROR, is_log_exc=True)
+
+ if self.stopping:
+ logging.echo_log("总开关已打开,停止重连.")
+ break
+
+ logging.echo_log("等待重新连接...")
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ logging.echo_log("等待期间被取消,准备退出.")
+ break
+
+ logging.echo_log("监听服务已完全停止.")
+
+ def subscribe_stop(self):
+ self.running = False
+ self.stopping = True
diff --git a/paste/rbac/__init__.py b/paste/rbac/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/paste/rbac/rbac_assignment.py b/paste/rbac/rbac_assignment.py
new file mode 100644
index 0000000..34f3f3b
--- /dev/null
+++ b/paste/rbac/rbac_assignment.py
@@ -0,0 +1,87 @@
+from typing import Union
+
+from sqlalchemy import select, delete
+
+from paste.rbac.rbac_models import RbacAssignmentModel
+
+
+class RbacAssignment(RbacAssignmentModel):
+ """
+ 权限分配器,负责用户权限的分配。
+ 允许为用户分配角色,也允许为用户直接分配权限。
+ """
+
+ @classmethod
+ def new(cls, user_id: Union[str, int], item_name: str):
+ """
+ 新建角色分配对象。不检查用户是否存在,也不保存数据库。
+
+ :param user_id: 用户 ID
+ :param item_name: 授权项名称
+ :return: 新的分配对象
+ """
+ return cls(user_id=user_id, item_name=item_name).before_save()
+
+ @classmethod
+ async def new_batch(cls, user_id: int, item_names: set[str]):
+ """
+ 为用户创建角色或权限,自动剔除已经分配的角色或授权项。注意:仅创建不保存。
+
+ :param user_id: 用户名
+ :param item_names: 授权项名称列表,这里允许是角色名称,也允许是权限名称
+ """
+ from paste.rbac.rbac_user import RbacUser, Supervisors
+
+ # 确认用户存在,且不是超级用户
+ _mod_user: RbacUser = await RbacUser(**{RbacUser.id.key: user_id}).async_find_first()
+ assert _mod_user is not None, f"ID 为 {user_id} 的用户不存在."
+ assert _mod_user.username not in Supervisors, f"ID 为 {user_id} 的用户 {_mod_user.username} 禁止设置权限."
+
+ # 查库,过滤已经存在的权限分配
+ _query = select(cls.item_name).where(cls.user_id == user_id, cls.item_name.in_(item_names))
+ _rows = await cls.query_all(_query)
+ _exist_names: set[str] = set([_r[cls.item_name.key] for _r in _rows])
+ item_names = set([_name for _name in item_names if _name not in _exist_names])
+
+ # 创建模型列表
+ _new_assignments = [cls(**{cls.user_id.key: user_id, cls.item_name.key: name}) for name in item_names]
+ return _new_assignments
+
+ @classmethod
+ async def assign(cls, user_id: int, item_names: set[str]):
+ """
+ 为用户分配角色或权限,自动剔除已经分配的角色或授权项。
+
+ :param user_id: 用户名
+ :param item_names: 授权项名称列表,这里允许是角色名称,也允许是权限名称
+ """
+ # 创建模型列表
+ _new_assignments = await cls.new_batch(user_id=user_id, item_names=item_names)
+
+ # 保存数据
+ _session = cls.get_aio_session()
+ try:
+ _session.add_all(_new_assignments)
+ await _session.commit()
+ except Exception as e:
+ await _session.rollback()
+ raise e
+ finally:
+ await _session.close()
+
+ @classmethod
+ async def delete(cls, user_id: Union[str, int], item_name: str):
+ """
+ 删除授权。
+
+ :param user_id: 用户 ID
+ :param item_name: 授权项
+ :return: 操作状态,游标返回对象
+ """
+ assert user_id not in ('', None), '必须提供用户 ID.'
+ assert item_name not in ('', None), '必须提供权限或角色名称.'
+
+ _query = delete(cls).where(cls.user_id == user_id, cls.item_name == item_name)
+ _result = await cls.raw_execute(query=_query)
+ _rowcount = _result.rowcount if isinstance(_result.rowcount, int) else 0
+ return _rowcount > 0, _result
diff --git a/paste/rbac/rbac_item.py b/paste/rbac/rbac_item.py
new file mode 100644
index 0000000..85deb97
--- /dev/null
+++ b/paste/rbac/rbac_item.py
@@ -0,0 +1,134 @@
+from sqlalchemy import select, delete
+
+from paste.rbac.rbac_item_child import RbacItemChild
+from paste.rbac.rbac_models import RbacItemModel
+
+
+class RbacItem(RbacItemModel):
+ """
+ 授权项。
+ 分为角色和权限两类,分别由对应子类实现。
+ """
+
+ TYPE_ROLE = 1
+ """
+ 角色类型。
+ """
+
+ TYPE_PERMISSION = 2
+ """
+ 权限类型。
+ """
+
+ @classmethod
+ async def all_parents(cls, item_names: set[str]):
+ """
+ 按层次,递归查询授权项列表中所有授权项的父授权项名称。
+
+ :param item_names: 授权项名称列表
+ :return: 所有各层级的父授权项
+ """
+ _query = select(RbacItemChild.parent).where(RbacItemChild.child.in_(item_names))
+ _rows = await cls.query_all(_query)
+ _item_names: set[str] = set(_rows)
+
+ if _item_names:
+ _parents = await cls.all_parents(_item_names)
+ _item_names = _item_names.union(_parents)
+
+ return _item_names
+
+ @classmethod
+ async def all_children(cls, item_names: set[str]):
+ """
+ 按层次,递归查询授权项列表中所有授权项的子授权项名称。
+
+ :param item_names: 授权项名称列表
+ :return: 所有各层级的子授权项
+ """
+ _query = select(RbacItemChild.child).where(RbacItemChild.parent.in_(item_names))
+ _rows = await cls.query_all(_query)
+ _item_names: set[str] = set(_rows)
+
+ if _item_names:
+ _children = await cls.all_children(_item_names)
+ _item_names = _item_names.union(_children)
+
+ return _item_names
+
+ @classmethod
+ async def find_by_name(cls, name: str):
+ """
+ 根据授权项名称查找授权项。
+
+ :param name: 授权项名称
+ :return: 授权项
+ """
+ _query = select(cls).where(cls.name == name)
+ _model: cls = await cls.query_first(_query)
+ return _model
+
+ async def add_children(self, item_names: set[str]):
+ """
+ 增加子授权项,自动剔除已包含的角色或授权项。角色和权限都能增加子授权项。
+
+ :param item_names: 待分配的子授权项名称列表
+ """
+ # 首先根据授权项名称列表查出所有的授权项,剔除错误的名称
+ _query = select(RbacItem.name).where(RbacItem.name.in_(item_names))
+ _rows = await self.query_all(_query)
+ item_names: set[str] = set(_rows)
+
+ # 取得所有祖先,确保所有子授权项,没有出现在祖先中,防止循环授权
+ _all_parents = await self.all_parents(item_names=item_names)
+
+ # 查库,过滤已经存在的子授权项
+ _query = select(RbacItemChild.child).where(
+ RbacItemChild.parent == self.name,
+ RbacItemChild.child.in_(item_names)
+ )
+ _rows = await self.query_all(_query)
+ _exist_children: set[str] = set(_rows)
+
+ # 创建模型列表,剔除已包含的角色或授权项,剔除出现在祖先中的授权项,以及剔除自身
+ _new_children = [
+ RbacItemChild(**{RbacItemChild.parent.key: self.name, RbacItemChild.child.key: _name})
+ for _name in item_names
+ if _name not in _exist_children and _name not in _all_parents and _name != self.name
+ ]
+
+ # 保存数据
+ _session = self.get_aio_session()
+ try:
+ _session.add_all(_new_children)
+ await _session.commit()
+ except Exception as e:
+ await _session.rollback()
+ raise e
+ finally:
+ await _session.close()
+
+ async def get_children(self):
+ """
+ 通过中间关系查询所有子授权项。注意:查询的是直接子授权项,不包含继承获得的授权项。
+
+ :return: 子权限项列表
+ """
+ _query = select(RbacItem).join(
+ RbacItemChild, RbacItemChild.child == RbacItem.name
+ ).where(
+ RbacItemChild.parent == self.name,
+ )
+ _item_model_list: list[RbacItem] = await self.query_all(_query)
+ return _item_model_list
+
+ async def remove_children(self, item_names: set[str]):
+ """
+ 删除子授权项。
+
+ :param item_names: 子授权项名称集合
+ :return: 成功删除的项数
+ """
+ _delete = delete(RbacItemChild).where(RbacItemChild.parent == self.name, RbacItemChild.child.in_(item_names))
+ _result = await self.raw_execute(_delete)
+ return _result.rowcount
diff --git a/paste/rbac/rbac_item_child.py b/paste/rbac/rbac_item_child.py
new file mode 100644
index 0000000..9dc630d
--- /dev/null
+++ b/paste/rbac/rbac_item_child.py
@@ -0,0 +1,8 @@
+from paste.rbac.rbac_models import RbacItemChildModel
+
+
+class RbacItemChild(RbacItemChildModel):
+ """
+ 授权项目关系。
+ """
+ pass
diff --git a/paste/rbac/rbac_models.py b/paste/rbac/rbac_models.py
new file mode 100644
index 0000000..19b8a6b
--- /dev/null
+++ b/paste/rbac/rbac_models.py
@@ -0,0 +1,93 @@
+#
+# 数据模型配置文件,注意:与数据模型对应的表名称来自配置文件。
+# 若使用 生成代码,则这部分表将不会自动生成。
+#
+
+from sqlalchemy import Column, String, DateTime, LargeBinary, text, BigInteger, Integer, SmallInteger, Text, ForeignKey
+from sqlalchemy.orm import relationship
+
+from paste.core import config
+from paste.db.basemodel import BaseModel
+
+
+class RbacRuleModel(BaseModel):
+ __tablename__ = config.get_config('rbac.table.rule')
+ __table_args__ = {'comment': '规则'}
+
+ name = Column(String(64, 'utf8mb4_unicode_ci'), primary_key=True, comment='名称')
+ data = Column(LargeBinary, comment='规则对象')
+ created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='更新时间')
+
+
+class RbacUserModel(BaseModel):
+ __tablename__ = config.get_config('rbac.table.user')
+ __table_args__ = {'comment': '用户'}
+
+ id = Column(BigInteger, primary_key=True, comment='系统编号')
+ username = Column(String(255, 'utf8mb4_unicode_ci'), nullable=False, unique=True, comment='用户名')
+ password_hash = Column(String(255, 'utf8mb4_unicode_ci'), nullable=False, comment='密码')
+ password_reset_token = Column(String(255, 'utf8mb4_unicode_ci'), comment='重置标记')
+ auth_key = Column(String(255, 'utf8mb4_unicode_ci'), comment='授权码')
+ status = Column(Integer, nullable=False, server_default=text("'0'"), comment='用户状态')
+ type = Column(String(64, 'utf8mb4_unicode_ci'), nullable=False, server_default=text("'user'"), comment='用户类型')
+ created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='更新时间')
+
+ def before_save(self):
+ super().before_save()
+ if self.is_new:
+ self.status = 1
+
+
+class RbacItemModel(BaseModel):
+ __tablename__ = config.get_config('rbac.table.item')
+ __table_args__ = {'comment': '授权项(角色/权限)'}
+
+ name = Column(String(64, 'utf8mb4_unicode_ci'), primary_key=True, comment='名称')
+ type = Column(SmallInteger, nullable=False, comment='类型,1角色,2权限')
+ description = Column(Text(collation='utf8mb4_unicode_ci'), comment='描述')
+ rule_name = Column(
+ ForeignKey(f"{config.get_config('rbac.table.rule')}.name", ondelete='SET NULL', onupdate='CASCADE'),
+ index=True, comment='规则名称'
+ )
+ created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
+ updated_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='更新时间')
+
+ rbac_rule = relationship(RbacRuleModel)
+
+
+class RbacAssignmentModel(BaseModel):
+ __tablename__ = config.get_config('rbac.table.assignment')
+ __table_args__ = {'comment': '权限分配'}
+
+ item_name = Column(
+ ForeignKey(f"{config.get_config('rbac.table.item')}.name", ondelete='CASCADE', onupdate='CASCADE'),
+ primary_key=True, nullable=False, comment='授权项(角色/权限)'
+ )
+ user_id = Column(
+ ForeignKey(f"{config.get_config('rbac.table.user')}.id", ondelete='CASCADE', onupdate='CASCADE'),
+ primary_key=True, nullable=False, index=True, comment='用户'
+ )
+ created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
+
+ rbac_item = relationship(RbacItemModel)
+ rbac_user = relationship(RbacUserModel)
+
+
+class RbacItemChildModel(BaseModel):
+ __tablename__ = config.get_config('rbac.table.item_child')
+ __table_args__ = {'comment': '授权关系'}
+
+ parent = Column(
+ ForeignKey(f"{config.get_config('rbac.table.item')}.name", ondelete='CASCADE', onupdate='CASCADE'),
+ primary_key=True, nullable=False, comment='角色/权限'
+ )
+ child = Column(
+ ForeignKey(f"{config.get_config('rbac.table.item')}.name", ondelete='CASCADE', onupdate='CASCADE'),
+ primary_key=True, nullable=False, index=True, comment='授权项(角色/权限)'
+ )
+ created_at = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
+
+ child_item = relationship(RbacItemModel, primaryjoin='RbacItemChildModel.child == RbacItemModel.name')
+ parent_item = relationship(RbacItemModel, primaryjoin='RbacItemChildModel.parent == RbacItemModel.name')
diff --git a/paste/rbac/rbac_permission.py b/paste/rbac/rbac_permission.py
new file mode 100644
index 0000000..71b1758
--- /dev/null
+++ b/paste/rbac/rbac_permission.py
@@ -0,0 +1,149 @@
+import importlib
+
+from sqlalchemy import delete
+
+from paste.core import config
+from paste.rbac.rbac_item import RbacItem
+from paste.web.application import Application
+
+
+class RbacPermission(RbacItem):
+ """
+ 权限。
+ 大多数情况下,权限都对应着一个具体的请求操作,且经由导入期自动导入。
+ 权限的 name 属性为请求控制器 RequestHandler 的 route_pattern 属性值,
+ 权限的 description 属性对应于 RequestHandler 类的文档注释的第一行。
+ 权限可以分配给用户,也可以分配给角色或其他权限。
+
+ 此外允许手动创建权限,主要是为规则创建一个权限载体,当其他的权限属于这个权限的子权限时,相当于同时拥有了这个权限的规则。
+ """
+
+ @classmethod
+ async def create(cls, name: str, description: str = None, rule_name: str = None):
+ """
+ 创建权限。
+
+ :param name: 权限名称
+ :param rule_name: 规则名称
+ :param description: 权限描述
+ :return: 权限对象
+ """
+ assert name not in ('', None), '必须提供权限名称.'
+ _permission = cls(name=name, description=description, type=cls.TYPE_PERMISSION)
+ _permission.rule_name = rule_name if rule_name else _permission.rule_name
+ await _permission.async_save()
+ return _permission
+
+ @classmethod
+ async def delete(cls, name: str):
+ """
+ 删除权限。
+
+ :param name: 授权项名称
+ :return: 操作状态,游标返回对象
+ """
+ assert name not in ('', None), '必须提供权限名称.'
+ _query = delete(cls).where(cls.name == name)
+ _result = await cls.raw_execute(query=_query)
+ _rowcount = _result.rowcount if isinstance(_result.rowcount, int) else 0
+ return _rowcount > 0, _result
+
+ @classmethod
+ async def modify(cls, name: str, description: str = None, rule_name: str = None):
+ """
+ 编辑权限。
+
+ :param name: 名称
+ :param description: 描述
+ :param rule_name: 规则名称
+ :return: 权限对象
+ """
+ assert name not in ('', None), '必须提供权限名称.'
+ _permission: cls = await cls(name=name).async_find_first()
+ assert _permission, f"未找到名称为:{name} 的权限."
+
+ _permission.description = description if description else _permission.description
+ _permission.rule_name = rule_name if rule_name else _permission.rule_name
+ await _permission.async_save()
+ return _permission
+
+ @classmethod
+ def identify_permission(cls):
+ """
+ 根据应用程序配置,从应用程序目录中识别所有的控制器,及其对应的路由。
+ 读取配置文件中关于 tornado 部分的配置,扫描配置包中的所有 Handler 类。
+ 识别需要授权的接口,即包含 auth_permission 装饰器的接口。
+ 忽略无需授权的接口,如:部分 OpenAPI 或 FrontendAPI。
+
+ :return: 识别到的控制器列表,注意列表中是 tuple(route, handler_type)
+ """
+ apps_config: dict = config.get_config('tornado')
+ _handlers: list[tuple[str, type]] = []
+
+ for _n, _app_cfgs in apps_config.items():
+ for app_cfg in _app_cfgs:
+ _handlers_pkg = app_cfg.get('handlers_pkg')
+ _modules_itr = Application.modules_iterator(package=_handlers_pkg)
+ for file_finder, handler_name, is_package in _modules_itr:
+ if is_package:
+ continue
+
+ _module = importlib.import_module(handler_name)
+ _hls_list = Application.fetch_handlers(module=_module)
+ for _hls in _hls_list:
+ _uri, _hdl = _hls
+ # 检查 post 或 get 是否被 auth_permission 装饰
+ for method_name in ['post', 'get']:
+ method = getattr(_hdl, method_name, None)
+ if method and callable(method):
+ if getattr(method, '__auth_permission__', False):
+ _handlers.append(_hls)
+ break
+
+ return _handlers
+
+ @classmethod
+ async def import_permissions(cls):
+ """
+ 导入所有的可分配权限到数据库,若已经在数据库存在,则更新。
+ """
+ _handlers = cls.identify_permission()
+
+ # 根据所有的路由信息,查出已经有的权限数据
+ _routes: list[str] = [rc[0] for rc in _handlers]
+ _item_model_list: list[cls] = await cls(**{cls.name.key: _routes}).async_find()
+
+ # 利用路由 Key 建立索引
+ _item_model_dict: dict[str: cls] = {_item.name: _item for _item in _item_model_list}
+
+ _permissions: list[cls] = []
+ for _route, _cls in _handlers:
+ # 取得类描述
+ _desc = f"{_cls.__doc__}".strip().split('\n')[0].strip(),
+ # 利用路由取出模型
+ _perm_item: cls = _item_model_dict.get(_route, None)
+ if _perm_item is None:
+ # 未得到模型,创建
+ _perm_item = cls(**{cls.name.key: _route, cls.description.key: _desc})
+ else:
+ # 已得到模型,更新
+ _perm_item.description = _desc
+ _perm_item.close_session()
+
+ # 加入列表,批量保存
+ _permissions.append(_perm_item)
+
+ # 保存数据
+ _session = cls.get_aio_session()
+ try:
+ _session.add_all(_permissions)
+ await _session.commit()
+ except Exception as e:
+ await _session.rollback()
+ raise e
+ finally:
+ await _session.close()
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.type = self.TYPE_PERMISSION
diff --git a/paste/rbac/rbac_role.py b/paste/rbac/rbac_role.py
new file mode 100644
index 0000000..b85922e
--- /dev/null
+++ b/paste/rbac/rbac_role.py
@@ -0,0 +1,63 @@
+from sqlalchemy import delete
+
+from paste.rbac.rbac_item import RbacItem
+
+
+class RbacRole(RbacItem):
+ """
+ 角色。
+ 是一系列关联角色或权限的组合。可以包含权限,也可以包含其他角色。
+ """
+
+ @classmethod
+ async def create(cls, name: str, description: str = None, rule_name: str = None):
+ """
+ 创建角色。
+
+ :param name: 角色名称
+ :param description: 角色描述
+ :param rule_name: 规则名称
+ :return: 角色对象
+ """
+ assert name not in ('', None), '必须提供角色名称.'
+ _role = cls(name=name, description=description, type=cls.TYPE_ROLE)
+ _role.rule_name = rule_name if rule_name else _role.rule_name
+ await _role.async_save()
+ return _role
+
+ @classmethod
+ async def delete(cls, name: str):
+ """
+ 删除角色。
+
+ :param name: 授权项名称
+ :return: 操作状态,游标返回对象
+ """
+ assert name not in ('', None), '必须提供角色名称.'
+ _query = delete(cls).where(cls.name == name)
+ _result = await cls.raw_execute(query=_query)
+ _rowcount = _result.rowcount if isinstance(_result.rowcount, int) else 0
+ return _rowcount > 0, _result
+
+ @classmethod
+ async def modify(cls, name: str, description: str = None, rule_name: str = None):
+ """
+ 编辑角色。
+
+ :param name: 角色名称
+ :param description: 描述
+ :param rule_name: 规则名称
+ :return: 角色对象
+ """
+ assert name not in ('', None), '必须提供角色名称.'
+ _role: cls = await cls(name=name).async_find_first()
+ assert _role, f"未找到名称为:{name} 的角色."
+
+ _role.description = description if description else _role.description
+ _role.rule_name = rule_name if rule_name else _role.rule_name
+ await _role.async_save()
+ return _role
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.type = self.TYPE_ROLE
diff --git a/paste/rbac/rbac_rule.py b/paste/rbac/rbac_rule.py
new file mode 100644
index 0000000..92830b1
--- /dev/null
+++ b/paste/rbac/rbac_rule.py
@@ -0,0 +1,147 @@
+import importlib
+import pickle
+
+from sqlalchemy import select, text, delete
+
+from paste.rbac.rbac_item import RbacItem
+from paste.rbac.rbac_models import RbacRuleModel
+from paste.rbac.rbac_permission import RbacPermission
+from paste.rbac.rbac_user import RbacUser
+
+
+class RbacRule(RbacRuleModel):
+ """
+ 规则是授权项的附带验证条件。在验证授权项时,只能判断是否具有某个授权项,无法对具体数据执行进一步验证。比如
+ 验证某一项数据是否允许某个用户执行某操作,此时就需要用到规则。
+
+ 规则是在单独定义的验证方法,这个方法被持久化保存在数据库中,具体执行某个需要鉴权的操作时,若该权限配置了规
+ 则,那么规则方法会一并参与到鉴权过程中去,以确定用户是否有对具体数据执行操作的权限。
+
+ 必须允许多个规则应用于一个授权项,但事实上是一个授权项自身只能绑定一个规则,解决方案是将一系列需要规则鉴权
+ 的操作作为规则权限的子授权项。这样,在对这个权限进行鉴权操作时,会自底向上逐个检查父权限是否有规则,若有规
+ 则,那么会进入这个父权限的规则方法,执行,并校验其返回值。
+
+ 因此要允许手动创建授权项,且允许手动为授权项绑定规则,然后将其他同样需要执行该规则的权限配置为该权限的子权
+ 限。
+
+ 由于在规则执行系统中,不知道未来将会编写和配置的规则,因此无法调用到对应的规则,只能将来编写好规则后,通过
+ 持久化对象到数据库中,通过查库还原将来的规则对象后再执行 run 方法。
+ """
+
+ @classmethod
+ async def create(cls, full_class_name: str):
+ """
+ 添加规则。
+
+ :param full_class_name: 规则类
+ :return: 保存状态、当前规则
+ """
+ _rule_cls = cls.load_rule_class(full_class_name)
+ assert _rule_cls, f"未找到名称为:{full_class_name} 的规则类."
+
+ _rule_model = _rule_cls()
+ _rule_model.data = _rule_model.dumps()
+ if not _rule_model.name:
+ _rule_model.name = _rule_cls.__name__
+
+ await _rule_model.async_save()
+ return _rule_model
+
+ @classmethod
+ async def delete(cls, name: str):
+ """
+ 删除规则。
+
+ :param name: 要删除的规则名称
+ :return: 是否删除成功、删除的行数
+ """
+ assert name not in ('', None), '必须提供规则名称.'
+ _query = delete(cls).where(cls.name == name)
+ _result = await cls.raw_execute(query=_query)
+ _rowcount = _result.rowcount if isinstance(_result.rowcount, int) else 0
+ return _rowcount > 0, _rowcount
+
+ @classmethod
+ async def modify(cls, name: str, full_class_name: str):
+ """
+ 编辑规则。
+
+ :param name: 名称
+ :param full_class_name: 规则类
+ :return: 保存状态、当前规则
+ """
+ _rule_cls = cls.load_rule_class(full_class_name)
+ assert _rule_cls, f"未找到名称为:{full_class_name} 的规则类."
+ _rule_model = _rule_cls()
+
+ assert name not in ('', None), '必须提供规则名称.'
+ _rule: cls = await cls(name=name).async_find_first()
+ assert _rule, f"未找到名称为:{name} 的规则."
+
+ _rule.data = _rule_model.dumps()
+ _rule.name = _rule_model.name
+ if not _rule.name:
+ _rule.name = _rule_cls.__name__
+
+ await _rule.async_save()
+ return _rule
+
+ @classmethod
+ async def find_by_item_names(cls, item_names: set[str]):
+ """
+ 根据授权项名称(权限名称或角色名称)取得所有角色。
+
+ :param item_names: 授权项名称列表
+ :return: 规则列表
+ """
+ # 取出所有授权项中的规则,忽略没有规则的
+ _query = select(cls).join(
+ RbacItem, RbacItem.rule_name == cls.name
+ ).where(
+ RbacItem.name.in_(item_names),
+ text(f"ifnull({RbacItem.rule_name.key},'')!=''")
+ )
+ _rule_list: list[cls] = await cls.query_all(_query)
+ return _rule_list
+
+ @classmethod
+ def load_rule_class(cls, full_class_name: str):
+ """
+ 通过规则类名称,加载规则类。若类所在的模块不存在,则报异常。
+
+ :param full_class_name: 完整规则名称,从顶层模块名称开始,直到类名称。
+ :return: 找到的规则类,找不到返回 None
+ """
+ _full_paths = full_class_name.split('.')
+ _cls_name = _full_paths[-1]
+ _mod_name = '.'.join(_full_paths[:-1])
+
+ try:
+ _module = importlib.import_module(_mod_name)
+ # 迭代模块成员
+ for _n in dir(_module):
+ _cls = getattr(_module, _n)
+ if isinstance(_cls, type) and issubclass(_cls, RbacRule) and _cls.__name__ == _cls_name:
+ return _cls
+ except Exception:
+ return None
+
+ return None
+
+ def run(self, rbac_user: RbacUser, rbac_permission: RbacPermission, *args, **kwargs) -> bool:
+ """
+ 运行规则。当用户在执行具体操作时,该规则会自动被唤起执行。
+
+ :param rbac_user: 用户
+ :param rbac_permission: 权限对象
+ :return: 允许执行返回 True, 否则返回 False
+ """
+ return True
+
+ def dumps(self):
+ """
+ 序列化为可持久化文本。
+
+ :return: 可持久化文本
+ """
+ return pickle.dumps(self)
diff --git a/paste/rbac/rbac_user.py b/paste/rbac/rbac_user.py
new file mode 100644
index 0000000..b8b03fc
--- /dev/null
+++ b/paste/rbac/rbac_user.py
@@ -0,0 +1,259 @@
+import pickle
+from typing import Optional, Awaitable
+
+from sqlalchemy import select
+
+from paste.rbac.rbac_assignment import RbacAssignment
+from paste.rbac.rbac_item import RbacItem
+from paste.rbac.rbac_models import RbacUserModel
+from paste.rbac.rbac_permission import RbacPermission
+from paste.rbac.rbac_role import RbacRole
+from paste.security import shash
+
+Supervisors = ('administrator', 'root', 'supervisor')
+"""
+超级管理员名称,这些名称不能用于一般用户。
+"""
+
+
+class RbacUser(RbacUserModel):
+ """
+ RBAC 用户类。
+ """
+
+ STATUS_DEFAULT = 0b00000000000000000000000000000
+ """
+ 用户默认状态:0。
+ """
+ STATUS_ENABLED = 0b00000000000000000000000000001
+ """
+ 用户激活状态:1。
+ """
+ STATUS_DISABLED = 0b00000000000000000000000000010
+ """
+ 用户禁用状态:2。
+ """
+ STATUS_DELETED = 0b00000000000000000000000000100
+ """
+ 用户删除状态:4。
+ """
+
+ STATUS_LIST = [STATUS_DEFAULT, STATUS_ENABLED, STATUS_DISABLED, STATUS_DELETED]
+ """
+ 允许的所有用户状态。
+ """
+
+ STATUS_DESCRIPTION = {
+ STATUS_DEFAULT: '默认',
+ STATUS_ENABLED: '激活',
+ STATUS_DISABLED: '已禁用',
+ STATUS_DELETED: '已删除',
+ }
+ """
+ 用户状态描述。
+ """
+
+ TYPE_USER = 'user'
+ """
+ 用户类型:用户。
+ """
+
+ TYPE_ADMINISTRATOR = 'admin'
+ """
+ 用户类型:管理员。
+ """
+
+ @classmethod
+ async def import_supervisors(cls):
+ """
+ 导入超级用户。
+
+ :return: 初始化状态
+ """
+ # 查出已经有的超级用户
+ _user_model_list: list[cls] = await cls(**{cls.username.key: Supervisors}).async_find()
+ _user_model_dict: dict[str: cls] = {_user.username: _user for _user in _user_model_list}
+
+ init_status: bool = True
+ for _username in Supervisors:
+ _user: cls = _user_model_dict.get(_username, None)
+ if _user is None:
+ # 用户不存在,创建超级用户
+ _save_status, _ = await cls.create(username=_username, password=_username)
+
+ return init_status
+
+ @classmethod
+ async def create(cls, username: str, password: str):
+ """
+ 创建用户。
+
+ :param username: 用户名
+ :param password: 密码
+ :return: 保存状态
+ """
+ assert username is not None and password is not None, '必须提供用户名和密码.'
+ _usr_model = cls()
+ _usr_model.before_save()
+ _usr_model.username = username
+ _usr_model.password_hash = shash.generate_password_hash(pwd=password)
+ _status = await _usr_model.async_save()
+ return _status, _usr_model
+
+ @classmethod
+ async def find_by_username(cls, username: str):
+ """
+ 根据用户名查找用户。
+
+ :param username: 用户名
+ :return: 用户对象
+ """
+ query = select(cls).where(cls.username == username)
+ model = await cls.query_first(query)
+ return model
+
+ @classmethod
+ def status_description(cls, status):
+ """
+ 取得状态描述。
+
+ :param status:
+ :return:
+ """
+ if status in cls.STATUS_DESCRIPTION:
+ return cls.STATUS_DESCRIPTION[status]
+ else:
+ return '状态未知'
+
+ async def assign(self, item_names: set[str]):
+ """
+ 为用户分配权限。这里调用了权限分配器的分配方法,自动剔除已经分配过的角色或权限。
+
+ :param item_names: 授权项名称列表
+ """
+ await RbacAssignment.assign(user_id=self.id, item_names=item_names)
+
+ async def can(self, permission_name: str, **kwargs) -> bool:
+ """
+ 验证用户是否具有 permission_name 权限。
+
+ 该方法主要是调用 :class:`RbacRule` 的 run() 方法,执行规则检验。
+
+ 在执行规则验证的时候是自底向上,查询父 RbacItem 中的规则,并调用其 run 方法。调用时不分先后,随机执行。
+
+ 只要有一个规则返回 False,则后续规则方法不再继续执行。
+
+ :param permission_name: 权限名称
+ :param kwargs: 可选参数,传递给规则的 run 方法的参数
+ :return: 验证状态
+ """
+ from paste.rbac.rbac_rule import RbacRule
+
+ # 如果用户没有初始化,直接禁止
+ if self.username is None:
+ return False
+
+ # 如果是超级用户,直接返回 True
+ if self.is_supervisors():
+ return True
+
+ # 取得所有授权项,然后取得所有对应规则
+ _item_names = await self.get_all_permissions()
+ _rule_list: list[RbacRule] = await RbacRule.find_by_item_names(_item_names)
+
+ # 取出授权项
+ _permission: RbacPermission = await RbacPermission.find_by_name(permission_name)
+
+ # 遍历执行 run 方法
+ for _rule in _rule_list:
+ # 还原持久化数据到对象
+ _rule: RbacRule = pickle.loads(_rule.data)
+
+ # 执行 run 方法,若是协程,则继续等待协程完成
+ _result = _rule.run(rbac_user=self, rbac_permission=_permission, **kwargs)
+ if isinstance(_result, Awaitable):
+ _result = await _result
+
+ if not _result:
+ return False
+
+ return True
+
+ async def get_all_roles(self):
+ """
+ 取得所有角色名称。
+
+ :return: 角色名称列表
+ """
+ _roles = await self.roles()
+ _role_names = [_r.name for _r in _roles]
+ return _role_names
+
+ async def get_all_permissions(self):
+ """
+ 取得所有授权项的名称,包含角色和权限。
+
+ :return: 授权项名称列表
+ """
+ _dir_perms = await self.get_direct_permissions()
+ _inh_perms = await self.get_inherit_permissions(direct_perms=_dir_perms)
+ _dir_perms = _dir_perms.union(_inh_perms)
+ return _dir_perms
+
+ async def get_direct_permissions(self):
+ """
+ 通过权限分配器查询用户直接拥有的角色和权限。
+
+ :return: 授权项名称列表
+ """
+ _query = select(RbacAssignment.item_name, RbacAssignment.user_id).where(RbacAssignment.user_id == self.id)
+ _item_names = await RbacAssignment.query_all(_query)
+ return set(_item_names)
+
+ async def get_inherit_permissions(self, direct_perms: Optional[set[str]] = None):
+ """
+ 通过直接拥有的权限查询继承而来的角色和权限。
+ 继承而来的权限包括用户所有角色的权限及其子权限。
+
+ :param direct_perms: 用户直接拥有的权限名称列表
+ :return: 授权项名称列表
+ """
+ # 若不提供用户直接拥有的权限,则查询
+ if direct_perms is None:
+ direct_perms = await self.get_direct_permissions()
+
+ _perm_names = await RbacItem.all_children(item_names=direct_perms)
+ return _perm_names
+
+ async def has_permission(self, permission_name: str):
+ """
+ 验证用户是否有执行某个路由的权限。
+
+ :param permission_name: 路由模式,一般为 route_pattern
+ :return: 是否有权限
+ """
+ _permissions = await self.get_all_permissions()
+ return permission_name in _permissions
+
+ async def roles(self):
+ """
+ 通过中间关系直接查询用户的所有角色。
+
+ :return: 角色列表
+ """
+ _query = select(RbacRole).join(
+ RbacAssignment, RbacAssignment.item_name == RbacRole.name
+ ).where(
+ RbacRole.type == RbacRole.TYPE_ROLE,
+ RbacAssignment.user_id == self.id
+ )
+ _roles: list[RbacRole] = await RbacRole.query_all(_query)
+ return _roles
+
+ def is_supervisors(self):
+ """
+ 检查是否是超级用户。
+
+ :return: 是超级用户放回 True,否则返回 False
+ """
+ return self.username in Supervisors
diff --git a/paste/rbac/test_rule.py b/paste/rbac/test_rule.py
new file mode 100644
index 0000000..70e840e
--- /dev/null
+++ b/paste/rbac/test_rule.py
@@ -0,0 +1,19 @@
+from paste.core import logging
+from paste.rbac.rbac_permission import RbacPermission
+from paste.rbac.rbac_rule import RbacRule
+from paste.rbac.rbac_user import RbacUser
+
+
+class TestRule(RbacRule):
+ """
+ 测试规则类。实际编写时要注意父类继承关系。
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.name = '测试规则'
+ self.data = self.dumps()
+
+ def run(self, rbac_user: RbacUser, rbac_permission: RbacPermission, *args, **kwargs):
+ logging.echo_log(f"正在运行规则:{self.name}.")
+ return True
diff --git a/paste/security/__init__.py b/paste/security/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/paste/security/cryp_rsa.py b/paste/security/cryp_rsa.py
new file mode 100644
index 0000000..72dca12
--- /dev/null
+++ b/paste/security/cryp_rsa.py
@@ -0,0 +1,59 @@
+import base64
+
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import padding, rsa
+
+
+def generate_rsa_keypair():
+ """
+ 生成RSA密钥对,并返回公钥的PEM格式字符串。
+
+ 此函数生成一个2048位的RSA密钥对,使用65537作为公钥指数。
+ 仅返回公钥部分的PEM编码字符串,私钥不返回以确保安全性。
+
+ 返回:
+ str: 公钥的PEM格式字符串,包含-----BEGIN PUBLIC KEY-----和-----END PUBLIC KEY-----头尾。
+
+ 示例:
+ -----BEGIN PUBLIC KEY-----
+ MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA...
+ -----END PUBLIC KEY-----
+
+ 注意:
+ - 私钥在此函数中生成但未返回,应由调用方安全存储。
+ - 不建议在生产环境中直接使用此函数生成密钥,应使用更安全的密钥管理服务。
+ """
+ private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+ public_key = private_key.public_key()
+
+ # 获取公钥的 PEM 格式
+ public_pem = public_key.public_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo
+ ).decode('utf-8')
+
+ return public_pem
+
+
+def rsa_encrypt_pkcs1_v1_5(public_key_pem, plaintext):
+ """
+ 使用 PKCS#1 v1.5 填充对字符串进行 RSA 加密。
+
+ :param public_key_pem: 公钥的 PEM 格式字符串
+ :param plaintext: 要加密的明文字符串
+ :return: Base64 编码的加密后字节串
+ """
+ # 加载公钥
+ public_key = serialization.load_pem_public_key(public_key_pem.encode())
+
+ # 将字符串编码为字节
+ plaintext_bytes = plaintext.encode('utf-8')
+
+ # 使用 PKCS#1 v1.5 填充进行加密
+ ciphertext = public_key.encrypt(
+ plaintext_bytes,
+ padding.PKCS1v15()
+ )
+
+ # 返回 Base64 编码的加密结果
+ return base64.b64encode(ciphertext).decode('utf-8')
\ No newline at end of file
diff --git a/paste/security/shash.py b/paste/security/shash.py
new file mode 100755
index 0000000..99b63e6
--- /dev/null
+++ b/paste/security/shash.py
@@ -0,0 +1,106 @@
+import hashlib
+import hmac
+import secrets
+from typing import Tuple
+
+SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+"""
+加盐字符。
+"""
+
+DEFAULT_PBKDF2_ITERATIONS = 260000
+"""
+加密迭代数。
+"""
+
+
+def gen_salt(length: int) -> str:
+ """
+ Generate a random string of SALT_CHARS with specified ``length``.
+ """
+ if length <= 0:
+ raise ValueError("Salt length must be positive")
+
+ return "".join(secrets.choice(SALT_CHARS) for _ in range(length))
+
+
+def _hash_internal(method: str, salt: str, password: str) -> Tuple[str, str]:
+ """
+ Internal password hash helper. Supports plaintext without salt, unsalted and salted passwords.
+ In case salted passwords are used hmac is used.
+ """
+ if method == "plain":
+ return password, method
+
+ salt = salt.encode("utf-8")
+ password = password.encode("utf-8")
+
+ if method.startswith("pbkdf2:"):
+ if not salt:
+ raise ValueError("Salt is required for PBKDF2")
+
+ args = method[7:].split(":")
+
+ if len(args) not in (1, 2):
+ raise ValueError("Invalid number of arguments for PBKDF2")
+
+ method = args.pop(0)
+ iterations = int(args[0] or 0) if args else DEFAULT_PBKDF2_ITERATIONS
+ return (
+ hashlib.pbkdf2_hmac(method, password, salt, iterations).hex(),
+ f"pbkdf2:{method}:{iterations}",
+ )
+
+ if salt:
+ return hmac.new(salt, password, method).hexdigest(), method
+
+ return hashlib.new(method, password).hexdigest(), method
+
+
+def generate_password_hash(pwd: str, method: str = "pbkdf2:sha256", salt_length: int = 16) -> str:
+ """
+ Hash a password with the given method and salt with a string of
+ the given length. The format of the string returned includes the method
+ that was used so that :func:`check_password_hash` can check the hash.
+
+ The format for the hashed string looks like this::
+
+ method$salt$hash
+
+ This method can **not** generate unsalted passwords but it is possible
+ to set param method='plain' in order to enforce plaintext passwords.
+ If a salt is used, hmac is used internally to salt the password.
+
+ If PBKDF2 is wanted it can be enabled by setting the method to
+ ``pbkdf2:method:iterations`` where iterations is optional::
+
+ pbkdf2:sha256:80000$salt$hash
+ pbkdf2:sha256$salt$hash
+
+ :param pwd: the password to hash
+ :param method: the hash method to use (one that hashlib supports). Can
+ optionally be in the format ``pbkdf2:method:iterations``
+ to enable PBKDF2
+ :param salt_length: the length of the salt in letters
+ """
+ salt = gen_salt(salt_length) if method != "plain" else ""
+ h, actual_method = _hash_internal(method, salt, pwd)
+ return f"{actual_method}${salt}${h}"
+
+
+def check_password_hash(pwd_hash: str, password: str) -> bool:
+ """
+ Check a password against a given salted and hashed password value.
+ In order to support unsalted legacy passwords this method supports
+ plain text passwords, md5 and sha1 hashes (both salted and unsalted).
+
+ Returns `True` if the password matched, `False` otherwise
+ :param pwd_hash: a hashed string like returned by
+ :func:`generate_password_hash`
+ :param password: the plaintext password to compare against the hash
+ """
+ if pwd_hash.count("$") < 2:
+ return False
+
+ method, salt, hash_val = pwd_hash.split("$", 2)
+ return hmac.compare_digest(_hash_internal(method, salt, password)[0], hash_val)
diff --git a/paste/security/token.py b/paste/security/token.py
new file mode 100644
index 0000000..86a3261
--- /dev/null
+++ b/paste/security/token.py
@@ -0,0 +1,89 @@
+import base64
+import datetime
+import uuid
+from typing import Optional
+
+import jwt
+
+SECRET_KEY = 'IrOYFjXtQPuofoXEpsR+2VsvjnaCWkPzvfmym1qNmcI='
+"""
+自定义密钥,生成方法见 generate_secret_key() 函数。该密钥应当在应用程序服务器启动时更新,并使用动态生成的密钥。
+"""
+
+PRIVATE_ISS = 'hǎi_tén_education_technology_co_ltd'
+"""
+签发者签名。
+"""
+
+
+def generate_secret_key():
+ """
+ 生成如 IrOYFjXtQPuofoXEpsR+2VsvjnaCWkPzvfmym1qNmcI= 的随机串。
+ """
+ return base64.b64encode(uuid.uuid4().bytes + uuid.uuid4().bytes).decode()
+
+
+def reset_secret_key():
+ """
+ 重置加密密钥。
+
+ :return: 加密密钥
+ """
+ global SECRET_KEY
+ SECRET_KEY = generate_secret_key()
+ return SECRET_KEY
+
+
+def get_secret_key():
+ """
+ 取得加密密钥。
+
+ :return: 加密密钥
+ """
+ global SECRET_KEY
+ return SECRET_KEY
+
+
+def encode_token(exp: Optional[datetime.datetime] = None, **kwargs):
+ """
+ 对用户信息加密生成令牌。
+
+ :param exp: 过期时间,不传默认 7 天过期
+ :param kwargs: 要加入到数据部分的参数
+ :return: 加密后的 token 内容
+ """
+ try:
+ iat = datetime.datetime.now(datetime.timezone.utc)
+ exp = exp if exp else iat + datetime.timedelta(days=7)
+ # 即 JWT 三部分中的载荷部分
+ # 过期时间最后是和 UTC 作比较,所以设置的时候使用 datetime.datetime.now()
+ payload = {
+ 'iss': PRIVATE_ISS, # 签发者
+ 'iat': iat, # 签发时间
+ 'exp': exp, # 过期时间,这里设置7天
+ 'params': {} # 参数,存放用户自定义数据
+ }
+
+ payload['params'].update(kwargs)
+
+ # 开始进行加密,返回字符串,传入例如密钥,指定加密算法
+ # encode 返回的是 bytes,需要 decode() 得到 str,不转换的话,在封装json的时候报错
+ auth_token = jwt.encode(payload, SECRET_KEY, algorithm='HS256')
+ if isinstance(auth_token, bytes):
+ auth_token = auth_token.decode()
+ return auth_token
+ except Exception as e:
+ raise e
+
+
+def decode_token(auth_token: str):
+ """
+ 解码 token 并从中提取用户信息,进行验证。
+
+ :param auth_token: 令牌
+ """
+ # 如果需要关闭过期时间的验证,可以在 options 中使用 verify_exp
+ # jwt.decode(auth_token, secret_key, issuer=private_iss, algorithms=['HS256'], options={'verify_exp': False})
+ # 传入了密钥和算法,和加密是对应的,因此密钥一定不要泄露
+ token_payload = jwt.decode(auth_token, SECRET_KEY, issuer=PRIVATE_ISS, algorithms=['HS256'])
+ return token_payload
diff --git a/paste/service/__init__.py b/paste/service/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/paste/service/daemonize.py b/paste/service/daemonize.py
new file mode 100644
index 0000000..b82f749
--- /dev/null
+++ b/paste/service/daemonize.py
@@ -0,0 +1,168 @@
+import atexit
+import os
+import signal
+import sys
+from typing import Callable
+
+import psutil
+
+from paste.util import ufile
+
+
+class DaemonizeService:
+ """
+ 驻内存服务。
+ """
+
+ def __init__(self, pid_file, name: str = '', stdin: str = None, stdout: str = None, stderr: str = None):
+ """
+ 初始化服务。
+
+ :param pid_file: pid 文件路径
+ :param name: 服务名称
+ :param stdin: 输入文件路径
+ :param stdout: 输出文件路径
+ :param stderr: 错误日志文件路径
+ """
+ self.start_callback = None
+ self.start_callback_args = ()
+ self.start_callback_kwargs = {}
+
+ self.term_callback = None
+ self.term_callback_args = ()
+ self.term_callback_kwargs = {}
+
+ self.pid_file = pid_file
+ self.name = name
+ self.stdin = stdin
+ self.stdout = stdout
+ self.stderr = stderr
+
+ def set_start_callback(self, callback, *args, **kwargs):
+ """
+ 设置启动回调函数。不返回参数。
+
+ :param callback: 回调函数
+ :param args: 回调参数
+ :param kwargs: 回调参数
+ """
+ self.start_callback = callback
+ self.start_callback_args = args
+ self.start_callback_kwargs = kwargs
+
+ def set_term_callback(self, callback, *args, **kwargs):
+ """
+ 设置终止回调函数。不返回参数。
+
+ :param callback: 回调函数
+ :param args: 回调参数
+ :param kwargs: 回调参数
+ """
+ self.term_callback = callback
+ self.term_callback_args = args
+ self.term_callback_kwargs = kwargs
+
+ def daemonize(self):
+ """
+ 设置和启动常驻服务程序。
+ """
+ if os.path.exists(self.pid_file):
+ _pid = int(ufile.read_to_buffer(self.pid_file).decode('utf8').strip())
+ if psutil.pid_exists(_pid):
+ raise RuntimeError(f"[{self.name}] 正在运行.")
+ else:
+ os.remove(self.pid_file)
+
+ try:
+ if os.fork() > 0:
+ raise SystemExit(0) # Parent exit
+ except OSError:
+ raise RuntimeError('创建子进程失败 #1.')
+
+ os.chdir(os.path.abspath(os.path.curdir))
+ os.umask(0)
+ os.setsid()
+
+ try:
+ if os.fork() > 0:
+ raise SystemExit(0)
+ except OSError:
+ raise RuntimeError('创建子进程失败 #2.')
+
+ sys.stdout.flush()
+ sys.stderr.flush()
+
+ # 替换 stdin, stdout, 和 stderr 的文件描述符
+ if self.stdin is not None:
+ with open(self.stdin, 'rb', 0) as file:
+ os.dup2(file.fileno(), sys.stdin.fileno())
+
+ if self.stdout is not None:
+ with open(self.stdout, 'ab', 0) as file:
+ os.dup2(file.fileno(), sys.stdout.fileno())
+
+ if self.stderr is not None:
+ with open(self.stderr, 'ab', 0) as file:
+ os.dup2(file.fileno(), sys.stderr.fileno())
+
+ # 写入 PID 文件
+ with open(self.pid_file, 'w') as file:
+ print(os.getpid(), file=file)
+
+ # 注册退出函数,进程退出时(包括异常退出)移除pid文件
+ atexit.register(lambda: os.remove(self.pid_file))
+ # 监听终止信号,绑定到处理程序
+ signal.signal(signal.SIGTERM, self.sigterm_handler)
+
+ def sigterm_handler(self, signum, frame):
+ """
+ 终止信号处理程序。这里是执行回调函数,处理服务退出前的准备。
+ :param signum: 信号代码
+ :param frame: 帧
+ :return: 程序退出码
+ """
+ if isinstance(self.term_callback, Callable):
+ self.term_callback(*self.term_callback_args, **self.term_callback_kwargs)
+ raise SystemExit(1)
+
+ def start(self):
+ """
+ 启动服务。
+ """
+ try:
+ self.daemonize()
+ if isinstance(self.start_callback, Callable):
+ self.start_callback(*self.start_callback_args, **self.start_callback_kwargs)
+ except RuntimeError as e:
+ print(e, file=sys.stderr)
+ raise SystemExit(1)
+
+ def stop(self):
+ """
+ 停止服务。
+ """
+ if os.path.exists(self.pid_file):
+ _pid = int(ufile.read_to_buffer(self.pid_file).decode('utf8').strip())
+ if psutil.pid_exists(_pid):
+ os.kill(_pid, signal.SIGTERM)
+ else:
+ os.remove(self.pid_file)
+ else:
+ print(f"[{self.name}] 尚未启动.", file=sys.stderr)
+ raise SystemExit(1)
+
+ def cli_run(self):
+ """
+ 命令行启动。
+ """
+ if len(sys.argv) != 2:
+ print(f"Please use like: python3 {sys.argv[0]} [start|stop].", file=sys.stderr)
+ raise SystemExit(1)
+
+ if sys.argv[1] == 'start':
+ self.start()
+ elif sys.argv[1] == 'stop':
+ self.stop()
+ else:
+ print(f"未知命令: {sys.argv[1]}", file=sys.stderr)
+ raise SystemExit(1)
diff --git a/paste/service/server.py b/paste/service/server.py
new file mode 100644
index 0000000..89d4088
--- /dev/null
+++ b/paste/service/server.py
@@ -0,0 +1,152 @@
+"""
+服务管理包。用这个包能根据服务的名称启动服务。
+"""
+
+import importlib
+import os.path
+from types import ModuleType
+from typing import Callable, Awaitable
+
+import pandas as pd
+
+from paste.core import aio_pool, config
+from paste.core.logging import logger_config_name
+from paste.util import ufile, udict
+
+service_flag = ['service_name', 'pid_file', 'start_service', 'start', 'stop']
+"""
+模块若要成为服务,必须同时具备以上属性。
+"""
+
+
+def get_services(full_log_path=True):
+ """
+ 取得所有服务列表及其运行状态,返回格式为 Pandas DataFrame。
+
+ :param full_log_path: 是否输出完整日志文件路径
+ :return: 服务运行状态数据框架
+ """
+ _service_info = []
+ _service_module = 'service'
+ _service_path = os.path.join(os.path.curdir, _service_module)
+ for _root, _dirs, _files in os.walk(_service_path):
+ for _file in _files:
+ _file_name, _ = os.path.splitext(_file)
+ if not _file_name.endswith('_service'):
+ continue
+ _mod_name = '.'.join([_service_module, _file_name])
+ _service_info.append(read_service_info(_mod_name, full_log_path))
+ _service_df = pd.DataFrame(_service_info)
+ return _service_df
+
+
+def is_service(service_module: ModuleType):
+ """
+ 检查模块是否是服务模块。
+
+ :param service_module:
+ :return:
+ """
+ _is_service = True
+ for _attr in service_flag:
+ if hasattr(service_module, _attr):
+ continue
+ _is_service = False
+ return _is_service
+
+
+def read_service_info(full_module_name: str, full_log_path=True):
+ """
+ 读取模块信息。
+
+ :param full_module_name: 完整模块路径
+ :param full_log_path: 是否输出完整日志文件路径
+ :return: 服务模块信息
+ """
+ _module = importlib.import_module(full_module_name)
+ assert is_service(_module), f"未设置关键属性,请确认是否为服务."
+ _pid = ''
+ _process_info = {}
+ _is_running = False
+ if os.path.exists(_module.pid_file):
+ _pid = ufile.read_to_buffer(_module.pid_file).decode('utf8').strip()
+ _process_info = aio_pool.process_info(int(_pid))
+ _is_running = True if _process_info else False
+
+ _configure = config.load_config()
+ _logger_config_name = getattr(_module, 'logger_config_name', logger_config_name)
+ if full_log_path:
+ _logger_file_path = os.path.abspath(udict.get_by_path(_configure, f"{_logger_config_name}.filename"))
+ else:
+ _logger_file_path = udict.get_by_path(_configure, f"{_logger_config_name}.filename")
+
+ # 代码中的 service_name 在这里作为服务描述
+ # 而 full_module_name 作为服务名称
+ _info = {
+ 'service': full_module_name,
+ 'service_name': _module.service_name,
+ 'is_running': os.path.exists(_module.pid_file),
+ 'logger_config_name': _logger_config_name,
+ 'logger_file_path': _logger_file_path,
+ 'pid_file': _module.pid_file,
+ 'pid': _pid,
+ 'process_name': _process_info.get('name', '') if _process_info else '',
+ 'cpu_usage': _process_info.get('cpu_usage', '') if _process_info else '',
+ 'memory_usage': _process_info.get('memory_usage', '') if _process_info else '',
+ 'running_time': _process_info.get('running_time', '') if _process_info else '',
+ }
+ return _info
+
+
+def start_service(full_module_name: str):
+ """
+ 在控制台启动服务,注意:当控制台关闭时,服务随即停止。
+
+ :param full_module_name: 完整模块路径。
+ :return: 操作状态,仅代表操作状态,并非立即启动服务
+ """
+ _module = importlib.import_module(full_module_name)
+ _start: Callable = getattr(_module, 'start_service', None)
+ if not isinstance(_start, Callable):
+ return
+ _result = _start()
+ # 处理异步方法执行
+ if isinstance(_result, Awaitable):
+ _runner = aio_pool.get_aio_runner()
+ _result = _runner(_result)
+
+
+def start(full_module_name: str):
+ """
+ 启动服务。
+
+ :param full_module_name: 完整模块路径。
+ :return: 操作状态,仅代表操作状态,并非立即启动服务
+ """
+ _module = importlib.import_module(full_module_name)
+ _start: Callable = getattr(_module, 'start', None)
+ if not isinstance(_start, Callable):
+ return
+ _result = _start()
+ # 处理异步方法执行
+ if isinstance(_result, Awaitable):
+ _runner = aio_pool.get_aio_runner()
+ _result = _runner(_result)
+
+
+def stop(full_module_name: str):
+ """
+ 停止服务。
+
+ :param full_module_name: 完整模块路径。
+ :return: 操作状态,仅代表操作状态,并非立即结束服务
+ """
+ _module = importlib.import_module(full_module_name)
+ _stop: Callable = getattr(_module, 'stop', None)
+ if not isinstance(_stop, Callable):
+ return
+ _result = _stop()
+ # 处理异步方法执行
+ if isinstance(_result, Awaitable):
+ _runner = aio_pool.get_aio_runner()
+ _result = _runner(_result)
diff --git a/paste/service/task_service.py b/paste/service/task_service.py
new file mode 100644
index 0000000..56f2e35
--- /dev/null
+++ b/paste/service/task_service.py
@@ -0,0 +1,496 @@
+"""
+系统服务,用于读取服务配置文件,启动或停止相关的服务。
+"""
+
+import asyncio
+import datetime
+import logging
+from asyncio import AbstractEventLoop, Task
+from enum import Enum
+from typing import Callable, Awaitable, Optional
+
+from dateutil.relativedelta import relativedelta
+
+from paste.core import aio_pool
+from paste.core.logging import echo_log
+from paste.db.baseadapter import BaseAdapter
+from paste.service.daemonize import DaemonizeService
+
+
+class PeriodType(Enum):
+ WEEKLY = "weekly"
+ MONTHLY = "monthly"
+ YEARLY = "yearly"
+ QUARTERLY = "quarterly"
+
+
+class TaskService:
+ """
+ 任务服务,专用于创建或停止服务。
+ """
+
+ task_event_loop: Optional[AbstractEventLoop] = None
+ """
+ 任务件循环对象。
+ """
+
+ def __init__(self, service_name: str = None, pid_file: str = None):
+ """
+ 构造函数。
+
+ :param service_name: 服务名称
+ :param pid_file: 进程 ID 文件路径
+ """
+
+ self.service_name = service_name
+ """
+ 服务名称。
+ """
+ if self.service_name is None:
+ self.service_name = '未命名服务'
+
+ self.pid_file = pid_file
+ """
+ PID 文件路径。
+ """
+ if self.pid_file is None:
+ _now = datetime.datetime.now()
+ self.pid_file = f'/tmp/task_service_{_now.strftime("%Y%m%d%H%M%S%f")}.pid'
+
+ self.task_list: list[Task] = []
+ """
+ 任务列表。
+ """
+
+ self._create_task_params: list[dict] = []
+ """
+ 创建任务的参数列表。
+ """
+
+ self.is_running = True
+ """
+ 是否允许运行。
+ """
+
+ self.log_next_time = True
+ """
+ 是否记录下次执行时间。
+ """
+
+ def event_loop(self):
+ """
+ 在需要调用的时间点取得事件循环对象。
+
+ :return: 事件循环对象
+ """
+ self.task_event_loop = aio_pool.get_aio_loop()
+ return self.task_event_loop
+
+ def create_delay_task(self, fn: Callable = None, delay: int = 60, **kwargs):
+ """
+ 创建延时任务工厂,每次任务完成后,将等待固定时长后继续执行。
+
+ :param fn: 要执行的任务函数
+ :param delay: 延时长度,单位:秒
+ :param kwargs fn 函数的参数
+ """
+
+ def log_next(log_next_time: bool):
+ if log_next_time and self.log_next_time:
+ echo_log(f"距下次执行:{fn.__name__} 还有:{delay} 秒.")
+
+ async def task_warp():
+ """
+ 任务包装器。
+ """
+ if fn is not None:
+ try:
+ _result = fn(**kwargs)
+ if isinstance(_result, Awaitable):
+ await _result
+ except Exception as e:
+ echo_log(msg=e, level=logging.ERROR, is_log_exc=True)
+
+ async def task_loop():
+ """
+ 任务循环。
+ """
+ if fn is None:
+ return
+
+ _log_next_time = True
+ _next_time = None
+ while self.is_running:
+ _delta_seconds = relativedelta(datetime.datetime.now(), _next_time).seconds if _next_time else 1
+
+ if _delta_seconds > 0:
+ await task_warp()
+ # 执行服务后,更新日期值
+ _next_time = datetime.datetime.now() + relativedelta(seconds=delay)
+ _log_next_time = True
+ else:
+ log_next(_log_next_time)
+ _log_next_time = False
+ await asyncio.sleep(0.5)
+ continue
+
+ _loop = self.event_loop()
+ _tsk: Task = _loop.create_task(task_loop())
+ return _tsk
+
+ def create_daily_task(self, fn: Callable = None, run_on_start=False,
+ year: int = None, month: int = None, day: int = None,
+ hour: int = None, minute: int = None, **kwargs):
+ """
+ 日常任务工厂,每次任务完成后,在第二天的固定时间继续执行。若设置的时间小于当前时间,则自动加一天。
+
+ :param fn: 要执行的任务函数
+ :param run_on_start: 是否在启动时立即运行一次,默认不运行
+ :param year: 年
+ :param month: 月
+ :param day: 日
+ :param hour: 时
+ :param minute: 分
+ :param kwargs fn 函数的参数
+ """
+ _now = datetime.datetime.now()
+ year = _now.year if year is None else year
+ month = _now.month if month is None else month
+ day = _now.day if day is None else day
+ hour = _now.hour if hour is None else hour
+ minute = _now.minute if minute is None else minute
+ _next_time = datetime.datetime(year, month, day, hour, minute, 0)
+ if relativedelta(_next_time, datetime.datetime.now()).seconds < 0:
+ # 小于当前时间的,自动加一天
+ _next_time = _next_time + relativedelta(days=1)
+
+ def log_next(log_next_time: bool, next_time: datetime.datetime):
+ if log_next_time and self.log_next_time:
+ _delta = relativedelta(next_time, datetime.datetime.now())
+ _d, _h, _m, _s = _delta.days, _delta.hours, _delta.minutes, _delta.seconds
+ echo_log(f"距下次执行:{fn.__name__} 还有:{_d} 天 {_h} 时 {_m} 分 {_s} 秒.")
+
+ async def task_warp():
+ """
+ 任务包装器。
+ """
+ if fn is not None:
+ try:
+ _result = fn(**kwargs)
+ if isinstance(_result, Awaitable):
+ await _result
+ except Exception as e:
+ echo_log(msg=e, level=logging.ERROR, is_log_exc=True)
+
+ async def task_loop(next_time: datetime.datetime):
+ """
+ 任务循环。
+
+ :param next_time: 下次执行时间
+ """
+ if fn is None:
+ return
+
+ _log_next_time = True
+ _run_on_start = run_on_start
+ while self.is_running:
+ _delta_seconds = relativedelta(datetime.datetime.now(), next_time).seconds
+
+ if _run_on_start or _delta_seconds > 0:
+ await task_warp()
+ _run_on_start = False
+ # 执行服务后,更新日期值
+ next_time = next_time + relativedelta(days=1)
+ _log_next_time = True
+ else:
+ log_next(_log_next_time, next_time)
+ _log_next_time = False
+ await asyncio.sleep(0.5)
+ continue
+
+ _loop = self.event_loop()
+ _tsk: Task = _loop.create_task(task_loop(next_time=_next_time))
+ return _tsk
+
+ def create_weekly_task(self, fn: Callable = None, weekday: int = 0,
+ hour: int = 0, minute: int = 0, run_on_start: bool = False, **kwargs):
+ """每周某天固定时间执行(周一=0,周日=6)"""
+ return self.create_periodic_task(
+ fn, PeriodType.WEEKLY, run_on_start=run_on_start,
+ hour=hour, minute=minute, weekday=weekday, **kwargs
+ )
+
+ def create_monthly_task(self, fn: Callable = None, day_of_month: int = 1,
+ hour: int = 0, minute: int = 0, run_on_start: bool = False, **kwargs):
+ """每月固定日期执行"""
+ return self.create_periodic_task(
+ fn, PeriodType.MONTHLY, run_on_start=run_on_start,
+ hour=hour, minute=minute, day_of_month=day_of_month, **kwargs
+ )
+
+ def create_yearly_task(self, fn: Callable = None, month: int = 1, day_of_month: int = 1,
+ hour: int = 0, minute: int = 0, run_on_start: bool = False, **kwargs):
+ """每年固定日期执行"""
+ return self.create_periodic_task(
+ fn, PeriodType.YEARLY, run_on_start=run_on_start,
+ hour=hour, minute=minute, month=month, day_of_month=day_of_month, **kwargs
+ )
+
+ def create_quarterly_task(self, fn: Callable = None, start_month: int = 1, day_of_month: int = 1,
+ hour: int = 0, minute: int = 0, run_on_start: bool = False, **kwargs):
+ """每季度固定日期执行(start_month: 1,4,7,10)"""
+ return self.create_periodic_task(
+ fn, PeriodType.QUARTERLY, run_on_start=run_on_start,
+ hour=hour, minute=minute, month=start_month, day_of_month=day_of_month, **kwargs
+ )
+
+ def create_periodic_task(self, fn: Callable = None, period_type: PeriodType = PeriodType.WEEKLY,
+ run_on_start: bool = False, hour: int = 0, minute: int = 0,
+ weekday: Optional[int] = None, # 0=周一, 6=周日,仅 weekly 有效
+ day_of_month: Optional[int] = None, # 1-31,仅 monthly/quarterly/yearly 有效
+ month: Optional[int] = None, # 1-12,仅 quarterly/yearly 有效
+ **kwargs
+ ):
+ """
+ 通用周期任务工厂。
+
+ :param fn: 任务函数
+ :param period_type: 周期类型 (weekly/monthly/yearly/quarterly)
+ :param run_on_start: 启动时是否立即运行一次
+ :param hour: 时 (0-23)
+ :param minute: 分 (0-59)
+ :param weekday: 星期几 (0=周一, 6=周日),仅 period_type=weekly 时使用
+ :param day_of_month: 每月几号 (1-31),仅 monthly/quarterly/yearly 时使用
+ :param month: 月份 (1-12),仅 quarterly/yearly 时使用(quarterly 时表示起始季度月份)
+ """
+
+ def get_next_run_time(now: datetime.datetime) -> Optional[datetime.datetime]:
+ """根据规则计算下一次执行时间"""
+
+ if period_type == PeriodType.WEEKLY:
+ if weekday is None:
+ raise ValueError("weekly 模式需要指定 weekday")
+
+ # 计算目标星期
+ days_ahead = (weekday - now.weekday()) % 7
+
+ # 如果今天就是目标星期
+ if days_ahead == 0:
+ target_time = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
+ # 如果今天的目标时间已过,则推迟到下周
+ if target_time <= now and not run_on_start:
+ days_ahead = 7
+ else:
+ return target_time
+
+ next_date = now + relativedelta(days=days_ahead)
+ return datetime.datetime(
+ next_date.year, next_date.month, next_date.day,
+ hour, minute, 0
+ )
+
+ elif period_type == PeriodType.MONTHLY:
+ if day_of_month is None:
+ raise ValueError("monthly 模式需要指定 day_of_month")
+
+ # 获取当前月份的最后一天
+ last_day = (now.replace(day=1) + relativedelta(months=1) - relativedelta(days=1)).day
+ target_day = min(day_of_month, last_day)
+
+ candidate = now.replace(day=target_day, hour=hour, minute=minute, second=0, microsecond=0)
+
+ # 如果候选时间已过,则下个月
+ if candidate <= now and not run_on_start:
+ next_month = now + relativedelta(months=1)
+ last_day_next = (next_month.replace(day=1) + relativedelta(months=1) - relativedelta(days=1)).day
+ target_day_next = min(day_of_month, last_day_next)
+ candidate = next_month.replace(day=target_day_next, hour=hour, minute=minute, second=0,
+ microsecond=0)
+
+ return candidate
+
+ elif period_type == PeriodType.QUARTERLY:
+ if day_of_month is None or month is None:
+ raise ValueError("quarterly 模式需要指定 day_of_month 和 month(起始季度月份)")
+
+ # 修正:计算季度的月份
+ q_months = []
+ for i in range(4):
+ qm = month + i * 3
+ if qm > 12:
+ qm -= 12
+ q_months.append(qm)
+
+ # 查找下一个季度月
+ target_month = None
+ target_year = now.year
+
+ for qm in sorted(q_months):
+ if qm > now.month:
+ target_month = qm
+ break
+
+ if target_month is None:
+ target_month = q_months[0]
+ target_year += 1
+
+ # 处理日期有效性
+ last_day = (datetime.datetime(target_year, target_month, 1) +
+ relativedelta(months=1) - relativedelta(days=1)).day
+ target_day = min(day_of_month, last_day)
+
+ candidate = datetime.datetime(target_year, target_month, target_day, hour, minute, 0)
+
+ if candidate <= now and not run_on_start:
+ # 跳到下个季度
+ candidate = candidate + relativedelta(months=3)
+
+ return candidate
+
+ elif period_type == PeriodType.YEARLY:
+ if day_of_month is None or month is None:
+ raise ValueError("yearly 模式需要指定 day_of_month 和 month")
+
+ # 检查年份
+ try:
+ candidate = datetime.datetime(now.year, month, day_of_month, hour, minute, 0)
+ except ValueError:
+ # 日期无效(如2月30日),取当月最后一天
+ last_day = (datetime.datetime(now.year, month, 1) +
+ relativedelta(months=1) - relativedelta(days=1)).day
+ candidate = datetime.datetime(now.year, month, last_day, hour, minute, 0)
+
+ if candidate <= now and not run_on_start:
+ try:
+ candidate = datetime.datetime(now.year + 1, month, day_of_month, hour, minute, 0)
+ except ValueError:
+ last_day = (datetime.datetime(now.year + 1, month, 1) +
+ relativedelta(months=1) - relativedelta(days=1)).day
+ candidate = datetime.datetime(now.year + 1, month, last_day, hour, minute, 0)
+
+ return candidate
+
+ async def task_warp():
+ if fn is not None:
+ try:
+ _result = fn(**kwargs)
+ if isinstance(_result, Awaitable):
+ await _result
+ except Exception as e:
+ echo_log(msg=e, level=logging.ERROR, is_log_exc=True)
+
+ async def task_loop():
+ if fn is None:
+ return
+
+ _log_next_time = True
+ _run_on_start = run_on_start
+ next_time = get_next_run_time(datetime.datetime.now())
+ while self.is_running:
+ now = datetime.datetime.now()
+ if _run_on_start or now >= next_time:
+ await task_warp()
+ _run_on_start = False
+ # 执行完后,基于当前时间重新计算下一次
+ next_time = get_next_run_time(now)
+ _log_next_time = True
+ else:
+ if _log_next_time and self.log_next_time:
+ delta = relativedelta(next_time, now)
+ echo_log(
+ f"距下次执行:{fn.__name__} 还有:{delta.days}天 {delta.hours}时 {delta.minutes}分 {delta.seconds}秒")
+ _log_next_time = False
+ await asyncio.sleep(1)
+
+ _loop = self.event_loop()
+ return _loop.create_task(task_loop())
+
+ def add_task(self, creator: Callable, fn: Callable, **kwargs):
+ """
+ 添加任务。注意:这里只是存储创建参数,直到任务启动前,才实际把任务创建出来。
+
+ :param creator: 任务工厂,对应:应延时任务工厂、日常任务工厂
+ :param fn: 任务函数,即任务对应的实际功能函数
+ :param kwargs: 任务函数的参数
+ """
+ _d = {
+ 'creator': creator, # 创建器,对应延时任务和日常任务
+ 'fn': fn, 'kwargs': kwargs # 任务函数及其参数
+ }
+ self._create_task_params.append(_d)
+
+ def rebuild_task_list(self):
+ """
+ 重建任务列表。
+ :return: 任务列表
+ """
+ self.task_list.clear()
+
+ # 遍历创建器列表,创建任务
+ for _ctp in self._create_task_params:
+ _creator: Callable = _ctp.get('creator')
+ _fn: Callable = _ctp.get('fn')
+ _kwargs: dict = _ctp.get('kwargs')
+ _task = _creator(_fn, **_kwargs)
+ self.task_list.append(_task)
+
+ return self.task_list
+
+ async def run_tasks(self):
+ """
+ 执行任务。支持多任务同时启动,如:预统计服务、设备数据同步服务等。
+ """
+ try:
+ self.rebuild_task_list()
+ echo_log(f'{self.service_name}启动成功.')
+ await asyncio.gather(*self.task_list)
+ except Exception as e:
+ echo_log(msg=e, level=logging.ERROR, is_log_exc=True)
+
+ def start_service(self, env_check: bool = True):
+ """
+ 以控制台服务方式启动服务。
+ """
+ echo_log(f'正在启动{self.service_name}...')
+
+ try:
+ if env_check:
+ # 检测 MySQL 服务是否正常
+ echo_log('检测 Database 服务...')
+ BaseAdapter.ping()
+
+ # 注意,这里是取得协程
+ _future = self.run_tasks()
+ # 开始执行任务事件循环
+ _loop = self.event_loop()
+ _loop.run_until_complete(_future)
+ except KeyboardInterrupt:
+ echo_log(msg='KeyboardInterrupt')
+ self.stop_service()
+ except Exception as e:
+ echo_log(msg=e, level=logging.ERROR, is_log_exc=True)
+
+ def stop_service(self):
+ """
+ 停止服务。
+ """
+ self.is_running = False
+ echo_log(f'{self.service_name}已停止.')
+
+ def start(self, env_check: bool = True):
+ """
+ 以驻内存服务方式启动服务。
+ """
+ ds = DaemonizeService(pid_file=self.pid_file, name=f'{self.service_name}')
+ ds.set_start_callback(self.start_service, env_check=env_check)
+ ds.set_term_callback(self.stop_service)
+ ds.start()
+
+ def stop(self, env_check: bool = True):
+ """
+ 停止驻内存服务。
+ """
+ ds = DaemonizeService(pid_file=self.pid_file, name=f'{self.service_name}')
+ ds.set_start_callback(self.start_service, env_check=env_check)
+ ds.set_term_callback(self.stop_service)
+ ds.stop()
diff --git a/paste/util/__init__.py b/paste/util/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/paste/util/encoder.py b/paste/util/encoder.py
new file mode 100644
index 0000000..ee62746
--- /dev/null
+++ b/paste/util/encoder.py
@@ -0,0 +1,185 @@
+import base64
+import binascii
+import datetime
+import decimal
+import json
+import re
+import zlib
+from typing import Union
+
+import numpy as np
+from pandas._libs.tslibs.nattype import NaTType
+
+from paste.db.baseadapter import BaseAdapter
+from paste.db.basemodel import LOCAL_DATETIME_FORMAT, LOCAL_DATE_FORMAT, LOCAL_TIME_FORMAT
+
+
+class JsonDumpsEncoder(json.JSONEncoder):
+ """
+ JSON 转字符串时对一些特殊类型进行转码(编码)方法。
+ """
+
+ def default(self, obj):
+ if isinstance(obj, NaTType):
+ return ''
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, (np.floating, decimal.Decimal)):
+ return float(obj)
+ elif isinstance(obj, bytes):
+ return obj.decode(encoding='utf-8', errors='ignore')
+ elif isinstance(obj, datetime.datetime):
+ return obj.strftime(LOCAL_DATETIME_FORMAT)
+ elif isinstance(obj, datetime.date):
+ return obj.strftime(LOCAL_DATE_FORMAT)
+ elif isinstance(obj, datetime.time):
+ return obj.strftime(LOCAL_TIME_FORMAT)
+ elif isinstance(obj, BaseAdapter):
+ return obj.to_dict()
+
+ return super().default(obj)
+
+
+class BaseX:
+ """
+ Base 编码解码方法,主要用于解码,针对编码数据自动检测编码类型。
+ 能根据编码方式自动选择解码方法,同时在解码后尝试执行标准 Zip 解压。
+ """
+
+ Type_Base16 = 'b16'
+ Type_Base32 = 'b32'
+ Type_Base64 = 'b64'
+ Type_Base85 = 'b85'
+
+ Decoders = {
+ Type_Base16: base64.b16decode,
+ Type_Base32: base64.b32decode,
+ Type_Base64: base64.b64decode,
+ Type_Base85: base64.b85decode,
+ }
+ """
+ 解码器。
+ """
+
+ @classmethod
+ def base_x_detect(cls, data: Union[bytes, str]):
+ """
+ 检测采用的内置 Base 编码种类。返回数据与以下编码方式对应::
+
+ 1、b16: Base16
+ 2、b32: Base32
+ 3、b64: Base64
+ 4、b85: Base85
+
+ :param data: Base 编码数据,允许为字节流或字符串
+ :return: 编码名称,全小写
+ """
+ if isinstance(data, bytes):
+ try:
+ data = data.decode()
+ except (UnicodeDecodeError, Exception):
+ return
+
+ try:
+ _reg = re.compile("^[0-9A-F=]+$")
+ if _reg.match(data) is not None:
+ return cls.Type_Base16
+ except (re.error, Exception):
+ pass
+
+ try:
+ _reg = re.compile("^[A-Z2-7=]+$")
+ if _reg.match(data) is not None:
+ return cls.Type_Base32
+ except (re.error, Exception):
+ pass
+
+ try:
+ _reg = re.compile("^[A-Za-z0-9+/=]+$")
+ if _reg.match(data) is not None:
+ return cls.Type_Base64
+ except (re.error, Exception):
+ pass
+
+ try:
+ _reg = re.compile("^[A-Za-z0-9!#$%&()*+-;<=>?@^_`{|}~']+$")
+ if _reg.match(data) is not None:
+ return cls.Type_Base85
+ except (re.error, Exception):
+ pass
+
+ @classmethod
+ def base_x_decode(cls, data: Union[bytes, str], base_type: str = None):
+ """
+ 自动检测编码种类后,解码 Base 编码数据。参数 base_type 与以下编码方式对应::
+
+ 1、b16: Base16
+ 2、b32: Base32
+ 3、b64: Base64
+ 4、b85: Base85
+
+ 若在解码过程中发生异常,则返回原始数据。
+
+ :param data: Base 编码数据,允许为字节流或字符串
+ :param base_type: Base 编码类型,如为 b64 则代表须用 Base64 解码
+ :return: 解码后数据
+ """
+ _res_data = b''
+ if isinstance(data, bytes):
+ try:
+ _tmp_data = data.decode()
+ _res_data = _tmp_data
+ except (UnicodeDecodeError, Exception):
+ return data
+ else:
+ _res_data = data
+
+ if base_type not in cls.Decoders:
+ # 检测 Base 编码种类
+ base_type = cls.base_x_detect(_res_data)
+
+ if base_type in cls.Decoders:
+ try:
+ # 尝试 BaseX 解码
+ _decoder = cls.Decoders.get(base_type)
+ _tmp_data = _decoder(_res_data)
+ _res_data = _tmp_data
+ except (binascii.Error, UnicodeDecodeError):
+ return data
+
+ return _res_data
+
+ @classmethod
+ def auto_decode_unzip(cls, data: Union[bytes, str], base_type: str = None):
+ """
+ 对参数尝试执行自动 BaseX 解码 和 Zip 解压::
+
+ 1、若能 BaseX 解码,则执行解码,否则保持原始数据不变。
+ 2、若能 Zip 解压,则执行解压,否则保持上一层数据不变。
+
+ 若各种解码方法都无法顺利解码,或在解码过程中发生异常,则返回原始数据。
+
+ 参数 base_type 与以下编码方式对应::
+
+ 1、b16: Base16
+ 2、b32: Base32
+ 3、b64: Base64
+ 4、b85: Base85
+
+ :param data: Base64 数据,允许为字节流或字符串
+ :param base_type: Base 编码类型,如为 b64 则代表须用 Base64 解码
+ :return: 解码后数据
+ """
+ # 尝试 BaseX 解码
+ _res_data = cls.base_x_decode(data, base_type)
+
+ try:
+ # 尝试 Zip 解压
+ _tmp_data = zlib.decompress(_res_data)
+ _res_data = _tmp_data
+ except (zlib.error, TypeError):
+ return _res_data
+
+ return _res_data
diff --git a/paste/util/pagination.py b/paste/util/pagination.py
new file mode 100644
index 0000000..4637c18
--- /dev/null
+++ b/paste/util/pagination.py
@@ -0,0 +1,125 @@
+"""
+基础分页程序,处理分页计算,后续应当扩展其功能。
+"""
+
+
+class Pagination:
+ """
+ 分页程序。
+ """
+
+ def __init__(self, row_count: int):
+ """
+ 初始化分页。
+
+ :param row_count: 总记录行数
+ """
+ self._offset = 0
+ """
+ 偏移量。
+ """
+
+ self._pages = -1
+ """
+ 总页数。
+ """
+
+ self._page_number = 1
+ """
+ 当前页码。
+ """
+
+ self.row_count = row_count
+ """
+ 数据行数。
+ """
+
+ self.page_size = 20
+ """
+ 每页显示的数据量。默认 20 行每页。
+ """
+
+ @property
+ def page_count(self):
+ """
+ 取得页数。该属性必须在调用 :meth:`.pages` 方法后调用,例如::
+
+ >>> self.pages()
+ >>> self.page_count
+
+ :return: 页数
+ """
+ return self._pages
+
+ @property
+ def page_number(self):
+ """
+ 取得当前页码。该属性必须在调用 :meth:`.number` 方法后调用, 例如::
+
+ >>> self.number(3)
+ >>> self.page_number
+
+ :return: 页码
+ """
+ return self._page_number
+
+ @property
+ def offset_size(self):
+ """
+ 取得偏移量。
+ """
+ return self._offset
+
+ def pages(self, page_size: int = 20):
+ """
+ 计算页数。
+
+ :param page_size: 每页行数,必须处于 [1, 1000] 区间中。若不在此区间,则强制转换到此区间。默认每页 20 条。
+ :return: 计算取得的页数。
+ """
+ page_size = 1 if page_size < 1 else page_size
+ page_size = 1000 if page_size > 1000 else page_size
+ self.page_size = page_size
+
+ if self.row_count == 0:
+ self._pages = 1
+ else:
+ _v1 = self.row_count / page_size
+ _v2 = self.row_count // page_size
+ self._pages = _v2 if _v1 == _v2 else _v2 + 1
+
+ return self._pages
+
+ def number(self, page_number: int):
+ """
+ 检查页码范围。
+
+ :param page_number: 页码
+ :return: 正确页码
+ """
+ _pages = self.pages(self.page_size)
+ self._page_number = 1 if page_number < 1 else page_number
+ self._page_number = _pages if self._page_number > _pages else self._page_number
+ return self._page_number
+
+ def offset(self, page_number: int):
+ """
+ 偏移量。
+
+ :param page_number: 页码
+ :return: 偏移量
+ """
+ self._offset = self.page_size * (page_number - 1)
+ return self._offset
+
+ def paging(self, page_number: int = 1, page_size: int = 20):
+ """
+ 分页计算,支持链式调用。
+
+ :params page_number 页码
+ :params page_size 每页显示的数量
+ :return self
+ """
+ self.pages(page_size=page_size)
+ self.offset(self.number(page_number))
+ return self
diff --git a/paste/util/pdf.py b/paste/util/pdf.py
new file mode 100644
index 0000000..c5f35cd
--- /dev/null
+++ b/paste/util/pdf.py
@@ -0,0 +1,63 @@
+import mimetypes
+import os
+import urllib.parse
+import urllib.request
+
+import weasyprint as wp
+
+from paste.core.logging import echo_log
+
+
+class Html2Pdf:
+ """
+ 将 HTML 内容转换为 PDF 文件。
+ """
+
+ @classmethod
+ def custom_url_fetcher(cls, url, timeout=30, **kwargs):
+ """
+ 自定义 URL 加载器,增加超时时间
+ """
+ # 处理 file:// URLs
+ if url.startswith('file://'):
+ parsed = urllib.parse.urlparse(url)
+ path = urllib.request.url2pathname(parsed.path)
+
+ if not os.path.exists(path):
+ raise ValueError(f"File not found: {path}")
+
+ _mime_type, _ = mimetypes.guess_type(path)
+ if not _mime_type:
+ _mime_type = 'application/octet-stream'
+
+ return {
+ 'mime_type': _mime_type,
+ 'encoding': 'binary',
+ 'filename': os.path.basename(path),
+ 'file_obj': open(path, 'rb'),
+ }
+
+ # 增加超时时间(默认是 30 秒)
+ return wp.default_url_fetcher(url, timeout=timeout, **kwargs)
+
+ @classmethod
+ def write_pdf(cls, content, output_pdf=None, base_url=""):
+ """
+ 将 HTML 转换为 PDF。
+
+ :param content: HTML 字符串
+ :param output_pdf: 输出的 PDF 文件路径,默认为空
+ :param base_url: 跨域默认地址
+ """
+ try:
+ # HTML 转换为 PDF
+ _html = wp.HTML(string=content, url_fetcher=cls.custom_url_fetcher, base_url=base_url)
+ _bytes = _html.write_pdf(output_pdf)
+ if output_pdf:
+ echo_log(f"PDF 已成功生成在: {output_pdf}.")
+ else:
+ echo_log(f"PDF 已成功生成.")
+ return _bytes
+ except Exception as e:
+ echo_log(f"转换失败: {e}")
+ raise e
diff --git a/paste/util/snow_id.py b/paste/util/snow_id.py
new file mode 100644
index 0000000..b00b2b8
--- /dev/null
+++ b/paste/util/snow_id.py
@@ -0,0 +1,126 @@
+"""
+雪花 ID 生成程序。
+"""
+
+import time
+import logging
+
+
+# 64位ID的划分
+WORKER_ID_BITS = 5
+DATACENTER_ID_BITS = 5
+SEQUENCE_BITS = 12
+
+# 最大取值计算
+MAX_WORKER_ID = -1 ^ (-1 << WORKER_ID_BITS) # 2**5-1 0b11111
+MAX_DATACENTER_ID = -1 ^ (-1 << DATACENTER_ID_BITS)
+
+# 移位偏移计算
+WORKER_ID_SHIFT = SEQUENCE_BITS
+DATACENTER_ID_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS
+TIMESTAMP_LEFT_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS + DATACENTER_ID_BITS
+
+# 序号循环掩码
+SEQUENCE_MASK = -1 ^ (-1 << SEQUENCE_BITS)
+
+# Twitter元年时间戳
+TW_EPOCH = 1288834974657
+
+ID_WORKER = None
+
+
+class InvalidSystemClock(Exception):
+ """
+ 时钟回拨异常
+ """
+ pass
+
+
+class IdWorker(object):
+ """
+ 用于生成 Snow ID。
+ """
+
+ @classmethod
+ def get_id_worker(cls, datacenter_id=1, worker_id=1, sequence=0):
+ """
+ 创建 Snow ID 对象。
+
+ :param datacenter_id: 数据中心(机器区域)ID
+ :param worker_id: 机器ID
+ :param sequence: 起始序号
+ """
+ global ID_WORKER
+ if ID_WORKER is None:
+ ID_WORKER = IdWorker(datacenter_id, worker_id, sequence)
+ return ID_WORKER
+
+ def __init__(self, datacenter_id, worker_id, sequence=0):
+ """
+ 初始化。
+
+ :param datacenter_id: 数据中心(机器区域)ID
+ :param worker_id: 机器ID
+ :param sequence: 起始序号
+ """
+ # sanity check
+ if worker_id > MAX_WORKER_ID or worker_id < 0:
+ raise ValueError('worker_id值越界')
+
+ if datacenter_id > MAX_DATACENTER_ID or datacenter_id < 0:
+ raise ValueError('datacenter_id值越界')
+
+ self.worker_id = worker_id
+ self.datacenter_id = datacenter_id
+ self.sequence = sequence
+
+ self.last_timestamp = -1 # 上次计算的时间戳
+
+ @staticmethod
+ def _gen_timestamp():
+ """
+ 生成整数时间戳。
+
+ :return:int timestamp
+ """
+ return int(time.time() * 1000)
+
+ def get_id(self):
+ """
+ 获取新ID。
+
+ :return: 新的 Snow ID
+ """
+ timestamp = self._gen_timestamp()
+
+ # 时钟回拨
+ if timestamp < self.last_timestamp:
+ logging.error(f"时钟正在向后倒转。拒绝请求直至 {self.last_timestamp}.")
+ raise InvalidSystemClock
+
+ if timestamp == self.last_timestamp:
+ self.sequence = (self.sequence + 1) & SEQUENCE_MASK
+ if self.sequence == 0:
+ timestamp = self._til_next_millis(self.last_timestamp)
+ else:
+ self.sequence = 0
+
+ self.last_timestamp = timestamp
+
+ new_id = ((timestamp - TW_EPOCH) << TIMESTAMP_LEFT_SHIFT) | (self.datacenter_id << DATACENTER_ID_SHIFT) | \
+ (self.worker_id << WORKER_ID_SHIFT) | self.sequence
+ return new_id
+
+ def _til_next_millis(self, last_timestamp):
+ """
+ 等到下一毫秒。
+ """
+ timestamp = self._gen_timestamp()
+ while timestamp <= last_timestamp:
+ timestamp = self._gen_timestamp()
+ return timestamp
+
+
+if __name__ == '__main__':
+ worker = IdWorker(1, 1, 0)
+ print(worker.get_id())
diff --git a/paste/util/svg.py b/paste/util/svg.py
new file mode 100644
index 0000000..76b06d7
--- /dev/null
+++ b/paste/util/svg.py
@@ -0,0 +1,704 @@
+import re
+from typing import Optional
+
+import svgwrite
+from svgwrite.container import Group
+from svgwrite.path import Path
+from svgwrite.shapes import Rect
+from svgwrite.text import Text
+
+from paste.util import ustr
+
+
+class TextRect(Group):
+ """
+ 可显示文本的矩形。
+ """
+
+ def __init__(self, text, insert, text_extra: dict = None, rect_extra: dict = None, **extra):
+ # 父类初始化
+ super().__init__(**extra)
+
+ self.text = text
+ """
+ 要显示的文本内容。
+ """
+
+ self.extra = extra
+ """
+ 组合扩展信息。
+ """
+
+ self.rectExtra = rect_extra if rect_extra is not None else {}
+ """
+ 外框扩展信息。
+ """
+
+ self.textExtra = text_extra if text_extra is not None else {}
+ """
+ 文本扩展信息。
+ """
+
+ self.rectInsert = insert
+ """
+ 整体位置参数,即外框的位置参数。
+ """
+
+ # 初始化文本尺寸
+ _fs = self.font_size
+
+ self.textInsert = self.text_pos
+ """
+ 文本位置参数。
+ """
+
+ # 文本初始化
+ self.textElement = Text(self.text, insert=self.textInsert, **self.textExtra)
+ # 矩形初始化
+ self.rectElement = Rect(insert=self.rectInsert, size=self.rect_size, **self.rectExtra)
+
+ # 加入元素
+ self.add(self.rectElement)
+ self.add(self.textElement)
+
+ @property
+ def font_size(self):
+ """
+ 从样式中识别字体大小,单位用像素,缺省 14px。
+
+ :return: 字体大小
+ """
+ _font_size = self.textExtra.get('font-size', self.extra.get('font-size', f"{14}px"))
+ self.textExtra['font-size'] = _font_size
+ _size = re.sub(r'\D', '', _font_size.strip())
+ return int(_size)
+
+ @property
+ def text_width(self):
+ """
+ 文本宽度(近似)。
+ """
+ total = len(self.text)
+ q_count = ustr.str_q_count(self.text)
+ return q_count * self.font_size + (total - q_count) * self.font_size * 0.5
+
+ @property
+ def text_height(self):
+ """
+ 文本高度(近似)。
+ """
+ return self.font_size * 1.2
+
+ @property
+ def rect_width(self):
+ """
+ 外框宽度。
+ """
+ return self.text_width + self.font_size * 1.5
+
+ @property
+ def rect_height(self):
+ """
+ 外框高度。
+ """
+ return self.text_height * 1.9
+
+ @property
+ def rect_size(self):
+ """
+ 外框尺寸。
+ """
+ return self.rect_width, self.rect_height
+
+ @property
+ def text_pos(self):
+ """
+ 文本位置。
+ """
+ return \
+ self.rectInsert[0] + (self.rect_width - self.text_width) * 0.5, \
+ self.rectInsert[1] + self.text_height * 1.25
+
+ def reposition(self, position: tuple):
+ """
+ 重新定位。
+
+ :param position: 位置坐标
+ """
+ self.rectInsert = position
+ self.rectElement.attribs['x'] = self.rectInsert[0]
+ self.rectElement.attribs['y'] = self.rectInsert[1]
+
+ self.textInsert = self.text_pos
+ self.textElement.attribs['x'] = self.textInsert[0]
+ self.textElement.attribs['y'] = self.textInsert[1]
+
+ def point_bottom(self):
+ """
+ 底部点。
+ """
+ return self.rectInsert[0] + self.rect_width / 2, self.rectInsert[1] + self.rect_size[1]
+
+ def point_top(self):
+ """
+ 顶部点。
+ """
+ return self.rectInsert[0] + self.rect_width / 2, self.rectInsert[1]
+
+ def point_left(self):
+ """
+ 左侧点。
+ """
+ return self.rectInsert[0], self.rectInsert[1] + self.rect_height / 2
+
+ def point_right(self):
+ """
+ 右侧点。
+ """
+ return self.rectInsert[0] + self.rect_width, self.rectInsert[1] + self.rect_height / 2
+
+ @classmethod
+ def horizontal_path(cls, start: tuple, end: tuple, **extra):
+ """
+ 生成水平方向连接线。
+
+ :param start: 起点坐标
+ :param end: 终点坐标
+ :param extra: 扩展参数
+ :return: 路径对象
+ """
+ _p_control = [
+ (start[0] + end[0]) * 0.5,
+ start[1]
+ ]
+
+ _p_center = [
+ (start[0] + end[0]) * 0.5,
+ (start[1] + end[1]) * 0.5
+ ]
+
+ _path = Path(**extra)
+
+ _path.push(['M', start])
+ _path.push(['Q', _p_control + _p_center])
+ _path.push(['T', end])
+
+ return _path
+
+ @classmethod
+ def vertical_path(cls, start: tuple, end: tuple, **extra):
+ """
+ 生成垂直方向连接线。
+
+ :param start: 起点坐标
+ :param end: 终点坐标
+ :param extra: 扩展参数
+ :return: 路径对象
+ """
+ _p_control = [
+ start[0],
+ (start[1] + end[1]) * 0.5
+ ]
+
+ _p_center = [
+ (start[0] + end[0]) * 0.5,
+ (start[1] + end[1]) * 0.5
+ ]
+
+ _path = Path(**extra)
+
+ _path.push(['M', start])
+ _path.push(['Q', _p_control + _p_center])
+ _path.push(['T', end])
+
+ return _path
+
+ def choose_point(self, sibling: 'TextRect'):
+ """
+ 选择与目标文本框的连线点。
+
+ 返回起点(tuple)在自生文本框上,终点(tuple)在目标文本框上。
+
+ :param sibling: 目标文本框
+ :return: 起点、终点、是否水平连线
+ """
+ _start = self.point_bottom()
+ _end = sibling.point_top()
+ _is_horizontal = True
+
+ if self.point_bottom()[1] > sibling.point_top()[1]:
+ if self.point_right()[0] < sibling.point_left()[0]:
+ _start = self.point_right()
+ _end = sibling.point_left()
+ _is_horizontal = False
+ elif self.point_left()[0] > sibling.point_right()[0]:
+ _start = self.point_left()
+ _end = sibling.point_right()
+ _is_horizontal = False
+ else:
+ _start = self.point_top()
+ _end = sibling.point_bottom()
+ _is_horizontal = True
+ elif self.point_top()[1] < sibling.point_bottom()[1]:
+ if self.point_right()[0] < sibling.point_left()[0]:
+ _start = self.point_right()
+ _end = sibling.point_left()
+ _is_horizontal = False
+ elif self.point_left()[0] > sibling.point_right()[0]:
+ _start = self.point_left()
+ _end = sibling.point_right()
+ _is_horizontal = False
+ else:
+ _start = self.point_bottom()
+ _end = sibling.point_top()
+ _is_horizontal = True
+ else:
+ if self.point_right()[0] < sibling.point_left()[0]:
+ _start = self.point_right()
+ _end = sibling.point_left()
+ _is_horizontal = False
+ elif self.point_left()[0] > sibling.point_right()[0]:
+ _start = self.point_left()
+ _end = sibling.point_right()
+ _is_horizontal = False
+ else:
+ _start = self.point_bottom()
+ _end = sibling.point_top()
+ _is_horizontal = True
+
+ return _start, _end, _is_horizontal
+
+ def connect(self, sibling: 'TextRect', **extra):
+ """
+ 取得连接路径。
+
+ :param sibling: 目标文本框
+ :param extra: 连线扩展参数
+ :return: 连接路径
+ """
+ _start, _end, _is_horizontal = self.choose_point(sibling)
+ if _is_horizontal:
+ return self.vertical_path(_start, _end, **extra)
+ else:
+ return self.horizontal_path(_start, _end, **extra)
+
+
+class RelationGraph:
+ """
+ SVG 关系图。
+
+ 根据 title 名称和 row_list 列表数据输出 svg 格式的关系图谱。
+ """
+
+ def __init__(self, filename: str = 'noname.svg'):
+ self.filename = filename
+
+ self.width = 800
+ """
+ 画布宽度。
+ """
+ self.height = 600
+ """
+ 画布高度。
+ """
+ self.vhSpace = 170
+ """
+ 内容垂直浮动空间。
+ """
+ self.lrSpace = 100
+ """
+ 内容左右留白空间。
+ """
+
+ self.titleTextExtra = {
+ 'font-size': '16px', 'fill': 'rgb(255, 255, 255)'
+ }
+ """
+ 标题文本样式。
+ """
+
+ self.titleRectExtra = {
+ 'rx': 10, 'ry': 10, 'fill': 'rgb(233, 72, 41)', 'fill-opacity': 1, 'stroke': 'rgb(233, 72, 41)'
+ }
+ """
+ 标题外框样式
+ """
+
+ self.textExtra = {
+ 'font-size': '14px', 'fill': 'rgb(255, 255, 255)'
+ }
+ """
+ 普通文本样式。
+ """
+
+ self.rectExtra = {
+ 'rx': 10, 'ry': 10, 'fill': 'rgb(65, 130, 164)', 'fill-opacity': 1, 'stroke': 'rgb(65, 130, 164)'
+ }
+ """
+ 普通文本外框样式。
+ """
+
+ self.pathExtra = {
+ 'fill': 'none', 'stroke': 'rgb(65, 130, 164)'
+ }
+ """
+ 连线样式。
+ """
+
+ self.drawing = svgwrite.Drawing(filename=self.filename)
+ """
+ 主绘图对象。
+ """
+
+ self.attribs = self.drawing.attribs
+ """
+ 图像样式。
+ """
+
+ self.save = self.drawing.save
+ """
+ 保存文件方法。
+ """
+
+ self.attribs.update({
+ 'width': self.width, 'height': self.height
+ })
+
+ self.titleTr: Optional[TextRect] = None
+ """
+ 标题文本框对象。
+ """
+
+ def draw(self, title: str, row_list: list[dict]):
+ """
+ 绘制图形。
+
+ :param row_list: 数据对象列表,必须包含 unit_name, unit_uscid, enterprise_id 三个字段
+ :param title: 标题文本
+ :return: 自身对象
+ """
+ # 重定设图像参数
+ self.attribs.update(self.attribs)
+
+ # 创建标题文本框
+ self.titleTr = TextRect(
+ text=title, insert=(0, 0), text_extra=self.titleTextExtra, rect_extra=self.titleRectExtra, **{
+ 'debug': False
+ }
+ )
+ self.titleTr.reposition((
+ (self.width - self.titleTr.rect_width) * 0.5, (self.height - self.titleTr.rect_height) * 0.5 - 20
+ ))
+ self.drawing.add(self.titleTr)
+
+ _tr_list: list[TextRect] = []
+ for _i, _row in enumerate(row_list):
+ # 遍历数据,初始创建所有的文本框,得到文本框尺寸信息
+ # 同时保留所有需要输出的数据
+ _text = f"{_row['short_name']} ({_row['count']})"
+ _tr = TextRect(
+ text=_text, insert=(0, 0), rect_extra=self.rectExtra, **self.textExtra, **{
+ 'debug': False,
+ 'data-name': _row['unit_name'],
+ 'data-uscid': _row['unit_uscid'],
+ 'data-enterprise-id': _row['enterprise_id'],
+ }
+ )
+ _tr_list.append(_tr)
+
+ _harf = int(len(_tr_list) / 2) if int(len(_tr_list) % 2) == 0 else int(len(_tr_list) / 2) + 1
+ _top_list = []
+ _lft_list = _tr_list[:_harf]
+ _rit_list = _tr_list[_harf:]
+ _btm_list = []
+
+ if len(_tr_list) >= 12:
+ _top_list = _lft_list[:2]
+ _lft_list = _lft_list[2:]
+ if len(_tr_list) >= 14:
+ _btm_list = _rit_list[-2:]
+ _rit_list = _rit_list[:-2]
+
+ # 遍历所有顶部文本框,重新定位
+ for _i, _tr in enumerate(_top_list):
+ if _i == 0:
+ _position = (
+ self.titleTr.point_top()[0] - _tr.rect_width - 15,
+ self.titleTr.point_top()[1] - self.vhSpace - _tr.rect_height - 15
+ )
+ else:
+ _position = (
+ self.titleTr.point_top()[0] + 15,
+ self.titleTr.point_top()[1] - self.vhSpace - _tr.rect_height - 15
+ )
+ _tr.reposition(_position)
+
+ # 遍历所有底部文本框,重新定位
+ for _i, _tr in enumerate(_btm_list):
+ if _i == 0:
+ _position = (
+ self.titleTr.point_bottom()[0] - _tr.rect_width - 15,
+ self.titleTr.point_bottom()[1] + self.vhSpace + _tr.rect_height + 15
+ )
+ else:
+ _position = (
+ self.titleTr.point_bottom()[0] + 15,
+ self.titleTr.point_bottom()[1] + self.vhSpace + _tr.rect_height + 15
+ )
+ _tr.reposition(_position)
+
+ _top = self.titleTr.point_top()[1] - self.vhSpace
+ # 遍历所有左则文本框,重新定位
+ for _tr in _lft_list:
+ _w = _tr.rect_width
+ _h = _tr.rect_height
+ _space = self.titleTr.point_bottom()[1] - self.titleTr.point_top()[1] + self.vhSpace * 2 + _h
+
+ _margin = 0
+ if len(_lft_list) > 1:
+ _margin = (_space - len(_lft_list) * _h) / (len(_lft_list) - 1)
+
+ _left = self.titleTr.point_left()[0] - _w - self.lrSpace
+ _position = (_left, _top)
+ _tr.reposition(_position)
+ if _tr.point_left()[0] < 20:
+ _left = 20
+ _position = (_left, _top)
+ _tr.reposition(_position)
+ _top += _h + _margin
+
+ _top = self.titleTr.point_top()[1] - self.vhSpace
+ # 遍历所有右侧文本框,重新定位
+ for _tr in _rit_list:
+ _w = _tr.rect_width
+ _h = _tr.rect_height
+ _space = self.titleTr.point_bottom()[1] - self.titleTr.point_top()[1] + self.vhSpace * 2 + _h
+
+ _margin = 0
+ if len(_rit_list) > 1:
+ _margin = (_space - len(_rit_list) * _h) / (len(_rit_list) - 1)
+
+ _left = self.titleTr.point_right()[0] + self.lrSpace
+ _position = (_left, _top)
+ _tr.reposition(_position)
+ if _tr.point_right()[0] > self.width - 20:
+ _left = self.width - _tr.rect_width - 20
+ _position = (_left, _top)
+ _tr.reposition(_position)
+
+ _top += _h + _margin
+
+ for _tr in _tr_list:
+ self.drawing.add(self.titleTr.connect(_tr, **self.pathExtra))
+
+ for _tr in _tr_list:
+ self.drawing.add(_tr)
+
+
+class EnterpriseGraph:
+ """
+ SVG 企业汇总信息图。
+
+ 根据 title 名称和 row_list 列表数据输出 svg 格式的关系图谱。
+ """
+
+ def __init__(self, filename: str = 'noname.svg'):
+ self.filename = filename
+
+ self.width = 800
+ """
+ 画布宽度。
+ """
+ self.height = 300
+ """
+ 画布高度。
+ """
+ self.vhSpace = 50
+ """
+ 内容垂直浮动空间。
+ """
+ self.lrSpace = 100
+ """
+ 内容左右留白空间。
+ """
+
+ self.titleTextExtra = {
+ 'font-size': '16px', 'fill': 'rgb(255, 255, 255)'
+ }
+ """
+ 标题文本样式。
+ """
+
+ self.titleRectExtra = {
+ 'rx': 10, 'ry': 10, 'fill': 'rgb(233, 72, 41)', 'fill-opacity': 1, 'stroke': 'rgb(233, 72, 41)'
+ }
+ """
+ 标题外框样式
+ """
+
+ self.textExtra = {
+ 'font-size': '14px', 'fill': 'rgb(255, 255, 255)'
+ }
+ """
+ 普通文本样式。
+ """
+
+ self.rectExtra = {
+ 'rx': 10, 'ry': 10, 'fill': 'rgb(65, 130, 164)', 'fill-opacity': 1, 'stroke': 'rgb(65, 130, 164)'
+ }
+ """
+ 普通文本外框样式。
+ """
+
+ self.pathExtra = {
+ 'fill': 'none', 'stroke': 'rgb(65, 130, 164)'
+ }
+ """
+ 连线样式。
+ """
+
+ self.drawing = svgwrite.Drawing(filename=self.filename)
+ """
+ 主绘图对象。
+ """
+
+ self.attribs = self.drawing.attribs
+ """
+ 图像样式。
+ """
+
+ self.save = self.drawing.save
+ """
+ 保存文件方法。
+ """
+
+ self.attribs.update({
+ 'width': self.width, 'height': self.height
+ })
+
+ self.titleTr: Optional[TextRect] = None
+ """
+ 标题文本框对象。
+ """
+
+ def draw(self, title: str, data_item: dict):
+ """
+ 绘制图形。
+
+ :param data_item: 数据项字典,中文名称:数据值
+ :param title: 标题文本
+ :return: 自身对象
+ """
+ # 重定设图像参数
+ self.attribs.update(self.attribs)
+
+ # 创建标题文本框
+ self.titleTr = TextRect(
+ text=title, insert=(0, 0), text_extra=self.titleTextExtra, rect_extra=self.titleRectExtra, **{
+ 'debug': False
+ }
+ )
+ self.titleTr.reposition((
+ (self.width - self.titleTr.rect_width) * 0.5, (self.height - self.titleTr.rect_height) * 0.5 - 20
+ ))
+ self.drawing.add(self.titleTr)
+
+ _tr_list: list[TextRect] = []
+ for _key, _val in data_item.items():
+ # 遍历数据,初始创建所有的文本框,得到文本框尺寸信息
+ # 同时保留所有需要输出的数据
+ _text = f"{_key}:{_val}"
+ _tr = TextRect(
+ text=_text, insert=(0, 0), rect_extra=self.rectExtra, **self.textExtra, **{
+ 'debug': False,
+ }
+ )
+ _tr_list.append(_tr)
+
+ _harf = int(len(_tr_list) / 2) if int(len(_tr_list) % 2) == 0 else int(len(_tr_list) / 2) + 1
+ _top_list = []
+ _lft_list = _tr_list[:_harf]
+ _rit_list = _tr_list[_harf:]
+ _btm_list = []
+
+ if len(_tr_list) >= 12:
+ _top_list = _lft_list[:2]
+ _lft_list = _lft_list[2:]
+ if len(_tr_list) >= 14:
+ _btm_list = _rit_list[-2:]
+ _rit_list = _rit_list[:-2]
+
+ # 遍历所有顶部文本框,重新定位
+ for _key, _tr in enumerate(_top_list):
+ if _key == 0:
+ _position = (
+ self.titleTr.point_top()[0] - _tr.rect_width - 15,
+ self.titleTr.point_top()[1] - self.vhSpace - _tr.rect_height - 15
+ )
+ else:
+ _position = (
+ self.titleTr.point_top()[0] + 15,
+ self.titleTr.point_top()[1] - self.vhSpace - _tr.rect_height - 15
+ )
+ _tr.reposition(_position)
+
+ # 遍历所有底部文本框,重新定位
+ for _key, _tr in enumerate(_btm_list):
+ if _key == 0:
+ _position = (
+ self.titleTr.point_bottom()[0] - _tr.rect_width - 15,
+ self.titleTr.point_bottom()[1] + self.vhSpace + _tr.rect_height + 15
+ )
+ else:
+ _position = (
+ self.titleTr.point_bottom()[0] + 15,
+ self.titleTr.point_bottom()[1] + self.vhSpace + _tr.rect_height + 15
+ )
+ _tr.reposition(_position)
+
+ _top = self.titleTr.point_top()[1] - self.vhSpace
+ # 遍历所有左则文本框,重新定位
+ for _tr in _lft_list:
+ _w = _tr.rect_width
+ _h = _tr.rect_height
+ _space = self.titleTr.point_bottom()[1] - self.titleTr.point_top()[1] + self.vhSpace * 2 + _h
+
+ _margin = 0
+ if len(_lft_list) > 1:
+ _margin = (_space - len(_lft_list) * _h) / (len(_lft_list) - 1)
+
+ _left = self.titleTr.point_left()[0] - _w - self.lrSpace
+ _position = (_left, _top)
+ _tr.reposition(_position)
+ if _tr.point_left()[0] < 20:
+ _left = 20
+ _position = (_left, _top)
+ _tr.reposition(_position)
+ _top += _h + _margin
+
+ _top = self.titleTr.point_top()[1] - self.vhSpace
+ # 遍历所有右侧文本框,重新定位
+ for _tr in _rit_list:
+ _w = _tr.rect_width
+ _h = _tr.rect_height
+ _space = self.titleTr.point_bottom()[1] - self.titleTr.point_top()[1] + self.vhSpace * 2 + _h
+
+ _margin = 0
+ if len(_rit_list) > 1:
+ _margin = (_space - len(_rit_list) * _h) / (len(_rit_list) - 1)
+
+ _left = self.titleTr.point_right()[0] + self.lrSpace
+ _position = (_left, _top)
+ _tr.reposition(_position)
+ if _tr.point_right()[0] > self.width - 20:
+ _left = self.width - _tr.rect_width - 20
+ _position = (_left, _top)
+ _tr.reposition(_position)
+
+ _top += _h + _margin
+
+ for _tr in _tr_list:
+ self.drawing.add(self.titleTr.connect(_tr, **self.pathExtra))
+
+ for _tr in _tr_list:
+ self.drawing.add(_tr)
diff --git a/paste/util/tail_read.py b/paste/util/tail_read.py
new file mode 100644
index 0000000..94c0daf
--- /dev/null
+++ b/paste/util/tail_read.py
@@ -0,0 +1,164 @@
+import os
+from typing import Optional
+
+from paste.core import config
+
+
+class TailRead:
+ """
+ 文件逆向读取器。
+ 主要针对读取日志文件。当遇到大日志文件时,需要从后向前读取,这样读取的速度更快。
+ 当日志文件动态增加时,再正向读取,此时仅读取差异内容,实现小数据量交互。
+ """
+
+ @classmethod
+ def logReader(cls, log_fn: str = None):
+ """
+ 取得配置文件设置的日志文件读取器。
+
+ :param log_fn: 日志文件名
+ :return: 默认日志文件读取器
+ """
+ if log_fn is None:
+ log_fn = config.get_config('logger.filename')
+ return TailRead(fn=log_fn)
+
+ def __init__(self, fn: str):
+ self.file_name = fn
+ """
+ 要读取的文件名。
+ """
+
+ self.file_io = open(self.file_name, 'rb')
+ """
+ 文件 IO 对象。
+ """
+
+ self.current_position: Optional[int] = None
+ """
+ 当前读取点位置。
+ """
+
+ _f_size = self.size()
+ if _f_size > 1:
+ # 移动到文件末尾
+ self.file_io.seek(_f_size - 1)
+ else:
+ self.file_io.seek(0)
+
+ def size(self):
+ """
+ 取得文件大小。
+
+ :return: 文件大小
+ """
+ return os.path.getsize(self.file_name)
+
+ def readTail(self, lines: int = 100):
+ """
+ 从文件中,逆向读取 lines 行。读取结束后,将读取定位移动到所有读取到的字节的最后。
+
+ :param lines: 读取的行数
+ :return: 读取到的数据,读取完成点
+ """
+ _buffer: bytes = b''
+ _c_pos = self.file_io.tell()
+
+ # 从当前位置读取一位,判断是否是回车
+ # 若是,则增加一行,确保读取足够的行数
+ _byte = self.file_io.read(1)
+ if _byte == b'\n':
+ lines += 1
+ # 重新回到原始位置
+ self.file_io.seek(_c_pos)
+
+ _r_pos = _c_pos
+ while lines > 0:
+ # 读取一个字节
+ _byte = self.file_io.read(1)
+ if _byte == b'':
+ # 无数据,退出
+ break
+ if _byte == b'\n':
+ # 减少行数
+ lines -= 1
+ # 逆向前移
+ _r_pos -= 1
+ if _r_pos <= 0:
+ # 超出第一位时,退出
+ break
+ self.file_io.seek(_r_pos)
+ # 加入缓存
+ _buffer = _byte + _buffer
+
+ # 扣除首字节回车符号
+ if _buffer[0:1] == b'\n':
+ _buffer = _buffer[1:]
+
+ # 第一位已经读取,因此正向移动一位
+ self.current_position = _c_pos + 1
+ self.file_io.seek(self.current_position)
+ return _buffer, self.current_position
+
+ def readLines(self, lines: int = 100, crt_pos: int = None):
+ """
+ 读取文件数据,默认读取 100 行。当不传入 crt_pos 时逆向读取,传入时正向读取。
+ 具有动态方向,确保第一次是最大量读取,以后每次都是增量读取,减少传递的数据量。
+
+ :param lines: 要读取的行数
+ :param crt_pos: 当前读取位置
+ :return:
+ """
+ _buffer: bytes = b''
+ if crt_pos is None:
+ # 参数 cur_pos 为 None 时,逆向读取
+ self.file_io.seek(self.size()-1)
+ _buffer, crt_pos = self.readTail(lines)
+ else:
+ # 参数 cur_pos 有值时,正向读取
+ self.file_io.seek(crt_pos)
+ while lines:
+ _bytes = self.file_io.readline()
+ if _bytes == b'':
+ # 无数据,退出
+ break
+ else:
+ # 减少行数
+ lines -= 1
+ # 加入缓存
+ _buffer += _bytes
+ crt_pos = self.file_io.tell()
+
+ self.current_position = crt_pos
+ return _buffer, self.current_position
+
+ def read(self, lines: int = 200, crt_pos: int = None):
+ """
+ 读取文件数据。注意::
+
+ 1、首次读取时 crt_pos 应为 None 此时逆向读取,返回读取到的数据流和读取点位置。
+ 2、当有 crt_pos 参数时,先检查文件是否发生了变化,若文件变大,则正向读取增量部分,若文件变小则置空。
+ 3、若有 crt_pos 且文件没有发生变化,则返回空字节流,读取位置不变。
+
+ :param lines: 要读取的最大行数,默认 200 行
+ :param crt_pos: 当前读取位置,为 None 时逆向读取,否则正向读取
+ :return: 读取到的字节流
+ """
+ _buffer = b''
+ if crt_pos is None:
+ # 参数 cur_pos 为 None 时,逆向读取
+ _buffer, crt_pos = self.readLines(lines, crt_pos=crt_pos)
+ else:
+ # 参数 cur_pos 有值时
+ # 检查文件是否发生了变化
+ _log_size = self.size()
+ if _log_size > crt_pos + 1:
+ # 内容增加,继续正向读取
+ _buffer, crt_pos = self.readLines(lines, crt_pos=crt_pos)
+ elif _log_size < crt_pos:
+ # 内容减少,置空读取位置
+ # 置空后,再次调用本函数,执行逆向读取
+ crt_pos = None
+
+ self.current_position = crt_pos
+ return _buffer, self.current_position
diff --git a/paste/util/udict.py b/paste/util/udict.py
new file mode 100644
index 0000000..7f7cacd
--- /dev/null
+++ b/paste/util/udict.py
@@ -0,0 +1,40 @@
+from typing import Any, Optional, Dict
+
+
+def get_with_default(dict_obj: dict, key: Any, default: Optional[Any] = None):
+ """
+ 从字典中取得对应的值,若为 None,则返回默认值。
+ 注意,字典自带 get 方法是当 key 存在,则返回对应的值,无论是否为 None;
+ 而该方法是无论 key 是否存在,只要值为 None 均返回默认值。
+
+ :param dict_obj: 字典对象
+ :param key: 键
+ :param default: 默认值
+ """
+ _val = dict_obj.get(key, default)
+ if _val is None:
+ _val = default
+ return _val
+
+
+def get_by_path(dict_obj: Dict[str, Any], path: str, default: Optional[Any] = None):
+ """
+ 按路径取得字典中的数据。要求路径指向的也必须是字典,除最后一项。
+
+ :param dict_obj: 字典对象
+ :param path: 字典中的 key 路径,以"."号分隔
+ :param default: 默认值
+ :return:
+ """
+ _dict: Optional[Dict[str, Any]] = dict_obj
+ _keys = path.split(".")
+
+ if len(_keys) > 1:
+ # 遍历到倒数第二项
+ for _key in _keys[:-1]:
+ _dict = _dict.get(_key, None)
+ if not isinstance(_dict, dict):
+ return default
+
+ # 返回最后一项内容
+ return _dict.get(_keys[-1], default)
\ No newline at end of file
diff --git a/paste/util/ufile.py b/paste/util/ufile.py
new file mode 100644
index 0000000..891ef0b
--- /dev/null
+++ b/paste/util/ufile.py
@@ -0,0 +1,293 @@
+import base64
+import datetime
+import io
+import os
+import re
+import unicodedata
+from typing import Optional, IO, Union
+
+import cv2
+import numpy as np
+from PIL import Image
+
+file_types = {
+ 'jpeg': (b'\xFF\xD8\xFF', b'\xff\xd8\xff'),
+ 'png': (b'\x89PNG',),
+ 'gif': (b'GIF8',),
+ 'bmp': (b'BM',),
+ 'tiff': (b'II*\x00', b'MM\x00*'),
+ 'webp': (b'RIFF\x00\x00\x00\x00WEBP',),
+ 'ico': (b'\x00\x00\x01\x00',),
+ 'psd': (b'8BPS',),
+ 'svg': (b'