Added analysis features for PyTorch

This commit is contained in:
cephi_sui 2024-12-12 15:47:26 -05:00
parent 752ec8b9cd
commit 47995eab85

View File

@ -11,6 +11,7 @@ from enum import Enum
import math
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import itertools
@ -43,6 +44,9 @@ def line_plot(
x: Stat, y: Stat, color: Stat,
font_size: int
):
print(x)
print(y)
print(color)
x_data: dict[str, np.ndarray] = accumulate(data_list, color, x)
y_data: dict[str, np.ndarray] = accumulate(data_list, color, y)
@ -66,7 +70,8 @@ def visualize(
x: Stat,
ys: list[Stat],
color: Stat,
filter_list: list[str] = []
filter_list: list[str],
title: str
):
# Remove stats entries containing undesired values (like a specific CPU).
# data_list = [stats for stats in data_list
@ -102,8 +107,11 @@ def visualize(
for i, y in enumerate(ys):
ax = axes[i % rows][int(i / rows)]
line_plot(ax, data_list, x, y, color, font_size)
if type(ax.get_xaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
ax.get_xaxis().get_major_formatter().set_scientific(False)
if type(ax.get_yaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
ax.get_yaxis().get_major_formatter().set_scientific(False)
#ax.ticklabel_format(axis='both', style='plain')
handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels()
# else:
@ -126,11 +134,12 @@ def visualize(
# case Plot.LINE:
# #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}"
# title = "altra_spmv"
title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
#title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
#fig.suptitle(title, fontsize = font_size)
fig.legend(handles, labels, fontsize = font_size)
fig.supxlabel(x.value, fontsize = font_size)
#title = f'{plot.value} plot of {[y.name for y in ys]} vs {x.name} by {color.name} excluding {filter_list}'
#plt.xticks(fontsize=font_size)
#plt.yticks(fontsize=font_size)
plt.savefig(title.replace(' ', '_') + ".png", dpi = 100)
@ -146,6 +155,8 @@ def main():
parser.add_argument('plot',
choices=[x.name.lower() for x in Plot],
help = 'the type of plot')
parser.add_argument('title', type=str,
help = 'the name of the plot')
parser.add_argument('-r', '--rows', type=int,
help = 'the number of rows to display when -y is not specified',
default = 5)
@ -191,7 +202,8 @@ def main():
args.x,
args.ys,
args.color,
args.filter
args.filter,
args.title
)
if __name__ == '__main__':