Added analysis features for PyTorch
This commit is contained in:
parent
752ec8b9cd
commit
47995eab85
@ -11,6 +11,7 @@ from enum import Enum
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import itertools
|
||||
|
||||
@ -43,6 +44,9 @@ def line_plot(
|
||||
x: Stat, y: Stat, color: Stat,
|
||||
font_size: int
|
||||
):
|
||||
print(x)
|
||||
print(y)
|
||||
print(color)
|
||||
x_data: dict[str, np.ndarray] = accumulate(data_list, color, x)
|
||||
y_data: dict[str, np.ndarray] = accumulate(data_list, color, y)
|
||||
|
||||
@ -66,7 +70,8 @@ def visualize(
|
||||
x: Stat,
|
||||
ys: list[Stat],
|
||||
color: Stat,
|
||||
filter_list: list[str] = []
|
||||
filter_list: list[str],
|
||||
title: str
|
||||
):
|
||||
# Remove stats entries containing undesired values (like a specific CPU).
|
||||
# data_list = [stats for stats in data_list
|
||||
@ -102,8 +107,11 @@ def visualize(
|
||||
for i, y in enumerate(ys):
|
||||
ax = axes[i % rows][int(i / rows)]
|
||||
line_plot(ax, data_list, x, y, color, font_size)
|
||||
if type(ax.get_xaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
|
||||
ax.get_xaxis().get_major_formatter().set_scientific(False)
|
||||
if type(ax.get_yaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
|
||||
ax.get_yaxis().get_major_formatter().set_scientific(False)
|
||||
#ax.ticklabel_format(axis='both', style='plain')
|
||||
|
||||
handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels()
|
||||
# else:
|
||||
@ -126,11 +134,12 @@ def visualize(
|
||||
# case Plot.LINE:
|
||||
# #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}"
|
||||
# title = "altra_spmv"
|
||||
title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
|
||||
#title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
|
||||
#fig.suptitle(title, fontsize = font_size)
|
||||
fig.legend(handles, labels, fontsize = font_size)
|
||||
fig.supxlabel(x.value, fontsize = font_size)
|
||||
|
||||
#title = f'{plot.value} plot of {[y.name for y in ys]} vs {x.name} by {color.name} excluding {filter_list}'
|
||||
#plt.xticks(fontsize=font_size)
|
||||
#plt.yticks(fontsize=font_size)
|
||||
plt.savefig(title.replace(' ', '_') + ".png", dpi = 100)
|
||||
@ -146,6 +155,8 @@ def main():
|
||||
parser.add_argument('plot',
|
||||
choices=[x.name.lower() for x in Plot],
|
||||
help = 'the type of plot')
|
||||
parser.add_argument('title', type=str,
|
||||
help = 'the name of the plot')
|
||||
parser.add_argument('-r', '--rows', type=int,
|
||||
help = 'the number of rows to display when -y is not specified',
|
||||
default = 5)
|
||||
@ -191,7 +202,8 @@ def main():
|
||||
args.x,
|
||||
args.ys,
|
||||
args.color,
|
||||
args.filter
|
||||
args.filter,
|
||||
args.title
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user