Added analysis features for PyTorch

This commit is contained in:
cephi_sui 2024-12-12 15:47:26 -05:00
parent 752ec8b9cd
commit 47995eab85

View File

@ -11,6 +11,7 @@ from enum import Enum
import math import math
import numpy as np import numpy as np
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import itertools import itertools
@ -43,6 +44,9 @@ def line_plot(
x: Stat, y: Stat, color: Stat, x: Stat, y: Stat, color: Stat,
font_size: int font_size: int
): ):
print(x)
print(y)
print(color)
x_data: dict[str, np.ndarray] = accumulate(data_list, color, x) x_data: dict[str, np.ndarray] = accumulate(data_list, color, x)
y_data: dict[str, np.ndarray] = accumulate(data_list, color, y) y_data: dict[str, np.ndarray] = accumulate(data_list, color, y)
@ -66,7 +70,8 @@ def visualize(
x: Stat, x: Stat,
ys: list[Stat], ys: list[Stat],
color: Stat, color: Stat,
filter_list: list[str] = [] filter_list: list[str],
title: str
): ):
# Remove stats entries containing undesired values (like a specific CPU). # Remove stats entries containing undesired values (like a specific CPU).
# data_list = [stats for stats in data_list # data_list = [stats for stats in data_list
@ -102,8 +107,11 @@ def visualize(
for i, y in enumerate(ys): for i, y in enumerate(ys):
ax = axes[i % rows][int(i / rows)] ax = axes[i % rows][int(i / rows)]
line_plot(ax, data_list, x, y, color, font_size) line_plot(ax, data_list, x, y, color, font_size)
ax.get_xaxis().get_major_formatter().set_scientific(False) if type(ax.get_xaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
ax.get_yaxis().get_major_formatter().set_scientific(False) ax.get_xaxis().get_major_formatter().set_scientific(False)
if type(ax.get_yaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
ax.get_yaxis().get_major_formatter().set_scientific(False)
#ax.ticklabel_format(axis='both', style='plain')
handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels() handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels()
# else: # else:
@ -126,11 +134,12 @@ def visualize(
# case Plot.LINE: # case Plot.LINE:
# #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}" # #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}"
# title = "altra_spmv" # title = "altra_spmv"
title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}' #title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
#fig.suptitle(title, fontsize = font_size) #fig.suptitle(title, fontsize = font_size)
fig.legend(handles, labels, fontsize = font_size) fig.legend(handles, labels, fontsize = font_size)
fig.supxlabel(x.value, fontsize = font_size) fig.supxlabel(x.value, fontsize = font_size)
#title = f'{plot.value} plot of {[y.name for y in ys]} vs {x.name} by {color.name} excluding {filter_list}'
#plt.xticks(fontsize=font_size) #plt.xticks(fontsize=font_size)
#plt.yticks(fontsize=font_size) #plt.yticks(fontsize=font_size)
plt.savefig(title.replace(' ', '_') + ".png", dpi = 100) plt.savefig(title.replace(' ', '_') + ".png", dpi = 100)
@ -146,6 +155,8 @@ def main():
parser.add_argument('plot', parser.add_argument('plot',
choices=[x.name.lower() for x in Plot], choices=[x.name.lower() for x in Plot],
help = 'the type of plot') help = 'the type of plot')
parser.add_argument('title', type=str,
help = 'the name of the plot')
parser.add_argument('-r', '--rows', type=int, parser.add_argument('-r', '--rows', type=int,
help = 'the number of rows to display when -y is not specified', help = 'the number of rows to display when -y is not specified',
default = 5) default = 5)
@ -191,7 +202,8 @@ def main():
args.x, args.x,
args.ys, args.ys,
args.color, args.color,
args.filter args.filter,
args.title
) )
if __name__ == '__main__': if __name__ == '__main__':