Added analysis features for PyTorch
This commit is contained in:
parent
752ec8b9cd
commit
47995eab85
@ -11,6 +11,7 @@ from enum import Enum
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
@ -43,6 +44,9 @@ def line_plot(
|
|||||||
x: Stat, y: Stat, color: Stat,
|
x: Stat, y: Stat, color: Stat,
|
||||||
font_size: int
|
font_size: int
|
||||||
):
|
):
|
||||||
|
print(x)
|
||||||
|
print(y)
|
||||||
|
print(color)
|
||||||
x_data: dict[str, np.ndarray] = accumulate(data_list, color, x)
|
x_data: dict[str, np.ndarray] = accumulate(data_list, color, x)
|
||||||
y_data: dict[str, np.ndarray] = accumulate(data_list, color, y)
|
y_data: dict[str, np.ndarray] = accumulate(data_list, color, y)
|
||||||
|
|
||||||
@ -66,7 +70,8 @@ def visualize(
|
|||||||
x: Stat,
|
x: Stat,
|
||||||
ys: list[Stat],
|
ys: list[Stat],
|
||||||
color: Stat,
|
color: Stat,
|
||||||
filter_list: list[str] = []
|
filter_list: list[str],
|
||||||
|
title: str
|
||||||
):
|
):
|
||||||
# Remove stats entries containing undesired values (like a specific CPU).
|
# Remove stats entries containing undesired values (like a specific CPU).
|
||||||
# data_list = [stats for stats in data_list
|
# data_list = [stats for stats in data_list
|
||||||
@ -102,8 +107,11 @@ def visualize(
|
|||||||
for i, y in enumerate(ys):
|
for i, y in enumerate(ys):
|
||||||
ax = axes[i % rows][int(i / rows)]
|
ax = axes[i % rows][int(i / rows)]
|
||||||
line_plot(ax, data_list, x, y, color, font_size)
|
line_plot(ax, data_list, x, y, color, font_size)
|
||||||
ax.get_xaxis().get_major_formatter().set_scientific(False)
|
if type(ax.get_xaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
|
||||||
ax.get_yaxis().get_major_formatter().set_scientific(False)
|
ax.get_xaxis().get_major_formatter().set_scientific(False)
|
||||||
|
if type(ax.get_yaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
|
||||||
|
ax.get_yaxis().get_major_formatter().set_scientific(False)
|
||||||
|
#ax.ticklabel_format(axis='both', style='plain')
|
||||||
|
|
||||||
handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels()
|
handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels()
|
||||||
# else:
|
# else:
|
||||||
@ -126,11 +134,12 @@ def visualize(
|
|||||||
# case Plot.LINE:
|
# case Plot.LINE:
|
||||||
# #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}"
|
# #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}"
|
||||||
# title = "altra_spmv"
|
# title = "altra_spmv"
|
||||||
title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
|
#title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
|
||||||
#fig.suptitle(title, fontsize = font_size)
|
#fig.suptitle(title, fontsize = font_size)
|
||||||
fig.legend(handles, labels, fontsize = font_size)
|
fig.legend(handles, labels, fontsize = font_size)
|
||||||
fig.supxlabel(x.value, fontsize = font_size)
|
fig.supxlabel(x.value, fontsize = font_size)
|
||||||
|
|
||||||
|
#title = f'{plot.value} plot of {[y.name for y in ys]} vs {x.name} by {color.name} excluding {filter_list}'
|
||||||
#plt.xticks(fontsize=font_size)
|
#plt.xticks(fontsize=font_size)
|
||||||
#plt.yticks(fontsize=font_size)
|
#plt.yticks(fontsize=font_size)
|
||||||
plt.savefig(title.replace(' ', '_') + ".png", dpi = 100)
|
plt.savefig(title.replace(' ', '_') + ".png", dpi = 100)
|
||||||
@ -146,6 +155,8 @@ def main():
|
|||||||
parser.add_argument('plot',
|
parser.add_argument('plot',
|
||||||
choices=[x.name.lower() for x in Plot],
|
choices=[x.name.lower() for x in Plot],
|
||||||
help = 'the type of plot')
|
help = 'the type of plot')
|
||||||
|
parser.add_argument('title', type=str,
|
||||||
|
help = 'the name of the plot')
|
||||||
parser.add_argument('-r', '--rows', type=int,
|
parser.add_argument('-r', '--rows', type=int,
|
||||||
help = 'the number of rows to display when -y is not specified',
|
help = 'the number of rows to display when -y is not specified',
|
||||||
default = 5)
|
default = 5)
|
||||||
@ -191,7 +202,8 @@ def main():
|
|||||||
args.x,
|
args.x,
|
||||||
args.ys,
|
args.ys,
|
||||||
args.color,
|
args.color,
|
||||||
args.filter
|
args.filter,
|
||||||
|
args.title
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user