diff --git a/pytorch/analyze.py b/pytorch/analyze.py index e67764e..d023bfa 100755 --- a/pytorch/analyze.py +++ b/pytorch/analyze.py @@ -11,6 +11,7 @@ from enum import Enum import math import numpy as np +import matplotlib import matplotlib.pyplot as plt import itertools @@ -43,6 +44,9 @@ def line_plot( x: Stat, y: Stat, color: Stat, font_size: int ): + print(x) + print(y) + print(color) x_data: dict[str, np.ndarray] = accumulate(data_list, color, x) y_data: dict[str, np.ndarray] = accumulate(data_list, color, y) @@ -66,7 +70,8 @@ def visualize( x: Stat, ys: list[Stat], color: Stat, - filter_list: list[str] = [] + filter_list: list[str], + title: str ): # Remove stats entries containing undesired values (like a specific CPU). # data_list = [stats for stats in data_list @@ -102,8 +107,11 @@ def visualize( for i, y in enumerate(ys): ax = axes[i % rows][int(i / rows)] line_plot(ax, data_list, x, y, color, font_size) - ax.get_xaxis().get_major_formatter().set_scientific(False) - ax.get_yaxis().get_major_formatter().set_scientific(False) + if type(ax.get_xaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter: + ax.get_xaxis().get_major_formatter().set_scientific(False) + if type(ax.get_yaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter: + ax.get_yaxis().get_major_formatter().set_scientific(False) + #ax.ticklabel_format(axis='both', style='plain') handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels() # else: @@ -126,11 +134,12 @@ def visualize( # case Plot.LINE: # #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}" # title = "altra_spmv" - title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}' + #title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}' #fig.suptitle(title, fontsize = font_size) fig.legend(handles, labels, fontsize = font_size) fig.supxlabel(x.value, fontsize = font_size) + #title = f'{plot.value} plot of {[y.name for y in ys]} vs {x.name} by {color.name} excluding {filter_list}' #plt.xticks(fontsize=font_size) #plt.yticks(fontsize=font_size) plt.savefig(title.replace(' ', '_') + ".png", dpi = 100) @@ -146,6 +155,8 @@ def main(): parser.add_argument('plot', choices=[x.name.lower() for x in Plot], help = 'the type of plot') + parser.add_argument('title', type=str, + help = 'the name of the plot') parser.add_argument('-r', '--rows', type=int, help = 'the number of rows to display when -y is not specified', default = 5) @@ -191,7 +202,8 @@ def main(): args.x, args.ys, args.color, - args.filter + args.filter, + args.title ) if __name__ == '__main__':