Merge commit '47296980495f8bbfc9493e93de85dd62de6fa6b9' as 'paste-framework'
This commit is contained in:
@@ -0,0 +1,291 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from paste import chart
|
||||
from paste.util import ufont
|
||||
|
||||
|
||||
# ===========================
|
||||
# 工具函数:统一字体与保存逻辑
|
||||
# ===========================
|
||||
|
||||
def _setup_plot():
|
||||
"""
|
||||
设置全局绘图参数(字体、负号、DPI),避免重复代码。
|
||||
"""
|
||||
available_font = ufont.get_fonts()
|
||||
plt.rcParams.update({
|
||||
'font.sans-serif': list(available_font),
|
||||
'axes.unicode_minus': False,
|
||||
})
|
||||
|
||||
|
||||
def _save_to_base64(dpi: int = 128) -> str:
|
||||
"""
|
||||
将当前图形保存为 base64 编码的 SVG 字符串,自动关闭图形。
|
||||
返回 Data URL 格式。
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight')
|
||||
plt.close()
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
return f"data:image/svg+xml;base64,{image_base64}"
|
||||
|
||||
|
||||
# ===========================
|
||||
# 1. gen_vertical_bars:竖向堆叠柱状图
|
||||
# ===========================
|
||||
|
||||
def gen_vertical_bars(
|
||||
primary_values: List[float],
|
||||
nested_values: List[float],
|
||||
x_labels: List[str],
|
||||
group_labels: List[str],
|
||||
color_palette: str = 'BuPu',
|
||||
dpi: int = 128
|
||||
) -> str:
|
||||
"""
|
||||
生成竖向堆叠柱状图(两个分组)。
|
||||
|
||||
数据结构说明:
|
||||
- primary_values: [v1, v2, ..., vn] # 主柱高度,长度 = n
|
||||
- nested_values: [v1, v2, ..., vn] # 嵌套柱高度,长度 = n,从主柱顶部叠加
|
||||
- x_labels: ['A', 'B', 'C', ...] # 每个柱子的X轴标签,长度 = n
|
||||
- group_labels: ['Group1', 'Group2'] # 两个分组的图例名称,长度必须为2
|
||||
|
||||
注意:所有列表长度必须一致(n),否则会报错。
|
||||
本函数仅支持两个分组,若需扩展,建议使用 genHBars。
|
||||
|
||||
:param primary_values: 主柱数据(底部),数值列表
|
||||
:param nested_values: 嵌套柱数据(顶部),数值列表
|
||||
:param x_labels: X轴标签列表
|
||||
:param group_labels: 图例标签列表,长度为2
|
||||
:param color_palette: Seaborn 色盘名称
|
||||
:param dpi: 图像分辨率
|
||||
:return: base64 编码的 SVG 图像字符串
|
||||
"""
|
||||
if len(group_labels) != 2:
|
||||
raise ValueError("group_labels 必须包含两个元素:[主组, 嵌套组]")
|
||||
|
||||
_setup_plot()
|
||||
|
||||
colors = chart.get_seaborn_colors(2, palette=color_palette)
|
||||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||||
|
||||
# 绘制主柱(底部)
|
||||
bars1 = ax.bar(x_labels, primary_values, label=group_labels[0], color=colors[0])
|
||||
|
||||
# 绘制嵌套柱(从主柱顶部开始,即 bottom=primary_values)
|
||||
# 注意:原代码中 bottom=0 是错误的!应为 bottom=primary_values
|
||||
# 修正:从主柱顶部开始叠加
|
||||
bars2 = ax.bar(x_labels, nested_values, bottom=primary_values, label=group_labels[1], color=colors[1])
|
||||
|
||||
# 添加数值标签
|
||||
ax.bar_label(bars1, label_type='edge', padding=3) # 边缘标签
|
||||
ax.bar_label(bars2, label_type='center', padding=1) # 中心标签
|
||||
|
||||
# 旋转X轴标签,避免重叠
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
|
||||
# 添加图例
|
||||
ax.legend()
|
||||
|
||||
# 自动紧凑布局
|
||||
plt.tight_layout()
|
||||
|
||||
return _save_to_base64(dpi)
|
||||
|
||||
|
||||
# ===========================
|
||||
# 2. gen_horizontal_stacked_bars:横向堆叠柱状图(多分组)
|
||||
# ===========================
|
||||
|
||||
def gen_horizontal_stacked_bars(
|
||||
data_shap: np.ndarray,
|
||||
x_labels: List[str],
|
||||
y_labels: List[str],
|
||||
y_data_unit: str = '',
|
||||
legend_title: str = '',
|
||||
color_palette: str = 'BuPu',
|
||||
dpi: int = 128
|
||||
) -> str:
|
||||
"""
|
||||
生成横向堆叠柱状图(支持多个分组)。
|
||||
|
||||
数据结构说明:
|
||||
- data_shap: numpy.ndarray, shape = (len(y_labels), len(x_labels))
|
||||
- 每行代表一个 Y 轴类别(如:城市、部门)
|
||||
- 每列代表一个 X 轴分组(如:年份、产品类型)
|
||||
- 例如:data_shap[2][1] 表示第3个Y类别(y_labels[2])在第2个X分组(x_labels[1])的数值
|
||||
|
||||
- x_labels: ['2021', '2022', '2023'] # 每个分组的名称(对应列)
|
||||
- y_labels: ['北京', '上海', '广州'] # 每个柱子的类别名称(对应行)
|
||||
- y_data_unit: 如 '人'、'万元',用于标注总量
|
||||
|
||||
:param data_shap: 二维数值数组,shape=(行数, 列数)
|
||||
:param x_labels: 每个分组的标签列表
|
||||
:param y_labels: 每个柱子的标签列表
|
||||
:param y_data_unit: Y轴总量单位(如 '人')
|
||||
:param legend_title: 图例标题
|
||||
:param color_palette: Seaborn 色盘
|
||||
:param dpi: 分辨率
|
||||
:return: base64 编码的 SVG 图像字符串
|
||||
"""
|
||||
if data_shap.ndim != 2:
|
||||
raise ValueError("data_shap 必须是二维数组,shape=(Y, X)")
|
||||
|
||||
if len(x_labels) != data_shap.shape[1]:
|
||||
raise ValueError(f"x_labels 长度 ({len(x_labels)}) 与 data_shap 列数 ({data_shap.shape[1]}) 不匹配")
|
||||
|
||||
if len(y_labels) != data_shap.shape[0]:
|
||||
raise ValueError(f"y_labels 长度 ({len(y_labels)}) 与 data_shap 行数 ({data_shap.shape[0]}) 不匹配")
|
||||
|
||||
_setup_plot()
|
||||
|
||||
colors = chart.get_seaborn_colors(len(x_labels), palette=color_palette)
|
||||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||||
|
||||
# 初始化左侧起始位置(从0开始)
|
||||
left_pos = np.zeros(len(y_labels))
|
||||
bars = [] # 存储每个分组的柱形对象,用于后续标签绘制
|
||||
|
||||
# 按列(分组)逐个绘制横向堆叠柱形
|
||||
for i in range(len(x_labels)):
|
||||
bar = ax.barh(
|
||||
y_labels,
|
||||
data_shap[:, i],
|
||||
left=left_pos,
|
||||
label=x_labels[i],
|
||||
color=colors[i],
|
||||
edgecolor='white',
|
||||
linewidth=0.5
|
||||
)
|
||||
bars.append(bar)
|
||||
left_pos += data_shap[:, i] # 更新下一次的起始位置
|
||||
|
||||
# 计算每个Y类别的总和
|
||||
totals = np.sum(data_shap, axis=1)
|
||||
# 构建带总量的Y标签:如 "北京\n1200人"
|
||||
y_labels_mix = [f"{place}\n{int(total)}{y_data_unit}" for place, total in zip(y_labels, totals)]
|
||||
ax.set_yticks(np.arange(len(y_labels)))
|
||||
ax.set_yticklabels(y_labels_mix)
|
||||
ax.invert_yaxis() # 使第一个类别在顶部
|
||||
|
||||
# 图例
|
||||
ax.legend(title=legend_title, bbox_to_anchor=(0.958, 0.25), frameon=True, framealpha=0.7)
|
||||
|
||||
# 智能标签函数:根据柱宽决定标签位置
|
||||
def add_smart_labels(spacing_pixels: int = 10):
|
||||
dpi = fig.dpi
|
||||
xlim = ax.get_xlim()
|
||||
x_range = xlim[1] - xlim[0]
|
||||
spacing_data = spacing_pixels * x_range / (fig.get_size_inches()[0] * dpi)
|
||||
|
||||
text_style = {
|
||||
'fontsize': 9,
|
||||
'va': 'center',
|
||||
'bbox': {
|
||||
'boxstyle': 'round',
|
||||
'facecolor': 'white',
|
||||
'alpha': 0.7,
|
||||
'edgecolor': 'none',
|
||||
'pad': 0.3
|
||||
}
|
||||
}
|
||||
|
||||
for i, (place, total_width) in enumerate(zip(y_labels, totals)):
|
||||
values = [bar[i].get_width() for bar in bars]
|
||||
label_text = "+".join(str(int(v)) for v in values if v > 0)
|
||||
if not label_text:
|
||||
continue
|
||||
|
||||
text_width = (len(label_text) * 0.015 + 0.05) * x_range
|
||||
y_pos = bars[0][i].get_y() + bars[0][i].get_height() / 2
|
||||
|
||||
if total_width >= text_width + 2 * spacing_data:
|
||||
ax.text(total_width - spacing_data, y_pos, label_text, ha='right', **text_style)
|
||||
else:
|
||||
ax.text(total_width + spacing_data, y_pos, label_text, ha='left', **text_style)
|
||||
|
||||
add_smart_labels()
|
||||
|
||||
# 添加网格线(仅x轴方向)
|
||||
ax.grid(axis='x', linestyle=':', alpha=0.6)
|
||||
plt.tight_layout()
|
||||
|
||||
return _save_to_base64(dpi)
|
||||
|
||||
|
||||
# ===========================
|
||||
# 3. gen_percent_stacked_bars:百分比堆叠柱状图
|
||||
# ===========================
|
||||
|
||||
def gen_percent_stacked_bars(
|
||||
data: Dict[str, List[float]],
|
||||
x_labels: List[str],
|
||||
legend_title: str = '',
|
||||
color_palette: str = 'BuPu',
|
||||
dpi: int = 128
|
||||
) -> str:
|
||||
"""
|
||||
绘制百分比堆叠柱状图(所有柱子高度统一为100%)。
|
||||
|
||||
数据结构说明:
|
||||
- data: Dict[str, List[float]]
|
||||
- 每个 key 代表一个分组(如 'A组', 'B组')
|
||||
- 每个 value 是长度为 len(x_labels) 的数值列表,表示该组在每个X标签下的数值
|
||||
- 示例:{'A组': [30, 40, 50], 'B组': [70, 60, 50]} → 每列总和为100
|
||||
|
||||
- x_labels: ['2021', '2022', '2023'] # 每个柱子的标签
|
||||
|
||||
注意:每个X位置的总和必须一致(或接近),否则百分比可能失真。
|
||||
本函数自动将每列归一化为100%。
|
||||
|
||||
:param data: 分组数据字典,键为组名,值为数值列表
|
||||
:param x_labels: X轴标签列表
|
||||
:param legend_title: 图例标题
|
||||
:param color_palette: Seaborn 色盘
|
||||
:param dpi: 分辨率
|
||||
:return: base64 编码的 SVG 图像字符串
|
||||
"""
|
||||
if not data:
|
||||
raise ValueError("data 不能为空字典")
|
||||
|
||||
# 验证所有组长度一致
|
||||
group_lengths = [len(values) for values in data.values()]
|
||||
if not all(l == len(x_labels) for l in group_lengths):
|
||||
raise ValueError(f"所有分组的数值列表长度必须等于 x_labels 长度 ({len(x_labels)})")
|
||||
|
||||
_setup_plot()
|
||||
|
||||
# 转换为 numpy 数组
|
||||
group_names = list(data.keys())
|
||||
data_array = np.array([data[group] for group in group_names]) # shape: (n_groups, n_x)
|
||||
percent_data = data_array / data_array.sum(axis=0) * 100 # 按列归一化
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
|
||||
|
||||
bottom = np.zeros(len(x_labels)) # 初始底部为0
|
||||
|
||||
# 绘制每个分组
|
||||
for i, (group, color) in enumerate(zip(group_names, chart.get_seaborn_colors(len(group_names), palette=color_palette))):
|
||||
ax.bar(x_labels, percent_data[i], bottom=bottom, label=group, color=color)
|
||||
bottom += percent_data[i]
|
||||
|
||||
# 设置Y轴范围为0-100%
|
||||
ax.set_ylim(0, 100)
|
||||
|
||||
# 添加百分比标签
|
||||
for container in ax.containers:
|
||||
ax.bar_label(container, label_type='center', fmt='%.1f%%', padding=0)
|
||||
|
||||
# 图例
|
||||
ax.legend(title=legend_title, loc='center left', bbox_to_anchor=(1, 0.5), frameon=True, framealpha=0.7)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
return _save_to_base64(dpi)
|
||||
Reference in New Issue
Block a user