213 lines
8.2 KiB
Python
213 lines
8.2 KiB
Python
import base64
|
||
import datetime
|
||
from io import BytesIO
|
||
from typing import Dict, List, Optional, Union
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from matplotlib import pyplot as plt
|
||
from matplotlib.dates import MonthLocator, DateFormatter
|
||
from scipy.interpolate import make_interp_spline
|
||
|
||
from paste import chart
|
||
from paste.util import ufont
|
||
|
||
|
||
def gen_lines(data_dict: Dict[str, pd.Series], color_palette: str = 'BuPu', dpi: int = 128) -> str:
|
||
"""
|
||
生成多条折线图,用于对比不同年份的时间序列数据(如月度数据)。
|
||
|
||
数据结构要求:
|
||
- data_dict: Dict[str, pd.Series]
|
||
- key: 字符串,表示年份(如 '2022', '2023'),必须可转为整数
|
||
- value: pd.Series,索引为日期(datetime-like),值为数值(如销售额、访问量等)
|
||
- 示例:
|
||
{
|
||
'2022': pd.Series([100, 120, 90], index=pd.date_range('2022-01-01', periods=3, freq='M')),
|
||
'2023': pd.Series([110, 130, 95], index=pd.date_range('2023-01-01', periods=3, freq='M'))
|
||
}
|
||
|
||
功能说明:
|
||
- 自动识别当前年份,并高亮显示其曲线(加粗、加圆点标记)
|
||
- 使用每两个月一次的主刻度,格式化为 'MM-DD'
|
||
- 保存为 base64 编码的 SVG 图像,适用于 Web 前端直接嵌入
|
||
|
||
:param data_dict: 年度时间序列数据字典,结构如上
|
||
:param color_palette: Seaborn 色盘名称,默认 'BuPu'
|
||
:param dpi: 输出图像分辨率,默认 128
|
||
:return: base64 编码的 SVG 图像 Data URL 字符串
|
||
"""
|
||
# === 颜色准备 ===
|
||
colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette)
|
||
if len(colors) == 0:
|
||
raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效")
|
||
|
||
# === 字体设置 ===
|
||
available_font = ufont.get_fonts()
|
||
plt.rcParams['font.sans-serif'] = list(available_font)
|
||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题
|
||
|
||
# === 创建画布 ===
|
||
plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi)
|
||
ax = plt.gca()
|
||
|
||
# === 绘制每条折线 ===
|
||
current_year = str(datetime.date.today().year)
|
||
for i, year in enumerate(data_dict.keys()):
|
||
s = data_dict[year]
|
||
|
||
# 检查数据有效性
|
||
if len(s) == 0:
|
||
raise ValueError(f"数据序列 {year} 为空,请检查输入")
|
||
|
||
# 判断是否为当前年份,决定样式
|
||
is_current_year = (year == current_year)
|
||
linewidth = 2.5 if is_current_year else 1.5
|
||
marker = 'o' if is_current_year else None
|
||
label = f'{year}年'
|
||
|
||
ax.plot(
|
||
s.index, s.values,
|
||
color=colors[i],
|
||
linewidth=linewidth,
|
||
alpha=0.9,
|
||
label=label,
|
||
marker=marker,
|
||
markersize=4,
|
||
markevery=30 # 每30个点显示一个标记,避免密集
|
||
)
|
||
|
||
# === 智能日期刻度 ===
|
||
ax.xaxis.set_major_locator(MonthLocator(bymonth=range(1, 13, 2))) # 每两个月一个主刻度
|
||
ax.xaxis.set_major_formatter(DateFormatter('%m-%d')) # 格式化为 MM-DD
|
||
|
||
# === 网格与边框美化 ===
|
||
ax.grid(True, which='major', linestyle='--', alpha=0.6)
|
||
ax.grid(True, which='minor', linestyle=':', alpha=0.3)
|
||
for spine in ['top', 'right']:
|
||
ax.spines[spine].set_visible(False)
|
||
ax.spines['left'].set_color('#d9d9d9')
|
||
ax.spines['bottom'].set_color('#d9d9d9')
|
||
|
||
# === 图例 ===
|
||
legend = ax.legend(loc='upper right', frameon=True, framealpha=0.7, edgecolor='#f0f0f0')
|
||
legend.get_frame().set_linewidth(1)
|
||
|
||
# === 输出为 base64 SVG ===
|
||
buffer = BytesIO()
|
||
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
|
||
plt.close()
|
||
|
||
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||
img_base64 = f"data:image/svg+xml;base64,{image_base64}"
|
||
return img_base64
|
||
|
||
|
||
def gen_splines(data_dict: Dict[str, pd.Series], total: pd.Series, x_labels: Optional[List[Union[str, int]]] = None,
|
||
color_palette: str = 'BuPu', dpi: int = 128) -> str:
|
||
"""
|
||
生成平滑曲线图,适用于非时间序列数据(如按类别排序的数值),通过样条插值实现曲线平滑。
|
||
|
||
数据结构要求:
|
||
- data_dict: Dict[str, pd.Series]
|
||
- key: 字符串,表示数据系列名称(如 'A产品', 'B产品')
|
||
- value: pd.Series,索引为数值型(如 1~12),值为观测值(如销量)
|
||
- 示例:
|
||
{
|
||
'A产品': pd.Series([10, 15, 12, 18], index=[1,2,3,4]),
|
||
'B产品': pd.Series([8, 14, 16, 20], index=[1,2,3,4])
|
||
}
|
||
|
||
- total: pd.Series
|
||
- 索引必须与 data_dict 的 key 一致
|
||
- 值为每个系列的总和(用于图例标注)
|
||
- 示例:pd.Series({'A产品': 55, 'B产品': 58})
|
||
|
||
- x_labels: List[Union[str, int]], 可选
|
||
- 用于自定义 X 轴刻度标签(通常为原始索引值)
|
||
- 若为 None,则使用插值后的密集点(不显示原始刻度)
|
||
- 若提供,则在绘图后手动设置刻度,避免插值后刻度混乱
|
||
|
||
功能说明:
|
||
- 按 total 值降序排列绘图顺序(重要数据优先显示)
|
||
- 使用三次样条插值(k=3)生成平滑曲线
|
||
- 图例中显示原始总值(如 "A产品(55)")
|
||
|
||
:param data_dict: 各系列的原始观测数据,结构如上
|
||
:param total: 每个系列的总和,用于排序和图例标注
|
||
:param x_labels: 可选,原始 X 轴标签列表(如 [1,2,3,4]),用于控制刻度显示
|
||
:param color_palette: Seaborn 色盘名称
|
||
:param dpi: 输出分辨率
|
||
:return: base64 编码的 SVG 图像 Data URL 字符串
|
||
"""
|
||
# === 颜色准备 ===
|
||
colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette)
|
||
if len(colors) == 0:
|
||
raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效")
|
||
|
||
# === 字体设置 ===
|
||
available_font = ufont.get_fonts()
|
||
plt.rcParams['font.sans-serif'] = list(available_font)
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
# === 创建画布 ===
|
||
plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi)
|
||
ax = plt.gca()
|
||
|
||
# === 按总值排序,确保重要数据优先显示 ===
|
||
total_sorted = total.sort_values(ascending=False)
|
||
|
||
# === 绘制每条平滑曲线 ===
|
||
for i, series_name in enumerate(total_sorted.index):
|
||
s = data_dict[series_name]
|
||
|
||
# 检查数据完整性
|
||
if len(s) < 2:
|
||
raise ValueError(f"系列 {series_name} 数据点少于2个,无法进行插值")
|
||
|
||
x = s.index.values.astype(float) # 确保为数值类型
|
||
y = s.values.astype(float)
|
||
|
||
# 插值:生成 100 个平滑点(线性插值不足以平滑,样条更优)
|
||
x_new = np.linspace(x.min(), x.max(), 100)
|
||
spline = make_interp_spline(x, y, k=3) # 三次样条插值
|
||
y_smooth = spline(x_new)
|
||
|
||
# 转回 Series 便于绘图
|
||
s_smooth = pd.Series(y_smooth, index=x_new)
|
||
|
||
ax.plot(
|
||
s_smooth.index, s_smooth.values,
|
||
color=colors[i],
|
||
linewidth=2,
|
||
alpha=0.9,
|
||
label=f'{series_name}({total_sorted[series_name]})', # 图例中显示总值
|
||
markersize=4,
|
||
markevery=30 # 标记稀疏,避免视觉干扰
|
||
)
|
||
|
||
# === 手动设置 X 轴刻度(若提供)===
|
||
if x_labels is not None:
|
||
# 确保 x_labels 与原始 x 一致(可选验证)
|
||
plt.xticks(ticks=x_labels, labels=x_labels)
|
||
|
||
# === 网格与边框 ===
|
||
ax.grid(True, which='major', linestyle='--', alpha=0.6)
|
||
ax.grid(True, which='minor', linestyle=':', alpha=0.3)
|
||
for spine in ['top', 'right']:
|
||
ax.spines[spine].set_visible(False)
|
||
ax.spines['left'].set_color('#d9d9d9')
|
||
ax.spines['bottom'].set_color('#d9d9d9')
|
||
|
||
# === 图例 ===
|
||
legend = ax.legend(loc='upper left', frameon=True, framealpha=0.7, edgecolor='#f0f0f0')
|
||
legend.get_frame().set_linewidth(1)
|
||
|
||
# === 输出为 base64 SVG ===
|
||
buffer = BytesIO()
|
||
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
|
||
plt.close()
|
||
|
||
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||
img_base64 = f"data:image/svg+xml;base64,{image_base64}"
|
||
return img_base64 |