Files
2026-06-02 16:26:10 +08:00

213 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import base64
import datetime
from io import BytesIO
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.dates import MonthLocator, DateFormatter
from scipy.interpolate import make_interp_spline
from paste import chart
from paste.util import ufont
def gen_lines(data_dict: Dict[str, pd.Series], color_palette: str = 'BuPu', dpi: int = 128) -> str:
"""
生成多条折线图,用于对比不同年份的时间序列数据(如月度数据)。
数据结构要求:
- data_dict: Dict[str, pd.Series]
- key: 字符串,表示年份(如 '2022', '2023'),必须可转为整数
- value: pd.Series,索引为日期(datetime-like),值为数值(如销售额、访问量等)
- 示例:
{
'2022': pd.Series([100, 120, 90], index=pd.date_range('2022-01-01', periods=3, freq='M')),
'2023': pd.Series([110, 130, 95], index=pd.date_range('2023-01-01', periods=3, freq='M'))
}
功能说明:
- 自动识别当前年份,并高亮显示其曲线(加粗、加圆点标记)
- 使用每两个月一次的主刻度,格式化为 'MM-DD'
- 保存为 base64 编码的 SVG 图像,适用于 Web 前端直接嵌入
:param data_dict: 年度时间序列数据字典,结构如上
:param color_palette: Seaborn 色盘名称,默认 'BuPu'
:param dpi: 输出图像分辨率,默认 128
:return: base64 编码的 SVG 图像 Data URL 字符串
"""
# === 颜色准备 ===
colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette)
if len(colors) == 0:
raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效")
# === 字体设置 ===
available_font = ufont.get_fonts()
plt.rcParams['font.sans-serif'] = list(available_font)
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题
# === 创建画布 ===
plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi)
ax = plt.gca()
# === 绘制每条折线 ===
current_year = str(datetime.date.today().year)
for i, year in enumerate(data_dict.keys()):
s = data_dict[year]
# 检查数据有效性
if len(s) == 0:
raise ValueError(f"数据序列 {year} 为空,请检查输入")
# 判断是否为当前年份,决定样式
is_current_year = (year == current_year)
linewidth = 2.5 if is_current_year else 1.5
marker = 'o' if is_current_year else None
label = f'{year}'
ax.plot(
s.index, s.values,
color=colors[i],
linewidth=linewidth,
alpha=0.9,
label=label,
marker=marker,
markersize=4,
markevery=30 # 每30个点显示一个标记,避免密集
)
# === 智能日期刻度 ===
ax.xaxis.set_major_locator(MonthLocator(bymonth=range(1, 13, 2))) # 每两个月一个主刻度
ax.xaxis.set_major_formatter(DateFormatter('%m-%d')) # 格式化为 MM-DD
# === 网格与边框美化 ===
ax.grid(True, which='major', linestyle='--', alpha=0.6)
ax.grid(True, which='minor', linestyle=':', alpha=0.3)
for spine in ['top', 'right']:
ax.spines[spine].set_visible(False)
ax.spines['left'].set_color('#d9d9d9')
ax.spines['bottom'].set_color('#d9d9d9')
# === 图例 ===
legend = ax.legend(loc='upper right', frameon=True, framealpha=0.7, edgecolor='#f0f0f0')
legend.get_frame().set_linewidth(1)
# === 输出为 base64 SVG ===
buffer = BytesIO()
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
plt.close()
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
img_base64 = f"data:image/svg+xml;base64,{image_base64}"
return img_base64
def gen_splines(data_dict: Dict[str, pd.Series], total: pd.Series, x_labels: Optional[List[Union[str, int]]] = None,
color_palette: str = 'BuPu', dpi: int = 128) -> str:
"""
生成平滑曲线图,适用于非时间序列数据(如按类别排序的数值),通过样条插值实现曲线平滑。
数据结构要求:
- data_dict: Dict[str, pd.Series]
- key: 字符串,表示数据系列名称(如 'A产品', 'B产品'
- value: pd.Series,索引为数值型(如 1~12),值为观测值(如销量)
- 示例:
{
'A产品': pd.Series([10, 15, 12, 18], index=[1,2,3,4]),
'B产品': pd.Series([8, 14, 16, 20], index=[1,2,3,4])
}
- total: pd.Series
- 索引必须与 data_dict 的 key 一致
- 值为每个系列的总和(用于图例标注)
- 示例:pd.Series({'A产品': 55, 'B产品': 58})
- x_labels: List[Union[str, int]], 可选
- 用于自定义 X 轴刻度标签(通常为原始索引值)
- 若为 None,则使用插值后的密集点(不显示原始刻度)
- 若提供,则在绘图后手动设置刻度,避免插值后刻度混乱
功能说明:
- 按 total 值降序排列绘图顺序(重要数据优先显示)
- 使用三次样条插值(k=3)生成平滑曲线
- 图例中显示原始总值(如 "A产品(55"
:param data_dict: 各系列的原始观测数据,结构如上
:param total: 每个系列的总和,用于排序和图例标注
:param x_labels: 可选,原始 X 轴标签列表(如 [1,2,3,4]),用于控制刻度显示
:param color_palette: Seaborn 色盘名称
:param dpi: 输出分辨率
:return: base64 编码的 SVG 图像 Data URL 字符串
"""
# === 颜色准备 ===
colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette)
if len(colors) == 0:
raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效")
# === 字体设置 ===
available_font = ufont.get_fonts()
plt.rcParams['font.sans-serif'] = list(available_font)
plt.rcParams['axes.unicode_minus'] = False
# === 创建画布 ===
plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi)
ax = plt.gca()
# === 按总值排序,确保重要数据优先显示 ===
total_sorted = total.sort_values(ascending=False)
# === 绘制每条平滑曲线 ===
for i, series_name in enumerate(total_sorted.index):
s = data_dict[series_name]
# 检查数据完整性
if len(s) < 2:
raise ValueError(f"系列 {series_name} 数据点少于2个,无法进行插值")
x = s.index.values.astype(float) # 确保为数值类型
y = s.values.astype(float)
# 插值:生成 100 个平滑点(线性插值不足以平滑,样条更优)
x_new = np.linspace(x.min(), x.max(), 100)
spline = make_interp_spline(x, y, k=3) # 三次样条插值
y_smooth = spline(x_new)
# 转回 Series 便于绘图
s_smooth = pd.Series(y_smooth, index=x_new)
ax.plot(
s_smooth.index, s_smooth.values,
color=colors[i],
linewidth=2,
alpha=0.9,
label=f'{series_name}{total_sorted[series_name]}', # 图例中显示总值
markersize=4,
markevery=30 # 标记稀疏,避免视觉干扰
)
# === 手动设置 X 轴刻度(若提供)===
if x_labels is not None:
# 确保 x_labels 与原始 x 一致(可选验证)
plt.xticks(ticks=x_labels, labels=x_labels)
# === 网格与边框 ===
ax.grid(True, which='major', linestyle='--', alpha=0.6)
ax.grid(True, which='minor', linestyle=':', alpha=0.3)
for spine in ['top', 'right']:
ax.spines[spine].set_visible(False)
ax.spines['left'].set_color('#d9d9d9')
ax.spines['bottom'].set_color('#d9d9d9')
# === 图例 ===
legend = ax.legend(loc='upper left', frameon=True, framealpha=0.7, edgecolor='#f0f0f0')
legend.get_frame().set_linewidth(1)
# === 输出为 base64 SVG ===
buffer = BytesIO()
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
plt.close()
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
img_base64 = f"data:image/svg+xml;base64,{image_base64}"
return img_base64