import base64 from io import BytesIO from typing import List, Dict import numpy as np from matplotlib import pyplot as plt from paste import chart from paste.util import ufont # =========================== # 工具函数:统一字体与保存逻辑 # =========================== def _setup_plot(): """ 设置全局绘图参数(字体、负号、DPI),避免重复代码。 """ available_font = ufont.get_fonts() plt.rcParams.update({ 'font.sans-serif': list(available_font), 'axes.unicode_minus': False, }) def _save_to_base64(dpi: int = 128) -> str: """ 将当前图形保存为 base64 编码的 SVG 字符串,自动关闭图形。 返回 Data URL 格式。 """ buffer = BytesIO() plt.savefig(buffer, format='svg', dpi=dpi, bbox_inches='tight') plt.close() image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') return f"data:image/svg+xml;base64,{image_base64}" # =========================== # 1. gen_vertical_bars:竖向堆叠柱状图 # =========================== def gen_vertical_bars( primary_values: List[float], nested_values: List[float], x_labels: List[str], group_labels: List[str], color_palette: str = 'BuPu', dpi: int = 128 ) -> str: """ 生成竖向堆叠柱状图(两个分组)。 数据结构说明: - primary_values: [v1, v2, ..., vn] # 主柱高度,长度 = n - nested_values: [v1, v2, ..., vn] # 嵌套柱高度,长度 = n,从主柱顶部叠加 - x_labels: ['A', 'B', 'C', ...] # 每个柱子的X轴标签,长度 = n - group_labels: ['Group1', 'Group2'] # 两个分组的图例名称,长度必须为2 注意:所有列表长度必须一致(n),否则会报错。 本函数仅支持两个分组,若需扩展,建议使用 genHBars。 :param primary_values: 主柱数据(底部),数值列表 :param nested_values: 嵌套柱数据(顶部),数值列表 :param x_labels: X轴标签列表 :param group_labels: 图例标签列表,长度为2 :param color_palette: Seaborn 色盘名称 :param dpi: 图像分辨率 :return: base64 编码的 SVG 图像字符串 """ if len(group_labels) != 2: raise ValueError("group_labels 必须包含两个元素:[主组, 嵌套组]") _setup_plot() colors = chart.get_seaborn_colors(2, palette=color_palette) fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi) # 绘制主柱(底部) bars1 = ax.bar(x_labels, primary_values, label=group_labels[0], color=colors[0]) # 绘制嵌套柱(从主柱顶部开始,即 bottom=primary_values) # 注意:原代码中 bottom=0 是错误的!应为 bottom=primary_values # 修正:从主柱顶部开始叠加 bars2 = ax.bar(x_labels, nested_values, bottom=primary_values, label=group_labels[1], color=colors[1]) # 添加数值标签 ax.bar_label(bars1, label_type='edge', padding=3) # 边缘标签 ax.bar_label(bars2, label_type='center', padding=1) # 中心标签 # 旋转X轴标签,避免重叠 plt.xticks(rotation=45, ha='right') # 添加图例 ax.legend() # 自动紧凑布局 plt.tight_layout() return _save_to_base64(dpi) # =========================== # 2. gen_horizontal_stacked_bars:横向堆叠柱状图(多分组) # =========================== def gen_horizontal_stacked_bars( data_shap: np.ndarray, x_labels: List[str], y_labels: List[str], y_data_unit: str = '', legend_title: str = '', color_palette: str = 'BuPu', dpi: int = 128 ) -> str: """ 生成横向堆叠柱状图(支持多个分组)。 数据结构说明: - data_shap: numpy.ndarray, shape = (len(y_labels), len(x_labels)) - 每行代表一个 Y 轴类别(如:城市、部门) - 每列代表一个 X 轴分组(如:年份、产品类型) - 例如:data_shap[2][1] 表示第3个Y类别(y_labels[2])在第2个X分组(x_labels[1])的数值 - x_labels: ['2021', '2022', '2023'] # 每个分组的名称(对应列) - y_labels: ['北京', '上海', '广州'] # 每个柱子的类别名称(对应行) - y_data_unit: 如 '人'、'万元',用于标注总量 :param data_shap: 二维数值数组,shape=(行数, 列数) :param x_labels: 每个分组的标签列表 :param y_labels: 每个柱子的标签列表 :param y_data_unit: Y轴总量单位(如 '人') :param legend_title: 图例标题 :param color_palette: Seaborn 色盘 :param dpi: 分辨率 :return: base64 编码的 SVG 图像字符串 """ if data_shap.ndim != 2: raise ValueError("data_shap 必须是二维数组,shape=(Y, X)") if len(x_labels) != data_shap.shape[1]: raise ValueError(f"x_labels 长度 ({len(x_labels)}) 与 data_shap 列数 ({data_shap.shape[1]}) 不匹配") if len(y_labels) != data_shap.shape[0]: raise ValueError(f"y_labels 长度 ({len(y_labels)}) 与 data_shap 行数 ({data_shap.shape[0]}) 不匹配") _setup_plot() colors = chart.get_seaborn_colors(len(x_labels), palette=color_palette) fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi) # 初始化左侧起始位置(从0开始) left_pos = np.zeros(len(y_labels)) bars = [] # 存储每个分组的柱形对象,用于后续标签绘制 # 按列(分组)逐个绘制横向堆叠柱形 for i in range(len(x_labels)): bar = ax.barh( y_labels, data_shap[:, i], left=left_pos, label=x_labels[i], color=colors[i], edgecolor='white', linewidth=0.5 ) bars.append(bar) left_pos += data_shap[:, i] # 更新下一次的起始位置 # 计算每个Y类别的总和 totals = np.sum(data_shap, axis=1) # 构建带总量的Y标签:如 "北京\n1200人" y_labels_mix = [f"{place}\n{int(total)}{y_data_unit}" for place, total in zip(y_labels, totals)] ax.set_yticks(np.arange(len(y_labels))) ax.set_yticklabels(y_labels_mix) ax.invert_yaxis() # 使第一个类别在顶部 # 图例 ax.legend(title=legend_title, bbox_to_anchor=(0.958, 0.25), frameon=True, framealpha=0.7) # 智能标签函数:根据柱宽决定标签位置 def add_smart_labels(spacing_pixels: int = 10): dpi = fig.dpi xlim = ax.get_xlim() x_range = xlim[1] - xlim[0] spacing_data = spacing_pixels * x_range / (fig.get_size_inches()[0] * dpi) text_style = { 'fontsize': 9, 'va': 'center', 'bbox': { 'boxstyle': 'round', 'facecolor': 'white', 'alpha': 0.7, 'edgecolor': 'none', 'pad': 0.3 } } for i, (place, total_width) in enumerate(zip(y_labels, totals)): values = [bar[i].get_width() for bar in bars] label_text = "+".join(str(int(v)) for v in values if v > 0) if not label_text: continue text_width = (len(label_text) * 0.015 + 0.05) * x_range y_pos = bars[0][i].get_y() + bars[0][i].get_height() / 2 if total_width >= text_width + 2 * spacing_data: ax.text(total_width - spacing_data, y_pos, label_text, ha='right', **text_style) else: ax.text(total_width + spacing_data, y_pos, label_text, ha='left', **text_style) add_smart_labels() # 添加网格线(仅x轴方向) ax.grid(axis='x', linestyle=':', alpha=0.6) plt.tight_layout() return _save_to_base64(dpi) # =========================== # 3. gen_percent_stacked_bars:百分比堆叠柱状图 # =========================== def gen_percent_stacked_bars( data: Dict[str, List[float]], x_labels: List[str], legend_title: str = '', color_palette: str = 'BuPu', dpi: int = 128 ) -> str: """ 绘制百分比堆叠柱状图(所有柱子高度统一为100%)。 数据结构说明: - data: Dict[str, List[float]] - 每个 key 代表一个分组(如 'A组', 'B组') - 每个 value 是长度为 len(x_labels) 的数值列表,表示该组在每个X标签下的数值 - 示例:{'A组': [30, 40, 50], 'B组': [70, 60, 50]} → 每列总和为100 - x_labels: ['2021', '2022', '2023'] # 每个柱子的标签 注意:每个X位置的总和必须一致(或接近),否则百分比可能失真。 本函数自动将每列归一化为100%。 :param data: 分组数据字典,键为组名,值为数值列表 :param x_labels: X轴标签列表 :param legend_title: 图例标题 :param color_palette: Seaborn 色盘 :param dpi: 分辨率 :return: base64 编码的 SVG 图像字符串 """ if not data: raise ValueError("data 不能为空字典") # 验证所有组长度一致 group_lengths = [len(values) for values in data.values()] if not all(l == len(x_labels) for l in group_lengths): raise ValueError(f"所有分组的数值列表长度必须等于 x_labels 长度 ({len(x_labels)})") _setup_plot() # 转换为 numpy 数组 group_names = list(data.keys()) data_array = np.array([data[group] for group in group_names]) # shape: (n_groups, n_x) percent_data = data_array / data_array.sum(axis=0) * 100 # 按列归一化 fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi) bottom = np.zeros(len(x_labels)) # 初始底部为0 # 绘制每个分组 for i, (group, color) in enumerate(zip(group_names, chart.get_seaborn_colors(len(group_names), palette=color_palette))): ax.bar(x_labels, percent_data[i], bottom=bottom, label=group, color=color) bottom += percent_data[i] # 设置Y轴范围为0-100% ax.set_ylim(0, 100) # 添加百分比标签 for container in ax.containers: ax.bar_label(container, label_type='center', fmt='%.1f%%', padding=0) # 图例 ax.legend(title=legend_title, loc='center left', bbox_to_anchor=(1, 0.5), frameon=True, framealpha=0.7) plt.tight_layout() return _save_to_base64(dpi)