diff --git a/pytorch/analyze.py b/pytorch/analyze.py index e97eccc..9ca8f48 100755 --- a/pytorch/analyze.py +++ b/pytorch/analyze.py @@ -40,7 +40,8 @@ def box_plot(ax, data_list: list[dict[str, str | int | float]], x: Stat, y: Stat def line_plot( ax, data_list: list[dict[str, str | int | float]], - x: Stat, y: Stat, color: Stat + x: Stat, y: Stat, color: Stat, + font_size: int ): x_data: dict[str, np.ndarray] = accumulate(data_list, color, x) y_data: dict[str, np.ndarray] = accumulate(data_list, color, y) @@ -53,7 +54,7 @@ def line_plot( #print("Plotted x data: " + str(x_data[category])) #print("Plotted y data: " + str(y_data[category])) - ax.set_ylabel(y.value) + ax.set_ylabel(y.value, fontsize=font_size) ax.grid(True) def visualize( @@ -63,7 +64,7 @@ def visualize( size_multiplier: int, font_size: int, x: Stat, - y: Stat, + ys: list[Stat], color: Stat, filter_list: list[str] = [] ): @@ -76,7 +77,7 @@ def visualize( #y = Stat.DTLB_MISS_RATE #color = Stat.SOLVER - if y is None: + if ys is None: #ys = [stat for stat in Stat if stat.name in data_list[0].keys()] #ys = [stat for stat in data_list[0].keys() if "power" not in stat] #and stat is not x @@ -85,51 +86,64 @@ def visualize( #and stat.value not in filter_list] # Create sorted, deduped list of all stats in data_list. ys = [Stat[stat_name] for stat_name in sorted(list(set([stat_name for data in data_list for stat_name in data if type(data[stat_name]) is not list])))] - print([stat.value for stat in ys]) - fig, axes = plt.subplots(rows, int(math.ceil(len(ys) / rows)), - figsize = (16 * size_multiplier, 9 * size_multiplier)) - match plot: - case Plot.BOX: - for i, y in enumerate(ys): - box_plot(axes[i % rows][int(i / rows)], data_list, x, y) - case Plot.LINE: - for i, y in enumerate(ys): - line_plot(axes[i % rows][int(i / rows)], data_list, x, y, color) + print([stat.name for stat in ys]) + print(len(ys)) + print(rows) + print(int(math.ceil(len(ys) / rows))) + fig, axes = plt.subplots(rows, int(math.ceil(len(ys) / rows)), + figsize = (16 * size_multiplier, 9 * size_multiplier)) - handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels() - else: - fig, ax = plt.subplots() + match plot: + case Plot.BOX: + for i, y in enumerate(ys): + box_plot(axes[i % rows][int(i / rows)], data_list, x, y) + case Plot.LINE: + for i, y in enumerate(ys): + ax = axes[i % rows][int(i / rows)] + line_plot(ax, data_list, x, y, color, font_size) + ax.get_xaxis().get_major_formatter().set_scientific(False) + ax.get_yaxis().get_major_formatter().set_scientific(False) - match plot: - case Plot.BOX: - box_plot(ax, data_list, x, y) - case Plot.LINE: - line_plot(ax, data_list, x, y, color) - - handles, labels = ax.get_legend_handles_labels() + handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels() +# else: +# fig, ax = plt.subplots() +# +# match plot: +# case Plot.BOX: +# box_plot(ax, data_list, x, y) +# case Plot.LINE: +# line_plot(ax, data_list, x, y, color) +# +# handles, labels = ax.get_legend_handles_labels() #box_plot(ax, stats, x, y) #line_plot(ax, stats, x, y, color) - match plot: - case Plot.BOX: - title = f"{plot.value}_plot_of_{y.value.replace(' ', '_')}_vs_{x.value.replace(' ', '_')}_excluding_{filter_list}" - case Plot.LINE: - #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}" - title = "altra_spmv" - fig.suptitle(title, fontsize = font_size) +# match plot: +# case Plot.BOX: +# title = f"{plot.value}_plot_of_{y.value.replace(' ', '_')}_vs_{x.value.replace(' ', '_')}_excluding_{filter_list}" +# case Plot.LINE: +# #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}" +# title = "altra_spmv" + title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}' + #fig.suptitle(title, fontsize = font_size) fig.legend(handles, labels, fontsize = font_size) fig.supxlabel(x.value, fontsize = font_size) - plt.savefig(title + ".png", dpi = 100) + #plt.xticks(fontsize=font_size) + #plt.yticks(fontsize=font_size) + plt.savefig(title.replace(' ', '_') + ".png", dpi = 100) plt.show() def main(): + default_figure_size = 4 + default_font_size = 5 * default_figure_size + parser = argparse.ArgumentParser() parser.add_argument('input_dir', help='the input directory') - parser.add_argument('-p', '--plot', + parser.add_argument('plot', choices=[x.name.lower() for x in Plot], help = 'the type of plot') parser.add_argument('-r', '--rows', type=int, @@ -137,18 +151,18 @@ def main(): default = 5) parser.add_argument('-s', '--size', type=int, help = 'figure size multiplier', - default = 4) + default = default_figure_size) parser.add_argument('-fs', '--font_size', type=int, help = 'font size', - default = 40) - parser.add_argument('-x', + default = default_font_size) + parser.add_argument('x', choices=[x.name.lower() for x in Stat], help = 'the name of the x axis') - parser.add_argument('-y', - choices=[x.name.lower() for x in Stat], + parser.add_argument('-ys', nargs='+', + choices=[y.name.lower() for y in Stat], help = 'the name of the y axis') parser.add_argument('-c', '--color', - choices=[x.name.lower() for x in Stat], + choices=[c.name.lower() for c in Stat], help = 'the name of the color') parser.add_argument('-f', '--filter', nargs = '+', help = 'a comma-separated string of names and values to filter out.', @@ -156,7 +170,8 @@ def main(): args = parser.parse_args() args.plot = Plot[args.plot.upper()] args.x = Stat[args.x.upper()] if args.x is not None else None - args.y = Stat[args.y.upper()] if args.y is not None else None + args.ys = ([Stat[y.upper()] for y in args.ys] + if args.ys is not None else None) args.color = Stat[args.color.upper()] if args.color is not None else None data_list: list[dict] = list() @@ -173,7 +188,7 @@ def main(): args.size, args.font_size, args.x, - args.y, + args.ys, args.color, args.filter )