diff --git a/pytorch/analyze.py b/pytorch/analyze.py index 109915a..e97eccc 100755 --- a/pytorch/analyze.py +++ b/pytorch/analyze.py @@ -18,11 +18,11 @@ class Plot(Enum): BOX = 'box' LINE = 'line' -def accumulate(stats_list: list[dict[str, str | int | float]], category: str, value: str): - print(category) - print(value) - category_list = np.array([stats[category] for stats in stats_list if value in stats]) - value_list = np.array([stats[value] for stats in stats_list if value in stats]) +def accumulate(data_list: list[dict[str, str | int | float]], category: Stat, value: Stat): + #print(category.name) + #print(value.name) + category_list = np.array([stats[category.name] for stats in data_list if value.name in stats]) + value_list = np.array([stats[value.name] for stats in data_list if value.name in stats]) result: dict[np.ndarray] = dict() for category in np.sort(np.unique(category_list)): @@ -30,34 +30,34 @@ def accumulate(stats_list: list[dict[str, str | int | float]], category: str, va return result -def box_plot(ax, stats_list: list[dict[str, str | int | float]], x: Stat, y: Stat): - data: dict[str, np.ndarray] = accumulate(stats_list, x, y) +def box_plot(ax, data_list: list[dict[str, str | int | float]], x: Stat, y: Stat): + data: dict[str, np.ndarray] = accumulate(data_list, x, y) - print("Plotted data: " + str(data)) + #print("Plotted data: " + str(data)) ax.boxplot(data.values(), tick_labels=data.keys()) ax.set_ylabel(y.value) def line_plot( - ax, stats_list: list[dict[str, str | int | float]], + ax, data_list: list[dict[str, str | int | float]], x: Stat, y: Stat, color: Stat ): - x_data: dict[str, np.ndarray] = accumulate(stats_list, color, x) - y_data: dict[str, np.ndarray] = accumulate(stats_list, color, y) + x_data: dict[str, np.ndarray] = accumulate(data_list, color, x) + y_data: dict[str, np.ndarray] = accumulate(data_list, color, y) for category in x_data.keys(): sorted_indices = np.argsort(x_data[category]) x_data[category] = x_data[category][sorted_indices] y_data[category] = y_data[category][sorted_indices] ax.plot(x_data[category], y_data[category], label=category) - print("Plotted x data: " + str(x_data[category])) - print("Plotted y data: " + str(y_data[category])) + #print("Plotted x data: " + str(x_data[category])) + #print("Plotted y data: " + str(y_data[category])) - ax.set_ylabel(y) + ax.set_ylabel(y.value) ax.grid(True) def visualize( - stats_list: list[dict[str, str | int | float]], + data_list: list[dict[str, str | int | float]], plot: Plot, rows: int, size_multiplier: int, @@ -68,7 +68,7 @@ def visualize( filter_list: list[str] = [] ): # Remove stats entries containing undesired values (like a specific CPU). -# stats_list = [stats for stats in stats_list +# data_list = [stats for stats in data_list # if len([stats[key] for key in stats.keys() # if stats[key] in filter_list]) == 0] @@ -77,21 +77,25 @@ def visualize( #color = Stat.SOLVER if y is None: - #ys = [stat for stat in Stat if stat.value in stats_list[0].keys() - ys = [stat for stat in stats_list[0].keys() if "power" not in stat] + #ys = [stat for stat in Stat if stat.name in data_list[0].keys()] + #ys = [stat for stat in data_list[0].keys() if "power" not in stat] #and stat is not x #and y != color #and y != marker #and stat.value not in filter_list] + # Create sorted, deduped list of all stats in data_list. + ys = [Stat[stat_name] for stat_name in sorted(list(set([stat_name for data in data_list for stat_name in data if type(data[stat_name]) is not list])))] + print([stat.value for stat in ys]) + fig, axes = plt.subplots(rows, int(math.ceil(len(ys) / rows)), figsize = (16 * size_multiplier, 9 * size_multiplier)) match plot: case Plot.BOX: for i, y in enumerate(ys): - box_plot(axes[i % rows][int(i / rows)], stats_list, x, y) + box_plot(axes[i % rows][int(i / rows)], data_list, x, y) case Plot.LINE: for i, y in enumerate(ys): - line_plot(axes[i % rows][int(i / rows)], stats_list, x, y, color) + line_plot(axes[i % rows][int(i / rows)], data_list, x, y, color) handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels() else: @@ -99,9 +103,9 @@ def visualize( match plot: case Plot.BOX: - box_plot(ax, stats_list, x, y) + box_plot(ax, data_list, x, y) case Plot.LINE: - line_plot(ax, stats_list, x, y, color) + line_plot(ax, data_list, x, y, color) handles, labels = ax.get_legend_handles_labels() @@ -116,7 +120,7 @@ def visualize( title = "altra_spmv" fig.suptitle(title, fontsize = font_size) fig.legend(handles, labels, fontsize = font_size) - fig.supxlabel(x, fontsize = font_size) + fig.supxlabel(x.value, fontsize = font_size) plt.savefig(title + ".png", dpi = 100) plt.show() @@ -138,19 +142,22 @@ def main(): help = 'font size', default = 40) parser.add_argument('-x', - #choices=[x.name.lower() for x in Stat], + choices=[x.name.lower() for x in Stat], help = 'the name of the x axis') parser.add_argument('-y', - #choices=[x.name.lower() for x in Stat], + choices=[x.name.lower() for x in Stat], help = 'the name of the y axis') parser.add_argument('-c', '--color', - #choices=[x.name.lower() for x in Stat], + choices=[x.name.lower() for x in Stat], help = 'the name of the color') parser.add_argument('-f', '--filter', nargs = '+', help = 'a comma-separated string of names and values to filter out.', default = []) args = parser.parse_args() args.plot = Plot[args.plot.upper()] + args.x = Stat[args.x.upper()] if args.x is not None else None + args.y = Stat[args.y.upper()] if args.y is not None else None + args.color = Stat[args.color.upper()] if args.color is not None else None data_list: list[dict] = list() @@ -159,14 +166,17 @@ def main(): data_list.append(json.load(file)) print(filename + " loaded.") - print(data_list) - #x = Stat[args.x.upper()] if args.x is not None else None - x = args.x - #y = Stat[args.y.upper()] if args.y is not None else None - y = args.y - #color = Stat[args.color.upper()] if args.color is not None else None - color = args.color - visualize(data_list, args.plot, args.rows, args.size, args.font_size, x, y, color, args.filter) + visualize( + data_list, + args.plot, + args.rows, + args.size, + args.font_size, + args.x, + args.y, + args.color, + args.filter + ) if __name__ == '__main__': main()