4729698049
git-subtree-dir: paste-framework git-subtree-split: 34e8684c4bc3cebbe177509f42ab4ef5b5425a7a
291 lines
10 KiB
Python
291 lines
10 KiB
Python
import base64
|
||
from io import BytesIO
|
||
from typing import List, Dict
|
||
|
||
import numpy as np
|
||
from matplotlib import pyplot as plt
|
||
|
||
from paste import chart
|
||
from paste.util import ufont
|
||
|
||
|
||
# ===========================
|
||
# 工具函数:统一字体与保存逻辑
|
||
# ===========================
|
||
|
||
def _setup_plot():
|
||
"""
|
||
设置全局绘图参数(字体、负号、DPI),避免重复代码。
|
||
"""
|
||
available_font = ufont.get_fonts()
|
||
plt.rcParams.update({
|
||
'font.sans-serif': list(available_font),
|
||
'axes.unicode_minus': False,
|
||
})
|
||
|
||
|
||
def _save_to_base64(dpi: int = 128) -> str:
|
||
"""
|
||
将当前图形保存为 base64 编码的 SVG 字符串,自动关闭图形。
|
||
返回 Data URL 格式。
|
||
"""
|
||
buffer = BytesIO()
|
||
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight')
|
||
plt.close()
|
||
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||
return f"data:image/svg+xml;base64,{image_base64}"
|
||
|
||
|
||
# ===========================
|
||
# 1. gen_vertical_bars:竖向堆叠柱状图
|
||
# ===========================
|
||
|
||
def gen_vertical_bars(
|
||
primary_values: List[float],
|
||
nested_values: List[float],
|
||
x_labels: List[str],
|
||
group_labels: List[str],
|
||
color_palette: str = 'BuPu',
|
||
dpi: int = 128
|
||
) -> str:
|
||
"""
|
||
生成竖向堆叠柱状图(两个分组)。
|
||
|
||
数据结构说明:
|
||
- primary_values: [v1, v2, ..., vn] # 主柱高度,长度 = n
|
||
- nested_values: [v1, v2, ..., vn] # 嵌套柱高度,长度 = n,从主柱顶部叠加
|
||
- x_labels: ['A', 'B', 'C', ...] # 每个柱子的X轴标签,长度 = n
|
||
- group_labels: ['Group1', 'Group2'] # 两个分组的图例名称,长度必须为2
|
||
|
||
注意:所有列表长度必须一致(n),否则会报错。
|
||
本函数仅支持两个分组,若需扩展,建议使用 genHBars。
|
||
|
||
:param primary_values: 主柱数据(底部),数值列表
|
||
:param nested_values: 嵌套柱数据(顶部),数值列表
|
||
:param x_labels: X轴标签列表
|
||
:param group_labels: 图例标签列表,长度为2
|
||
:param color_palette: Seaborn 色盘名称
|
||
:param dpi: 图像分辨率
|
||
:return: base64 编码的 SVG 图像字符串
|
||
"""
|
||
if len(group_labels) != 2:
|
||
raise ValueError("group_labels 必须包含两个元素:[主组, 嵌套组]")
|
||
|
||
_setup_plot()
|
||
|
||
colors = chart.get_seaborn_colors(2, palette=color_palette)
|
||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||
|
||
# 绘制主柱(底部)
|
||
bars1 = ax.bar(x_labels, primary_values, label=group_labels[0], color=colors[0])
|
||
|
||
# 绘制嵌套柱(从主柱顶部开始,即 bottom=primary_values)
|
||
# 注意:原代码中 bottom=0 是错误的!应为 bottom=primary_values
|
||
# 修正:从主柱顶部开始叠加
|
||
bars2 = ax.bar(x_labels, nested_values, bottom=primary_values, label=group_labels[1], color=colors[1])
|
||
|
||
# 添加数值标签
|
||
ax.bar_label(bars1, label_type='edge', padding=3) # 边缘标签
|
||
ax.bar_label(bars2, label_type='center', padding=1) # 中心标签
|
||
|
||
# 旋转X轴标签,避免重叠
|
||
plt.xticks(rotation=45, ha='right')
|
||
|
||
# 添加图例
|
||
ax.legend()
|
||
|
||
# 自动紧凑布局
|
||
plt.tight_layout()
|
||
|
||
return _save_to_base64(dpi)
|
||
|
||
|
||
# ===========================
|
||
# 2. gen_horizontal_stacked_bars:横向堆叠柱状图(多分组)
|
||
# ===========================
|
||
|
||
def gen_horizontal_stacked_bars(
|
||
data_shap: np.ndarray,
|
||
x_labels: List[str],
|
||
y_labels: List[str],
|
||
y_data_unit: str = '',
|
||
legend_title: str = '',
|
||
color_palette: str = 'BuPu',
|
||
dpi: int = 128
|
||
) -> str:
|
||
"""
|
||
生成横向堆叠柱状图(支持多个分组)。
|
||
|
||
数据结构说明:
|
||
- data_shap: numpy.ndarray, shape = (len(y_labels), len(x_labels))
|
||
- 每行代表一个 Y 轴类别(如:城市、部门)
|
||
- 每列代表一个 X 轴分组(如:年份、产品类型)
|
||
- 例如:data_shap[2][1] 表示第3个Y类别(y_labels[2])在第2个X分组(x_labels[1])的数值
|
||
|
||
- x_labels: ['2021', '2022', '2023'] # 每个分组的名称(对应列)
|
||
- y_labels: ['北京', '上海', '广州'] # 每个柱子的类别名称(对应行)
|
||
- y_data_unit: 如 '人'、'万元',用于标注总量
|
||
|
||
:param data_shap: 二维数值数组,shape=(行数, 列数)
|
||
:param x_labels: 每个分组的标签列表
|
||
:param y_labels: 每个柱子的标签列表
|
||
:param y_data_unit: Y轴总量单位(如 '人')
|
||
:param legend_title: 图例标题
|
||
:param color_palette: Seaborn 色盘
|
||
:param dpi: 分辨率
|
||
:return: base64 编码的 SVG 图像字符串
|
||
"""
|
||
if data_shap.ndim != 2:
|
||
raise ValueError("data_shap 必须是二维数组,shape=(Y, X)")
|
||
|
||
if len(x_labels) != data_shap.shape[1]:
|
||
raise ValueError(f"x_labels 长度 ({len(x_labels)}) 与 data_shap 列数 ({data_shap.shape[1]}) 不匹配")
|
||
|
||
if len(y_labels) != data_shap.shape[0]:
|
||
raise ValueError(f"y_labels 长度 ({len(y_labels)}) 与 data_shap 行数 ({data_shap.shape[0]}) 不匹配")
|
||
|
||
_setup_plot()
|
||
|
||
colors = chart.get_seaborn_colors(len(x_labels), palette=color_palette)
|
||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||
|
||
# 初始化左侧起始位置(从0开始)
|
||
left_pos = np.zeros(len(y_labels))
|
||
bars = [] # 存储每个分组的柱形对象,用于后续标签绘制
|
||
|
||
# 按列(分组)逐个绘制横向堆叠柱形
|
||
for i in range(len(x_labels)):
|
||
bar = ax.barh(
|
||
y_labels,
|
||
data_shap[:, i],
|
||
left=left_pos,
|
||
label=x_labels[i],
|
||
color=colors[i],
|
||
edgecolor='white',
|
||
linewidth=0.5
|
||
)
|
||
bars.append(bar)
|
||
left_pos += data_shap[:, i] # 更新下一次的起始位置
|
||
|
||
# 计算每个Y类别的总和
|
||
totals = np.sum(data_shap, axis=1)
|
||
# 构建带总量的Y标签:如 "北京\n1200人"
|
||
y_labels_mix = [f"{place}\n{int(total)}{y_data_unit}" for place, total in zip(y_labels, totals)]
|
||
ax.set_yticks(np.arange(len(y_labels)))
|
||
ax.set_yticklabels(y_labels_mix)
|
||
ax.invert_yaxis() # 使第一个类别在顶部
|
||
|
||
# 图例
|
||
ax.legend(title=legend_title, bbox_to_anchor=(0.958, 0.25), frameon=True, framealpha=0.7)
|
||
|
||
# 智能标签函数:根据柱宽决定标签位置
|
||
def add_smart_labels(spacing_pixels: int = 10):
|
||
dpi = fig.dpi
|
||
xlim = ax.get_xlim()
|
||
x_range = xlim[1] - xlim[0]
|
||
spacing_data = spacing_pixels * x_range / (fig.get_size_inches()[0] * dpi)
|
||
|
||
text_style = {
|
||
'fontsize': 9,
|
||
'va': 'center',
|
||
'bbox': {
|
||
'boxstyle': 'round',
|
||
'facecolor': 'white',
|
||
'alpha': 0.7,
|
||
'edgecolor': 'none',
|
||
'pad': 0.3
|
||
}
|
||
}
|
||
|
||
for i, (place, total_width) in enumerate(zip(y_labels, totals)):
|
||
values = [bar[i].get_width() for bar in bars]
|
||
label_text = "+".join(str(int(v)) for v in values if v > 0)
|
||
if not label_text:
|
||
continue
|
||
|
||
text_width = (len(label_text) * 0.015 + 0.05) * x_range
|
||
y_pos = bars[0][i].get_y() + bars[0][i].get_height() / 2
|
||
|
||
if total_width >= text_width + 2 * spacing_data:
|
||
ax.text(total_width - spacing_data, y_pos, label_text, ha='right', **text_style)
|
||
else:
|
||
ax.text(total_width + spacing_data, y_pos, label_text, ha='left', **text_style)
|
||
|
||
add_smart_labels()
|
||
|
||
# 添加网格线(仅x轴方向)
|
||
ax.grid(axis='x', linestyle=':', alpha=0.6)
|
||
plt.tight_layout()
|
||
|
||
return _save_to_base64(dpi)
|
||
|
||
|
||
# ===========================
|
||
# 3. gen_percent_stacked_bars:百分比堆叠柱状图
|
||
# ===========================
|
||
|
||
def gen_percent_stacked_bars(
|
||
data: Dict[str, List[float]],
|
||
x_labels: List[str],
|
||
legend_title: str = '',
|
||
color_palette: str = 'BuPu',
|
||
dpi: int = 128
|
||
) -> str:
|
||
"""
|
||
绘制百分比堆叠柱状图(所有柱子高度统一为100%)。
|
||
|
||
数据结构说明:
|
||
- data: Dict[str, List[float]]
|
||
- 每个 key 代表一个分组(如 'A组', 'B组')
|
||
- 每个 value 是长度为 len(x_labels) 的数值列表,表示该组在每个X标签下的数值
|
||
- 示例:{'A组': [30, 40, 50], 'B组': [70, 60, 50]} → 每列总和为100
|
||
|
||
- x_labels: ['2021', '2022', '2023'] # 每个柱子的标签
|
||
|
||
注意:每个X位置的总和必须一致(或接近),否则百分比可能失真。
|
||
本函数自动将每列归一化为100%。
|
||
|
||
:param data: 分组数据字典,键为组名,值为数值列表
|
||
:param x_labels: X轴标签列表
|
||
:param legend_title: 图例标题
|
||
:param color_palette: Seaborn 色盘
|
||
:param dpi: 分辨率
|
||
:return: base64 编码的 SVG 图像字符串
|
||
"""
|
||
if not data:
|
||
raise ValueError("data 不能为空字典")
|
||
|
||
# 验证所有组长度一致
|
||
group_lengths = [len(values) for values in data.values()]
|
||
if not all(l == len(x_labels) for l in group_lengths):
|
||
raise ValueError(f"所有分组的数值列表长度必须等于 x_labels 长度 ({len(x_labels)})")
|
||
|
||
_setup_plot()
|
||
|
||
# 转换为 numpy 数组
|
||
group_names = list(data.keys())
|
||
data_array = np.array([data[group] for group in group_names]) # shape: (n_groups, n_x)
|
||
percent_data = data_array / data_array.sum(axis=0) * 100 # 按列归一化
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||
|
||
bottom = np.zeros(len(x_labels)) # 初始底部为0
|
||
|
||
# 绘制每个分组
|
||
for i, (group, color) in enumerate(zip(group_names, chart.get_seaborn_colors(len(group_names), palette=color_palette))):
|
||
ax.bar(x_labels, percent_data[i], bottom=bottom, label=group, color=color)
|
||
bottom += percent_data[i]
|
||
|
||
# 设置Y轴范围为0-100%
|
||
ax.set_ylim(0, 100)
|
||
|
||
# 添加百分比标签
|
||
for container in ax.containers:
|
||
ax.bar_label(container, label_type='center', fmt='%.1f%%', padding=0)
|
||
|
||
# 图例
|
||
ax.legend(title=legend_title, loc='center left', bbox_to_anchor=(1, 0.5), frameon=True, framealpha=0.7)
|
||
|
||
plt.tight_layout()
|
||
|
||
return _save_to_base64(dpi) |