Files
d3i-szct/paste/chart/bar.py
T
zwf 4729698049 Squashed 'paste-framework/' content from commit 34e8684
git-subtree-dir: paste-framework
git-subtree-split: 34e8684c4bc3cebbe177509f42ab4ef5b5425a7a
2026-06-02 19:09:22 +08:00

291 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import base64
from io import BytesIO
from typing import List, Dict
import numpy as np
from matplotlib import pyplot as plt
from paste import chart
from paste.util import ufont
# ===========================
# 工具函数:统一字体与保存逻辑
# ===========================
def _setup_plot():
"""
设置全局绘图参数(字体、负号、DPI),避免重复代码。
"""
available_font = ufont.get_fonts()
plt.rcParams.update({
'font.sans-serif': list(available_font),
'axes.unicode_minus': False,
})
def _save_to_base64(dpi: int = 128) -> str:
"""
将当前图形保存为 base64 编码的 SVG 字符串,自动关闭图形。
返回 Data URL 格式。
"""
buffer = BytesIO()
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight')
plt.close()
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
return f"data:image/svg+xml;base64,{image_base64}"
# ===========================
# 1. gen_vertical_bars:竖向堆叠柱状图
# ===========================
def gen_vertical_bars(
primary_values: List[float],
nested_values: List[float],
x_labels: List[str],
group_labels: List[str],
color_palette: str = 'BuPu',
dpi: int = 128
) -> str:
"""
生成竖向堆叠柱状图(两个分组)。
数据结构说明:
- primary_values: [v1, v2, ..., vn] # 主柱高度,长度 = n
- nested_values: [v1, v2, ..., vn] # 嵌套柱高度,长度 = n,从主柱顶部叠加
- x_labels: ['A', 'B', 'C', ...] # 每个柱子的X轴标签,长度 = n
- group_labels: ['Group1', 'Group2'] # 两个分组的图例名称,长度必须为2
注意:所有列表长度必须一致(n),否则会报错。
本函数仅支持两个分组,若需扩展,建议使用 genHBars。
:param primary_values: 主柱数据(底部),数值列表
:param nested_values: 嵌套柱数据(顶部),数值列表
:param x_labels: X轴标签列表
:param group_labels: 图例标签列表,长度为2
:param color_palette: Seaborn 色盘名称
:param dpi: 图像分辨率
:return: base64 编码的 SVG 图像字符串
"""
if len(group_labels) != 2:
raise ValueError("group_labels 必须包含两个元素:[主组, 嵌套组]")
_setup_plot()
colors = chart.get_seaborn_colors(2, palette=color_palette)
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
# 绘制主柱(底部)
bars1 = ax.bar(x_labels, primary_values, label=group_labels[0], color=colors[0])
# 绘制嵌套柱(从主柱顶部开始,即 bottom=primary_values
# 注意:原代码中 bottom=0 是错误的!应为 bottom=primary_values
# 修正:从主柱顶部开始叠加
bars2 = ax.bar(x_labels, nested_values, bottom=primary_values, label=group_labels[1], color=colors[1])
# 添加数值标签
ax.bar_label(bars1, label_type='edge', padding=3) # 边缘标签
ax.bar_label(bars2, label_type='center', padding=1) # 中心标签
# 旋转X轴标签,避免重叠
plt.xticks(rotation=45, ha='right')
# 添加图例
ax.legend()
# 自动紧凑布局
plt.tight_layout()
return _save_to_base64(dpi)
# ===========================
# 2. gen_horizontal_stacked_bars:横向堆叠柱状图(多分组)
# ===========================
def gen_horizontal_stacked_bars(
data_shap: np.ndarray,
x_labels: List[str],
y_labels: List[str],
y_data_unit: str = '',
legend_title: str = '',
color_palette: str = 'BuPu',
dpi: int = 128
) -> str:
"""
生成横向堆叠柱状图(支持多个分组)。
数据结构说明:
- data_shap: numpy.ndarray, shape = (len(y_labels), len(x_labels))
- 每行代表一个 Y 轴类别(如:城市、部门)
- 每列代表一个 X 轴分组(如:年份、产品类型)
- 例如:data_shap[2][1] 表示第3个Y类别(y_labels[2])在第2个X分组(x_labels[1])的数值
- x_labels: ['2021', '2022', '2023'] # 每个分组的名称(对应列)
- y_labels: ['北京', '上海', '广州'] # 每个柱子的类别名称(对应行)
- y_data_unit: 如 '''万元',用于标注总量
:param data_shap: 二维数值数组,shape=(行数, 列数)
:param x_labels: 每个分组的标签列表
:param y_labels: 每个柱子的标签列表
:param y_data_unit: Y轴总量单位(如 ''
:param legend_title: 图例标题
:param color_palette: Seaborn 色盘
:param dpi: 分辨率
:return: base64 编码的 SVG 图像字符串
"""
if data_shap.ndim != 2:
raise ValueError("data_shap 必须是二维数组,shape=(Y, X)")
if len(x_labels) != data_shap.shape[1]:
raise ValueError(f"x_labels 长度 ({len(x_labels)}) 与 data_shap 列数 ({data_shap.shape[1]}) 不匹配")
if len(y_labels) != data_shap.shape[0]:
raise ValueError(f"y_labels 长度 ({len(y_labels)}) 与 data_shap 行数 ({data_shap.shape[0]}) 不匹配")
_setup_plot()
colors = chart.get_seaborn_colors(len(x_labels), palette=color_palette)
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
# 初始化左侧起始位置(从0开始)
left_pos = np.zeros(len(y_labels))
bars = [] # 存储每个分组的柱形对象,用于后续标签绘制
# 按列(分组)逐个绘制横向堆叠柱形
for i in range(len(x_labels)):
bar = ax.barh(
y_labels,
data_shap[:, i],
left=left_pos,
label=x_labels[i],
color=colors[i],
edgecolor='white',
linewidth=0.5
)
bars.append(bar)
left_pos += data_shap[:, i] # 更新下一次的起始位置
# 计算每个Y类别的总和
totals = np.sum(data_shap, axis=1)
# 构建带总量的Y标签:如 "北京\n1200人"
y_labels_mix = [f"{place}\n{int(total)}{y_data_unit}" for place, total in zip(y_labels, totals)]
ax.set_yticks(np.arange(len(y_labels)))
ax.set_yticklabels(y_labels_mix)
ax.invert_yaxis() # 使第一个类别在顶部
# 图例
ax.legend(title=legend_title, bbox_to_anchor=(0.958, 0.25), frameon=True, framealpha=0.7)
# 智能标签函数:根据柱宽决定标签位置
def add_smart_labels(spacing_pixels: int = 10):
dpi = fig.dpi
xlim = ax.get_xlim()
x_range = xlim[1] - xlim[0]
spacing_data = spacing_pixels * x_range / (fig.get_size_inches()[0] * dpi)
text_style = {
'fontsize': 9,
'va': 'center',
'bbox': {
'boxstyle': 'round',
'facecolor': 'white',
'alpha': 0.7,
'edgecolor': 'none',
'pad': 0.3
}
}
for i, (place, total_width) in enumerate(zip(y_labels, totals)):
values = [bar[i].get_width() for bar in bars]
label_text = "+".join(str(int(v)) for v in values if v > 0)
if not label_text:
continue
text_width = (len(label_text) * 0.015 + 0.05) * x_range
y_pos = bars[0][i].get_y() + bars[0][i].get_height() / 2
if total_width >= text_width + 2 * spacing_data:
ax.text(total_width - spacing_data, y_pos, label_text, ha='right', **text_style)
else:
ax.text(total_width + spacing_data, y_pos, label_text, ha='left', **text_style)
add_smart_labels()
# 添加网格线(仅x轴方向)
ax.grid(axis='x', linestyle=':', alpha=0.6)
plt.tight_layout()
return _save_to_base64(dpi)
# ===========================
# 3. gen_percent_stacked_bars:百分比堆叠柱状图
# ===========================
def gen_percent_stacked_bars(
data: Dict[str, List[float]],
x_labels: List[str],
legend_title: str = '',
color_palette: str = 'BuPu',
dpi: int = 128
) -> str:
"""
绘制百分比堆叠柱状图(所有柱子高度统一为100%)。
数据结构说明:
- data: Dict[str, List[float]]
- 每个 key 代表一个分组(如 'A组', 'B组'
- 每个 value 是长度为 len(x_labels) 的数值列表,表示该组在每个X标签下的数值
- 示例:{'A组': [30, 40, 50], 'B组': [70, 60, 50]} → 每列总和为100
- x_labels: ['2021', '2022', '2023'] # 每个柱子的标签
注意:每个X位置的总和必须一致(或接近),否则百分比可能失真。
本函数自动将每列归一化为100%
:param data: 分组数据字典,键为组名,值为数值列表
:param x_labels: X轴标签列表
:param legend_title: 图例标题
:param color_palette: Seaborn 色盘
:param dpi: 分辨率
:return: base64 编码的 SVG 图像字符串
"""
if not data:
raise ValueError("data 不能为空字典")
# 验证所有组长度一致
group_lengths = [len(values) for values in data.values()]
if not all(l == len(x_labels) for l in group_lengths):
raise ValueError(f"所有分组的数值列表长度必须等于 x_labels 长度 ({len(x_labels)})")
_setup_plot()
# 转换为 numpy 数组
group_names = list(data.keys())
data_array = np.array([data[group] for group in group_names]) # shape: (n_groups, n_x)
percent_data = data_array / data_array.sum(axis=0) * 100 # 按列归一化
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
bottom = np.zeros(len(x_labels)) # 初始底部为0
# 绘制每个分组
for i, (group, color) in enumerate(zip(group_names, chart.get_seaborn_colors(len(group_names), palette=color_palette))):
ax.bar(x_labels, percent_data[i], bottom=bottom, label=group, color=color)
bottom += percent_data[i]
# 设置Y轴范围为0-100%
ax.set_ylim(0, 100)
# 添加百分比标签
for container in ax.containers:
ax.bar_label(container, label_type='center', fmt='%.1f%%', padding=0)
# 图例
ax.legend(title=legend_title, loc='center left', bbox_to_anchor=(1, 0.5), frameon=True, framealpha=0.7)
plt.tight_layout()
return _save_to_base64(dpi)