首次提交
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
import seaborn
|
||||
|
||||
|
||||
def get_seaborn_colors(n, palette='husl'):
|
||||
"""
|
||||
使用Seaborn的调色板生成颜色。
|
||||
"""
|
||||
return seaborn.color_palette(palette, n_colors=n).as_hex()
|
||||
@@ -0,0 +1,291 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from paste import chart
|
||||
from paste.util import ufont
|
||||
|
||||
|
||||
# ===========================
|
||||
# 工具函数:统一字体与保存逻辑
|
||||
# ===========================
|
||||
|
||||
def _setup_plot():
|
||||
"""
|
||||
设置全局绘图参数(字体、负号、DPI),避免重复代码。
|
||||
"""
|
||||
available_font = ufont.get_fonts()
|
||||
plt.rcParams.update({
|
||||
'font.sans-serif': list(available_font),
|
||||
'axes.unicode_minus': False,
|
||||
})
|
||||
|
||||
|
||||
def _save_to_base64(dpi: int = 128) -> str:
|
||||
"""
|
||||
将当前图形保存为 base64 编码的 SVG 字符串,自动关闭图形。
|
||||
返回 Data URL 格式。
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight')
|
||||
plt.close()
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
return f"data:image/svg+xml;base64,{image_base64}"
|
||||
|
||||
|
||||
# ===========================
|
||||
# 1. gen_vertical_bars:竖向堆叠柱状图
|
||||
# ===========================
|
||||
|
||||
def gen_vertical_bars(
|
||||
primary_values: List[float],
|
||||
nested_values: List[float],
|
||||
x_labels: List[str],
|
||||
group_labels: List[str],
|
||||
color_palette: str = 'BuPu',
|
||||
dpi: int = 128
|
||||
) -> str:
|
||||
"""
|
||||
生成竖向堆叠柱状图(两个分组)。
|
||||
|
||||
数据结构说明:
|
||||
- primary_values: [v1, v2, ..., vn] # 主柱高度,长度 = n
|
||||
- nested_values: [v1, v2, ..., vn] # 嵌套柱高度,长度 = n,从主柱顶部叠加
|
||||
- x_labels: ['A', 'B', 'C', ...] # 每个柱子的X轴标签,长度 = n
|
||||
- group_labels: ['Group1', 'Group2'] # 两个分组的图例名称,长度必须为2
|
||||
|
||||
注意:所有列表长度必须一致(n),否则会报错。
|
||||
本函数仅支持两个分组,若需扩展,建议使用 genHBars。
|
||||
|
||||
:param primary_values: 主柱数据(底部),数值列表
|
||||
:param nested_values: 嵌套柱数据(顶部),数值列表
|
||||
:param x_labels: X轴标签列表
|
||||
:param group_labels: 图例标签列表,长度为2
|
||||
:param color_palette: Seaborn 色盘名称
|
||||
:param dpi: 图像分辨率
|
||||
:return: base64 编码的 SVG 图像字符串
|
||||
"""
|
||||
if len(group_labels) != 2:
|
||||
raise ValueError("group_labels 必须包含两个元素:[主组, 嵌套组]")
|
||||
|
||||
_setup_plot()
|
||||
|
||||
colors = chart.get_seaborn_colors(2, palette=color_palette)
|
||||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||||
|
||||
# 绘制主柱(底部)
|
||||
bars1 = ax.bar(x_labels, primary_values, label=group_labels[0], color=colors[0])
|
||||
|
||||
# 绘制嵌套柱(从主柱顶部开始,即 bottom=primary_values)
|
||||
# 注意:原代码中 bottom=0 是错误的!应为 bottom=primary_values
|
||||
# 修正:从主柱顶部开始叠加
|
||||
bars2 = ax.bar(x_labels, nested_values, bottom=primary_values, label=group_labels[1], color=colors[1])
|
||||
|
||||
# 添加数值标签
|
||||
ax.bar_label(bars1, label_type='edge', padding=3) # 边缘标签
|
||||
ax.bar_label(bars2, label_type='center', padding=1) # 中心标签
|
||||
|
||||
# 旋转X轴标签,避免重叠
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
|
||||
# 添加图例
|
||||
ax.legend()
|
||||
|
||||
# 自动紧凑布局
|
||||
plt.tight_layout()
|
||||
|
||||
return _save_to_base64(dpi)
|
||||
|
||||
|
||||
# ===========================
|
||||
# 2. gen_horizontal_stacked_bars:横向堆叠柱状图(多分组)
|
||||
# ===========================
|
||||
|
||||
def gen_horizontal_stacked_bars(
|
||||
data_shap: np.ndarray,
|
||||
x_labels: List[str],
|
||||
y_labels: List[str],
|
||||
y_data_unit: str = '',
|
||||
legend_title: str = '',
|
||||
color_palette: str = 'BuPu',
|
||||
dpi: int = 128
|
||||
) -> str:
|
||||
"""
|
||||
生成横向堆叠柱状图(支持多个分组)。
|
||||
|
||||
数据结构说明:
|
||||
- data_shap: numpy.ndarray, shape = (len(y_labels), len(x_labels))
|
||||
- 每行代表一个 Y 轴类别(如:城市、部门)
|
||||
- 每列代表一个 X 轴分组(如:年份、产品类型)
|
||||
- 例如:data_shap[2][1] 表示第3个Y类别(y_labels[2])在第2个X分组(x_labels[1])的数值
|
||||
|
||||
- x_labels: ['2021', '2022', '2023'] # 每个分组的名称(对应列)
|
||||
- y_labels: ['北京', '上海', '广州'] # 每个柱子的类别名称(对应行)
|
||||
- y_data_unit: 如 '人'、'万元',用于标注总量
|
||||
|
||||
:param data_shap: 二维数值数组,shape=(行数, 列数)
|
||||
:param x_labels: 每个分组的标签列表
|
||||
:param y_labels: 每个柱子的标签列表
|
||||
:param y_data_unit: Y轴总量单位(如 '人')
|
||||
:param legend_title: 图例标题
|
||||
:param color_palette: Seaborn 色盘
|
||||
:param dpi: 分辨率
|
||||
:return: base64 编码的 SVG 图像字符串
|
||||
"""
|
||||
if data_shap.ndim != 2:
|
||||
raise ValueError("data_shap 必须是二维数组,shape=(Y, X)")
|
||||
|
||||
if len(x_labels) != data_shap.shape[1]:
|
||||
raise ValueError(f"x_labels 长度 ({len(x_labels)}) 与 data_shap 列数 ({data_shap.shape[1]}) 不匹配")
|
||||
|
||||
if len(y_labels) != data_shap.shape[0]:
|
||||
raise ValueError(f"y_labels 长度 ({len(y_labels)}) 与 data_shap 行数 ({data_shap.shape[0]}) 不匹配")
|
||||
|
||||
_setup_plot()
|
||||
|
||||
colors = chart.get_seaborn_colors(len(x_labels), palette=color_palette)
|
||||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||||
|
||||
# 初始化左侧起始位置(从0开始)
|
||||
left_pos = np.zeros(len(y_labels))
|
||||
bars = [] # 存储每个分组的柱形对象,用于后续标签绘制
|
||||
|
||||
# 按列(分组)逐个绘制横向堆叠柱形
|
||||
for i in range(len(x_labels)):
|
||||
bar = ax.barh(
|
||||
y_labels,
|
||||
data_shap[:, i],
|
||||
left=left_pos,
|
||||
label=x_labels[i],
|
||||
color=colors[i],
|
||||
edgecolor='white',
|
||||
linewidth=0.5
|
||||
)
|
||||
bars.append(bar)
|
||||
left_pos += data_shap[:, i] # 更新下一次的起始位置
|
||||
|
||||
# 计算每个Y类别的总和
|
||||
totals = np.sum(data_shap, axis=1)
|
||||
# 构建带总量的Y标签:如 "北京\n1200人"
|
||||
y_labels_mix = [f"{place}\n{int(total)}{y_data_unit}" for place, total in zip(y_labels, totals)]
|
||||
ax.set_yticks(np.arange(len(y_labels)))
|
||||
ax.set_yticklabels(y_labels_mix)
|
||||
ax.invert_yaxis() # 使第一个类别在顶部
|
||||
|
||||
# 图例
|
||||
ax.legend(title=legend_title, bbox_to_anchor=(0.958, 0.25), frameon=True, framealpha=0.7)
|
||||
|
||||
# 智能标签函数:根据柱宽决定标签位置
|
||||
def add_smart_labels(spacing_pixels: int = 10):
|
||||
dpi = fig.dpi
|
||||
xlim = ax.get_xlim()
|
||||
x_range = xlim[1] - xlim[0]
|
||||
spacing_data = spacing_pixels * x_range / (fig.get_size_inches()[0] * dpi)
|
||||
|
||||
text_style = {
|
||||
'fontsize': 9,
|
||||
'va': 'center',
|
||||
'bbox': {
|
||||
'boxstyle': 'round',
|
||||
'facecolor': 'white',
|
||||
'alpha': 0.7,
|
||||
'edgecolor': 'none',
|
||||
'pad': 0.3
|
||||
}
|
||||
}
|
||||
|
||||
for i, (place, total_width) in enumerate(zip(y_labels, totals)):
|
||||
values = [bar[i].get_width() for bar in bars]
|
||||
label_text = "+".join(str(int(v)) for v in values if v > 0)
|
||||
if not label_text:
|
||||
continue
|
||||
|
||||
text_width = (len(label_text) * 0.015 + 0.05) * x_range
|
||||
y_pos = bars[0][i].get_y() + bars[0][i].get_height() / 2
|
||||
|
||||
if total_width >= text_width + 2 * spacing_data:
|
||||
ax.text(total_width - spacing_data, y_pos, label_text, ha='right', **text_style)
|
||||
else:
|
||||
ax.text(total_width + spacing_data, y_pos, label_text, ha='left', **text_style)
|
||||
|
||||
add_smart_labels()
|
||||
|
||||
# 添加网格线(仅x轴方向)
|
||||
ax.grid(axis='x', linestyle=':', alpha=0.6)
|
||||
plt.tight_layout()
|
||||
|
||||
return _save_to_base64(dpi)
|
||||
|
||||
|
||||
# ===========================
|
||||
# 3. gen_percent_stacked_bars:百分比堆叠柱状图
|
||||
# ===========================
|
||||
|
||||
def gen_percent_stacked_bars(
|
||||
data: Dict[str, List[float]],
|
||||
x_labels: List[str],
|
||||
legend_title: str = '',
|
||||
color_palette: str = 'BuPu',
|
||||
dpi: int = 128
|
||||
) -> str:
|
||||
"""
|
||||
绘制百分比堆叠柱状图(所有柱子高度统一为100%)。
|
||||
|
||||
数据结构说明:
|
||||
- data: Dict[str, List[float]]
|
||||
- 每个 key 代表一个分组(如 'A组', 'B组')
|
||||
- 每个 value 是长度为 len(x_labels) 的数值列表,表示该组在每个X标签下的数值
|
||||
- 示例:{'A组': [30, 40, 50], 'B组': [70, 60, 50]} → 每列总和为100
|
||||
|
||||
- x_labels: ['2021', '2022', '2023'] # 每个柱子的标签
|
||||
|
||||
注意:每个X位置的总和必须一致(或接近),否则百分比可能失真。
|
||||
本函数自动将每列归一化为100%。
|
||||
|
||||
:param data: 分组数据字典,键为组名,值为数值列表
|
||||
:param x_labels: X轴标签列表
|
||||
:param legend_title: 图例标题
|
||||
:param color_palette: Seaborn 色盘
|
||||
:param dpi: 分辨率
|
||||
:return: base64 编码的 SVG 图像字符串
|
||||
"""
|
||||
if not data:
|
||||
raise ValueError("data 不能为空字典")
|
||||
|
||||
# 验证所有组长度一致
|
||||
group_lengths = [len(values) for values in data.values()]
|
||||
if not all(l == len(x_labels) for l in group_lengths):
|
||||
raise ValueError(f"所有分组的数值列表长度必须等于 x_labels 长度 ({len(x_labels)})")
|
||||
|
||||
_setup_plot()
|
||||
|
||||
# 转换为 numpy 数组
|
||||
group_names = list(data.keys())
|
||||
data_array = np.array([data[group] for group in group_names]) # shape: (n_groups, n_x)
|
||||
percent_data = data_array / data_array.sum(axis=0) * 100 # 按列归一化
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||||
|
||||
bottom = np.zeros(len(x_labels)) # 初始底部为0
|
||||
|
||||
# 绘制每个分组
|
||||
for i, (group, color) in enumerate(zip(group_names, chart.get_seaborn_colors(len(group_names), palette=color_palette))):
|
||||
ax.bar(x_labels, percent_data[i], bottom=bottom, label=group, color=color)
|
||||
bottom += percent_data[i]
|
||||
|
||||
# 设置Y轴范围为0-100%
|
||||
ax.set_ylim(0, 100)
|
||||
|
||||
# 添加百分比标签
|
||||
for container in ax.containers:
|
||||
ax.bar_label(container, label_type='center', fmt='%.1f%%', padding=0)
|
||||
|
||||
# 图例
|
||||
ax.legend(title=legend_title, loc='center left', bbox_to_anchor=(1, 0.5), frameon=True, framealpha=0.7)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
return _save_to_base64(dpi)
|
||||
@@ -0,0 +1,213 @@
|
||||
import base64
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.dates import MonthLocator, DateFormatter
|
||||
from scipy.interpolate import make_interp_spline
|
||||
|
||||
from paste import chart
|
||||
from paste.util import ufont
|
||||
|
||||
|
||||
def gen_lines(data_dict: Dict[str, pd.Series], color_palette: str = 'BuPu', dpi: int = 128) -> str:
|
||||
"""
|
||||
生成多条折线图,用于对比不同年份的时间序列数据(如月度数据)。
|
||||
|
||||
数据结构要求:
|
||||
- data_dict: Dict[str, pd.Series]
|
||||
- key: 字符串,表示年份(如 '2022', '2023'),必须可转为整数
|
||||
- value: pd.Series,索引为日期(datetime-like),值为数值(如销售额、访问量等)
|
||||
- 示例:
|
||||
{
|
||||
'2022': pd.Series([100, 120, 90], index=pd.date_range('2022-01-01', periods=3, freq='M')),
|
||||
'2023': pd.Series([110, 130, 95], index=pd.date_range('2023-01-01', periods=3, freq='M'))
|
||||
}
|
||||
|
||||
功能说明:
|
||||
- 自动识别当前年份,并高亮显示其曲线(加粗、加圆点标记)
|
||||
- 使用每两个月一次的主刻度,格式化为 'MM-DD'
|
||||
- 保存为 base64 编码的 SVG 图像,适用于 Web 前端直接嵌入
|
||||
|
||||
:param data_dict: 年度时间序列数据字典,结构如上
|
||||
:param color_palette: Seaborn 色盘名称,默认 'BuPu'
|
||||
:param dpi: 输出图像分辨率,默认 128
|
||||
:return: base64 编码的 SVG 图像 Data URL 字符串
|
||||
"""
|
||||
# === 颜色准备 ===
|
||||
colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette)
|
||||
if len(colors) == 0:
|
||||
raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效")
|
||||
|
||||
# === 字体设置 ===
|
||||
available_font = ufont.get_fonts()
|
||||
plt.rcParams['font.sans-serif'] = list(available_font)
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题
|
||||
|
||||
# === 创建画布 ===
|
||||
plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi)
|
||||
ax = plt.gca()
|
||||
|
||||
# === 绘制每条折线 ===
|
||||
current_year = str(datetime.date.today().year)
|
||||
for i, year in enumerate(data_dict.keys()):
|
||||
s = data_dict[year]
|
||||
|
||||
# 检查数据有效性
|
||||
if len(s) == 0:
|
||||
raise ValueError(f"数据序列 {year} 为空,请检查输入")
|
||||
|
||||
# 判断是否为当前年份,决定样式
|
||||
is_current_year = (year == current_year)
|
||||
linewidth = 2.5 if is_current_year else 1.5
|
||||
marker = 'o' if is_current_year else None
|
||||
label = f'{year}年'
|
||||
|
||||
ax.plot(
|
||||
s.index, s.values,
|
||||
color=colors[i],
|
||||
linewidth=linewidth,
|
||||
alpha=0.9,
|
||||
label=label,
|
||||
marker=marker,
|
||||
markersize=4,
|
||||
markevery=30 # 每30个点显示一个标记,避免密集
|
||||
)
|
||||
|
||||
# === 智能日期刻度 ===
|
||||
ax.xaxis.set_major_locator(MonthLocator(bymonth=range(1, 13, 2))) # 每两个月一个主刻度
|
||||
ax.xaxis.set_major_formatter(DateFormatter('%m-%d')) # 格式化为 MM-DD
|
||||
|
||||
# === 网格与边框美化 ===
|
||||
ax.grid(True, which='major', linestyle='--', alpha=0.6)
|
||||
ax.grid(True, which='minor', linestyle=':', alpha=0.3)
|
||||
for spine in ['top', 'right']:
|
||||
ax.spines[spine].set_visible(False)
|
||||
ax.spines['left'].set_color('#d9d9d9')
|
||||
ax.spines['bottom'].set_color('#d9d9d9')
|
||||
|
||||
# === 图例 ===
|
||||
legend = ax.legend(loc='upper right', frameon=True, framealpha=0.7, edgecolor='#f0f0f0')
|
||||
legend.get_frame().set_linewidth(1)
|
||||
|
||||
# === 输出为 base64 SVG ===
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
|
||||
plt.close()
|
||||
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
img_base64 = f"data:image/svg+xml;base64,{image_base64}"
|
||||
return img_base64
|
||||
|
||||
|
||||
def gen_splines(data_dict: Dict[str, pd.Series], total: pd.Series, x_labels: Optional[List[Union[str, int]]] = None,
|
||||
color_palette: str = 'BuPu', dpi: int = 128) -> str:
|
||||
"""
|
||||
生成平滑曲线图,适用于非时间序列数据(如按类别排序的数值),通过样条插值实现曲线平滑。
|
||||
|
||||
数据结构要求:
|
||||
- data_dict: Dict[str, pd.Series]
|
||||
- key: 字符串,表示数据系列名称(如 'A产品', 'B产品')
|
||||
- value: pd.Series,索引为数值型(如 1~12),值为观测值(如销量)
|
||||
- 示例:
|
||||
{
|
||||
'A产品': pd.Series([10, 15, 12, 18], index=[1,2,3,4]),
|
||||
'B产品': pd.Series([8, 14, 16, 20], index=[1,2,3,4])
|
||||
}
|
||||
|
||||
- total: pd.Series
|
||||
- 索引必须与 data_dict 的 key 一致
|
||||
- 值为每个系列的总和(用于图例标注)
|
||||
- 示例:pd.Series({'A产品': 55, 'B产品': 58})
|
||||
|
||||
- x_labels: List[Union[str, int]], 可选
|
||||
- 用于自定义 X 轴刻度标签(通常为原始索引值)
|
||||
- 若为 None,则使用插值后的密集点(不显示原始刻度)
|
||||
- 若提供,则在绘图后手动设置刻度,避免插值后刻度混乱
|
||||
|
||||
功能说明:
|
||||
- 按 total 值降序排列绘图顺序(重要数据优先显示)
|
||||
- 使用三次样条插值(k=3)生成平滑曲线
|
||||
- 图例中显示原始总值(如 "A产品(55)")
|
||||
|
||||
:param data_dict: 各系列的原始观测数据,结构如上
|
||||
:param total: 每个系列的总和,用于排序和图例标注
|
||||
:param x_labels: 可选,原始 X 轴标签列表(如 [1,2,3,4]),用于控制刻度显示
|
||||
:param color_palette: Seaborn 色盘名称
|
||||
:param dpi: 输出分辨率
|
||||
:return: base64 编码的 SVG 图像 Data URL 字符串
|
||||
"""
|
||||
# === 颜色准备 ===
|
||||
colors = chart.get_seaborn_colors(len(data_dict), palette=color_palette)
|
||||
if len(colors) == 0:
|
||||
raise ValueError("color_palette 返回空颜色列表,请检查色盘名称是否有效")
|
||||
|
||||
# === 字体设置 ===
|
||||
available_font = ufont.get_fonts()
|
||||
plt.rcParams['font.sans-serif'] = list(available_font)
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# === 创建画布 ===
|
||||
plt.figure(figsize=(10, 6), facecolor='none', dpi=dpi)
|
||||
ax = plt.gca()
|
||||
|
||||
# === 按总值排序,确保重要数据优先显示 ===
|
||||
total_sorted = total.sort_values(ascending=False)
|
||||
|
||||
# === 绘制每条平滑曲线 ===
|
||||
for i, series_name in enumerate(total_sorted.index):
|
||||
s = data_dict[series_name]
|
||||
|
||||
# 检查数据完整性
|
||||
if len(s) < 2:
|
||||
raise ValueError(f"系列 {series_name} 数据点少于2个,无法进行插值")
|
||||
|
||||
x = s.index.values.astype(float) # 确保为数值类型
|
||||
y = s.values.astype(float)
|
||||
|
||||
# 插值:生成 100 个平滑点(线性插值不足以平滑,样条更优)
|
||||
x_new = np.linspace(x.min(), x.max(), 100)
|
||||
spline = make_interp_spline(x, y, k=3) # 三次样条插值
|
||||
y_smooth = spline(x_new)
|
||||
|
||||
# 转回 Series 便于绘图
|
||||
s_smooth = pd.Series(y_smooth, index=x_new)
|
||||
|
||||
ax.plot(
|
||||
s_smooth.index, s_smooth.values,
|
||||
color=colors[i],
|
||||
linewidth=2,
|
||||
alpha=0.9,
|
||||
label=f'{series_name}({total_sorted[series_name]})', # 图例中显示总值
|
||||
markersize=4,
|
||||
markevery=30 # 标记稀疏,避免视觉干扰
|
||||
)
|
||||
|
||||
# === 手动设置 X 轴刻度(若提供)===
|
||||
if x_labels is not None:
|
||||
# 确保 x_labels 与原始 x 一致(可选验证)
|
||||
plt.xticks(ticks=x_labels, labels=x_labels)
|
||||
|
||||
# === 网格与边框 ===
|
||||
ax.grid(True, which='major', linestyle='--', alpha=0.6)
|
||||
ax.grid(True, which='minor', linestyle=':', alpha=0.3)
|
||||
for spine in ['top', 'right']:
|
||||
ax.spines[spine].set_visible(False)
|
||||
ax.spines['left'].set_color('#d9d9d9')
|
||||
ax.spines['bottom'].set_color('#d9d9d9')
|
||||
|
||||
# === 图例 ===
|
||||
legend = ax.legend(loc='upper left', frameon=True, framealpha=0.7, edgecolor='#f0f0f0')
|
||||
legend.get_frame().set_linewidth(1)
|
||||
|
||||
# === 输出为 base64 SVG ===
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
|
||||
plt.close()
|
||||
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
img_base64 = f"data:image/svg+xml;base64,{image_base64}"
|
||||
return img_base64
|
||||
@@ -0,0 +1,144 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
from paste import chart
|
||||
from paste.util import ufont
|
||||
|
||||
|
||||
def gen_pie(data_df, value_column, percentage_column, legend_labels, color_palette='BuPu', dpi=128):
|
||||
"""
|
||||
生成环形图(Doughnut Chart),并附带右侧图例(色块 + 文字描述)。
|
||||
|
||||
注意:虽然函数名为 gen_pie,但实际绘制的是环形图(有中心空洞),非传统饼图。
|
||||
|
||||
参数说明(数据结构要求):
|
||||
--------------------------
|
||||
data_df : pandas.DataFrame
|
||||
必须包含以下三列:
|
||||
- value_column (数值列): 每个类别的数值(如设备数量),用于计算扇区角度。
|
||||
- percentage_column (百分比列): 每个类别的百分比(如 35.2%),用于显示在图例中。
|
||||
- legend_labels (标签列): 每个类别的名称(如 '服务器', '交换机'),用于图例文本。
|
||||
|
||||
示例结构:
|
||||
| server_count | percentage | device_type |
|
||||
|--------------|------------|-------------|
|
||||
| 35 | 35.2% | 服务器 |
|
||||
| 28 | 28.1% | 交换机 |
|
||||
| 22 | 22.0% | 路由器 |
|
||||
|
||||
要求:
|
||||
- 所有数值必须 ≥ 0;
|
||||
- 百分比列应为字符串格式(含'%'符号),或可被转换为字符串;
|
||||
- 所有列长度必须一致,且与 data_df 的行数匹配。
|
||||
|
||||
value_column : str
|
||||
data_df 中表示数值的列名(如 'server_count')。
|
||||
|
||||
percentage_column : str
|
||||
data_df 中表示百分比的列名(如 'percentage'),用于图例中显示。
|
||||
|
||||
legend_labels : str
|
||||
data_df 中表示类别名称的列名(如 'device_type'),用于图例文本。
|
||||
|
||||
color_palette : str, default='BuPu'
|
||||
Seaborn 颜色调色板名称,用于为每个扇区分配颜色。
|
||||
可选值参考:'BuPu', 'viridis', 'plasma', 'Set3' 等。
|
||||
若类别数 > 调色板颜色数,会自动循环使用。
|
||||
|
||||
dpi : int, default=128
|
||||
输出图像的分辨率(dots per inch),影响 SVG 清晰度。
|
||||
|
||||
返回值:
|
||||
--------
|
||||
str : Base64 编码的 SVG 图像 Data URL,可直接用于 HTML <img src="...">。
|
||||
"""
|
||||
|
||||
# === 1. 颜色准备 ===
|
||||
# 从 Seaborn 获取指定调色板的颜色序列,长度等于数据行数
|
||||
colors = chart.get_seaborn_colors(len(data_df.index), palette=color_palette)
|
||||
|
||||
# === 2. 字体设置 ===
|
||||
# 获取系统可用中文字体,优先使用支持中文的字体
|
||||
available_font = ufont.get_fonts()
|
||||
plt.rcParams['font.sans-serif'] = list(available_font)
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题
|
||||
|
||||
# === 3. 创建画布与子图 ===
|
||||
# 使用非白色背景,便于嵌入网页(透明背景)
|
||||
fig = plt.figure(figsize=(8, 6), facecolor='none', dpi=dpi)
|
||||
|
||||
# 左侧:环形图区域
|
||||
ax1 = fig.add_axes((0.1, 0.05, 0.6, 0.9)) # [left, bottom, width, height]
|
||||
# 右侧:图例区域(色块+文字)
|
||||
ax2 = fig.add_axes((0.75, 0.05, 0.2, 0.9))
|
||||
|
||||
# 移除两个子图的坐标轴(纯图形,无刻度)
|
||||
for ax in [ax1, ax2]:
|
||||
ax.set_axis_off()
|
||||
|
||||
# === 4. 绘制环形图 ===
|
||||
# 使用 wedgeprops 设置内环宽度,实现环形效果
|
||||
wedges, _ = ax1.pie(
|
||||
data_df[value_column], # 数值列,决定扇区大小
|
||||
colors=colors, # 颜色序列
|
||||
startangle=90, # 从正上方开始绘制(更美观)
|
||||
wedgeprops=dict(
|
||||
width=0.6, # 环形宽度(0~1),越大越薄
|
||||
edgecolor='white', # 边缘白色,提升对比度
|
||||
linewidth=1 # 边缘线宽
|
||||
),
|
||||
radius=1.2, # 半径略大于1,避免边缘被裁剪
|
||||
)
|
||||
|
||||
# === 5. 右侧图例:色块 + 文字 ===
|
||||
# 图例参数定义
|
||||
num_items = len(data_df.index)
|
||||
box_size = 0.08 # 每个色块大小(宽度和高度)
|
||||
text_offset = 0.05 # 文字与色块的水平间距
|
||||
font_size = 24 # 字体大小
|
||||
line_height = 0.1 # 每行占用高度(包括上下间距)
|
||||
total_legend_height = num_items * line_height
|
||||
start_y = 0.45 + total_legend_height / 2 # 使图例垂直居中于右侧区域
|
||||
|
||||
y_pos = start_y # 当前绘制的 y 坐标
|
||||
|
||||
# 遍历每一类,绘制色块和文本
|
||||
for i, (label, color) in enumerate(zip(data_df[legend_labels], colors)):
|
||||
# 绘制色块矩形
|
||||
ax2.add_patch(
|
||||
Rectangle(
|
||||
(0.05, y_pos - box_size / 2), # 左下角坐标:x=0.05,y居中
|
||||
box_size, box_size, # 宽高均为 box_size
|
||||
facecolor=color, # 填充颜色
|
||||
edgecolor='white', # 白色描边
|
||||
lw=1 # 线宽
|
||||
)
|
||||
)
|
||||
|
||||
# 绘制文字标签:名称 + 数值 + 百分比
|
||||
ax2.text(
|
||||
0.05 + box_size + text_offset, # 文字起始 x 位置:色块右侧 + 间距
|
||||
y_pos, # y 位置居中于色块
|
||||
f"{label}({data_df[value_column][i]}台,{data_df[percentage_column][i]}%)",
|
||||
va='center', # 垂直居中
|
||||
ha='left', # 水平左对齐
|
||||
fontsize=font_size,
|
||||
fontweight='bold'
|
||||
)
|
||||
|
||||
y_pos -= line_height # 下移一行,准备下一个图例项
|
||||
|
||||
# === 6. 输出图像 ===
|
||||
# 使用 BytesIO 缓存 SVG 图像
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight', facecolor='none')
|
||||
plt.close() # 关闭图形以释放内存
|
||||
|
||||
# 编码为 Base64 并构造 Data URL
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
img_base64 = f"data:image/svg+xml;base64,{image_base64}"
|
||||
|
||||
return img_base64
|
||||
Reference in New Issue
Block a user