import numpy as np import os import traceback from paste.chart.bar import ( gen_vertical_bars, gen_horizontal_stacked_bars, gen_percent_stacked_bars ) class ChartBarExample: """ 图表测试管理器:封装对 paste.chart.bar 中三个函数的调用。 不修改任何参数结构,仅提供清晰的调用封装与输出管理。 """ def __init__(self, output_directory="./charts"): """ 初始化测试器,定义所有测试数据。 数据结构完全匹配原始函数调用方式。 """ self.output_directory = output_directory os.makedirs(self.output_directory, exist_ok=True) # 纵向堆叠柱形图数据(直接对应 gen_vertical_bars 参数) self.primary_vals = [10, 20, 15, 25, 18] self.nested_vals = [5, 8, 3, 10, 6] self.x_labels_vert = ['产品1', '产品2', '产品3', '产品4', '产品5'] self.group_labels_vert = ['销售量', '退货量'] # 横向堆叠柱形图数据(直接对应 gen_horizontal_stacked_bars 参数) self.data_matrix = np.array([ [10, 20, 15], [15, 12, 18], [8, 16, 10], [12, 14, 13] ]) self.x_labels_hori = ['线上销售', '门店销售', '批发销售'] self.y_labels_hori = ['北京', '上海', '广州', '深圳'] self.y_data_unit_hori = '万元' self.title_hori = '销售构成' # 百分比堆叠柱形图数据(直接对应 gen_percent_stacked_bars 参数) self.data_percent = { 'A组': [10, 20, 15, 18], 'B组': [5, 10, 5, 8], 'C组': [3, 7, 10, 4] } self.x_labels_percent = ['Q1', 'Q2', 'Q3', 'Q4'] self.title_percent = '季度占比' def generate_vertical_bars(self) -> str: """调用 gen_vertical_bars,参数完全一致""" try: return gen_vertical_bars( self.primary_vals, self.nested_vals, self.x_labels_vert, self.group_labels_vert ) except Exception as e: print(f"纵向堆叠柱形图生成失败: {e}") traceback.print_exc() raise def generate_horizontal_stacked_bars(self) -> str: """调用 gen_horizontal_stacked_bars,参数完全一致""" try: return gen_horizontal_stacked_bars( self.data_matrix, self.x_labels_hori, self.y_labels_hori, self.y_data_unit_hori, self.title_hori ) except Exception as e: print(f"横向堆叠柱形图生成失败: {e}") traceback.print_exc() raise def generate_percent_stacked_bars(self) -> str: """调用 gen_percent_stacked_bars,参数完全一致""" try: return gen_percent_stacked_bars( self.data_percent, self.x_labels_percent, self.title_percent ) except Exception as e: print(f"百分比堆叠柱形图生成失败: {e}") traceback.print_exc() raise def save_svg(self, svg_data: str, filename: str) -> None: """ 将 SVG 的 base64 Data URL 写入文件(保留原始 SVG 格式)。 注意:svg_data 是 "data:image/svg+xml;base64,...",需提取真实 SVG 内容。 """ if not svg_data or not isinstance(svg_data, str): print(f"生成的 SVG 数据无效(为空或非字符串): {filename}") return # 提取 base64 编码部分(去除 data URL 前缀) if svg_data.startswith("data:image/svg+xml;base64,"): base64_content = svg_data[len("data:image/svg+xml;base64,"):] try: # 解码 base64 得到原始 SVG 字符串 import base64 svg_content = base64.b64decode(base64_content).decode('utf-8') except Exception as e: print(f"解码 base64 失败: {e}") svg_content = svg_data # 退化为直接写入 else: # 如果不是标准格式,直接写入(兼容调试) svg_content = svg_data filepath = os.path.join(self.output_directory, filename) with open(filepath, 'w', encoding='utf-8') as f: f.write(svg_content) print(f"已保存: {filepath}") def run(self) -> None: """按顺序执行所有图表生成与保存""" print("开始生成图表...") try: print("生成纵向堆叠柱形图...") svg1 = self.generate_vertical_bars() self.save_svg(svg1, "vertical_bars.svg") print("生成横向堆叠柱形图...") svg2 = self.generate_horizontal_stacked_bars() self.save_svg(svg2, "horizontal_stacked_bars.svg") print("生成百分比堆叠柱形图...") svg3 = self.generate_percent_stacked_bars() self.save_svg(svg3, "percent_stacked_bars.svg") print("\n所有图表已成功生成。") print(f"输出目录: {self.output_directory}") print("文件列表:") print(" - vertical_bars.svg") print(" - horizontal_stacked_bars.svg") print(" - percent_stacked_bars.svg") except Exception as e: print(f"\n测试失败: {e}") traceback.print_exc() # 程序入口 if __name__ == "__main__": tester = ChartBarExample() tester.run()