ampere_research/pytorch/analyze.py

289 lines
11 KiB
Python
Raw Permalink Normal View History

2024-12-03 22:16:08 -05:00
#! /bin/python3
2024-12-05 14:17:08 -05:00
from data_stat import Stat, Cpu
2024-12-03 22:16:08 -05:00
import argparse
import os, glob
import re
import json
from enum import Enum
import math
import numpy as np
2024-12-12 15:47:26 -05:00
import matplotlib
2024-12-03 22:16:08 -05:00
import matplotlib.pyplot as plt
import itertools
class Plot(Enum):
BOX = 'box'
LINE = 'line'
2024-12-16 00:10:30 -05:00
def accumulate(
data_list: list[dict[str, str | int | float]],
category: Stat,
category2: Stat,
value: Stat
) -> dict[np.ndarray]:
result: dict[np.ndarray] = dict()
2024-12-05 15:21:29 -05:00
value_list = np.array([stats[value.name] for stats in data_list if value.name in stats])
2024-12-03 22:16:08 -05:00
2024-12-16 00:10:30 -05:00
if category2 is None:
category_list = np.array([(stats[category.name])
for stats in data_list if value.name in stats])
for category in np.unique(category_list):
result[category] = value_list[category_list == category]
else:
category_list = np.array([(stats[category.name],
stats[category2.name])
for stats in data_list if value.name in stats], dtype='object')
for category in category_list:
mask = np.logical_and(category_list[:, 0] == category[0], category_list[:, 1] == category[1])
assert (tuple(category) not in result.keys()
or np.array_equal(result[tuple(category)], value_list[mask]))
result[tuple(category)] = value_list[mask]
2024-12-03 22:16:08 -05:00
return result
2024-12-05 15:21:29 -05:00
def box_plot(ax, data_list: list[dict[str, str | int | float]], x: Stat, y: Stat):
data: dict[str, np.ndarray] = accumulate(data_list, x, y)
2024-12-03 22:16:08 -05:00
2024-12-05 15:21:29 -05:00
#print("Plotted data: " + str(data))
2024-12-03 22:16:08 -05:00
ax.boxplot(data.values(), tick_labels=data.keys())
ax.set_ylabel(y.value)
def line_plot(
2024-12-05 15:21:29 -05:00
ax, data_list: list[dict[str, str | int | float]],
2024-12-16 00:10:30 -05:00
x: Stat, y: Stat, color: Stat, linestyle: Stat
2024-12-03 22:16:08 -05:00
):
2024-12-16 00:10:30 -05:00
x_data: dict[str, np.ndarray] = accumulate(data_list, color, linestyle, x)
y_data: dict[str, np.ndarray] = accumulate(data_list, color, linestyle, y)
2024-12-03 22:16:08 -05:00
2024-12-16 00:10:30 -05:00
linestyle_tuple = [
('solid', (0, ())),
#('loosely dotted', (0, (1, 10))),
('dotted', (0, (1, 1))),
#('densely dotted', (0, (1, 1))),
('long dash with offset', (5, (10, 3))),
('loosely dashed', (0, (5, 10))),
('dashed', (0, (5, 5))),
('densely dashed', (0, (5, 1))),
('loosely dashdotted', (0, (3, 10, 1, 10))),
('dashdotted', (0, (3, 5, 1, 5))),
('densely dashdotted', (0, (3, 1, 1, 1))),
('dashdotdotted', (0, (3, 5, 1, 5, 1, 5))),
('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]
if linestyle is None:
color_mapping = dict(zip(
sorted({color for color in x_data.keys()}),
matplotlib.colors.BASE_COLORS.values()))
else:
color_mapping = dict(zip(
sorted({color for (color, _) in x_data.keys()}),
matplotlib.colors.BASE_COLORS.values()))
linestyle_mapping = dict(zip(
sorted({linestyle for (_, linestyle) in x_data.keys()}),
[linestyle for (_, linestyle) in linestyle_tuple]))
for category in sorted(x_data.keys()):
2024-12-03 22:16:08 -05:00
sorted_indices = np.argsort(x_data[category])
x_data[category] = x_data[category][sorted_indices]
y_data[category] = y_data[category][sorted_indices]
2024-12-16 00:10:30 -05:00
if linestyle is None:
ax.plot(x_data[category], y_data[category], label=str(category), marker='o')
else:
ax.plot(x_data[category], y_data[category], label=str(category[0]) + ", " + str(category[1]),
color=color_mapping[category[0]],
linestyle=linestyle_mapping[category[1]],
marker='o')
2024-12-05 15:21:29 -05:00
#print("Plotted x data: " + str(x_data[category]))
#print("Plotted y data: " + str(y_data[category]))
2024-12-16 00:10:30 -05:00
#for category, (_, linestyle) in zip(sorted(x_data.keys()), linestyle_tuple):
#ax.plot(x_data[category], y_data[category], label=category, marker='o', linestyle=linestyle)
2024-12-03 22:16:08 -05:00
2024-12-16 00:10:30 -05:00
#ax.set_yscale('log')
#ax.set_xscale('log')
2024-12-03 22:16:08 -05:00
def visualize(
2024-12-05 15:21:29 -05:00
data_list: list[dict[str, str | int | float]],
2024-12-03 22:16:08 -05:00
plot: Plot,
rows: int,
size_multiplier: int,
font_size: int,
x: Stat,
2024-12-09 15:07:23 -05:00
ys: list[Stat],
2024-12-03 22:16:08 -05:00
color: Stat,
2024-12-16 00:10:30 -05:00
linestyle: Stat,
x_log: bool,
y_log: bool,
2024-12-12 15:47:26 -05:00
filter_list: list[str],
title: str
2024-12-03 22:16:08 -05:00
):
# Remove stats entries containing undesired values (like a specific CPU).
2024-12-05 15:21:29 -05:00
# data_list = [stats for stats in data_list
2024-12-03 22:16:08 -05:00
# if len([stats[key] for key in stats.keys()
# if stats[key] in filter_list]) == 0]
#x = Stat.MAXWELL_SIZE
#y = Stat.DTLB_MISS_RATE
#color = Stat.SOLVER
2024-12-09 15:07:23 -05:00
if ys is None:
2024-12-05 15:21:29 -05:00
#ys = [stat for stat in Stat if stat.name in data_list[0].keys()]
#ys = [stat for stat in data_list[0].keys() if "power" not in stat]
2024-12-03 22:16:08 -05:00
#and stat is not x
#and y != color
#and y != marker
#and stat.value not in filter_list]
2024-12-05 15:21:29 -05:00
# Create sorted, deduped list of all stats in data_list.
ys = [Stat[stat_name] for stat_name in sorted(list(set([stat_name for data in data_list for stat_name in data if type(data[stat_name]) is not list])))]
2024-12-03 22:16:08 -05:00
2024-12-09 15:07:23 -05:00
print([stat.name for stat in ys])
print(len(ys))
print(rows)
print(int(math.ceil(len(ys) / rows)))
fig, axes = plt.subplots(rows, int(math.ceil(len(ys) / rows)),
figsize = (16 * size_multiplier, 9 * size_multiplier))
2024-12-16 00:10:30 -05:00
print(len(axes))
2024-12-03 22:16:08 -05:00
match plot:
case Plot.BOX:
2024-12-09 15:07:23 -05:00
for i, y in enumerate(ys):
2024-12-16 00:10:30 -05:00
ax = axes[i % rows][int(i / rows)]
box_plot(ax, data_list, x, y)
2024-12-03 22:16:08 -05:00
case Plot.LINE:
2024-12-09 15:07:23 -05:00
for i, y in enumerate(ys):
2024-12-16 00:10:30 -05:00
ax = axes[i % rows]
if len(ys) > len(axes):
ax = ax[int(i / rows)]
line_plot(ax, data_list, x, y, color, linestyle)
ax.set_ylabel(y.value, fontsize=font_size)
ax.grid(True)
if x_log:
ax.set_xscale('log')
if y_log:
ax.set_yscale('log')
2024-12-12 15:47:26 -05:00
if type(ax.get_xaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
ax.get_xaxis().get_major_formatter().set_scientific(False)
if type(ax.get_yaxis().get_major_formatter()) is matplotlib.ticker.ScalarFormatter:
ax.get_yaxis().get_major_formatter().set_scientific(False)
2024-12-09 15:07:23 -05:00
2024-12-16 00:10:30 -05:00
handles, labels = ax.get_legend_handles_labels()
2024-12-09 15:07:23 -05:00
# else:
# fig, ax = plt.subplots()
#
# match plot:
# case Plot.BOX:
# box_plot(ax, data_list, x, y)
# case Plot.LINE:
# line_plot(ax, data_list, x, y, color)
#
# handles, labels = ax.get_legend_handles_labels()
#box_plot(ax, stats, x, y)
#line_plot(ax, stats, x, y, color)
# match plot:
# case Plot.BOX:
# title = f"{plot.value}_plot_of_{y.value.replace(' ', '_')}_vs_{x.value.replace(' ', '_')}_excluding_{filter_list}"
# case Plot.LINE:
# #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}"
# title = "altra_spmv"
2024-12-12 15:47:26 -05:00
#title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
2024-12-16 00:10:30 -05:00
fig.suptitle(title, fontsize = font_size)
fig.legend(handles, labels,
title=color.value if linestyle is None
else color.value + ", " + linestyle.value,
title_fontsize=font_size,
fontsize = font_size)
2024-12-05 15:21:29 -05:00
fig.supxlabel(x.value, fontsize = font_size)
2024-12-03 22:16:08 -05:00
2024-12-12 15:47:26 -05:00
#title = f'{plot.value} plot of {[y.name for y in ys]} vs {x.name} by {color.name} excluding {filter_list}'
2024-12-09 15:07:23 -05:00
#plt.xticks(fontsize=font_size)
#plt.yticks(fontsize=font_size)
plt.savefig(title.replace(' ', '_') + ".png", dpi = 100)
2024-12-03 22:16:08 -05:00
plt.show()
def main():
2024-12-09 15:07:23 -05:00
default_figure_size = 4
default_font_size = 5 * default_figure_size
2024-12-03 22:16:08 -05:00
parser = argparse.ArgumentParser()
2024-12-05 14:17:08 -05:00
parser.add_argument('input_dir',
help='the input directory')
2024-12-09 15:07:23 -05:00
parser.add_argument('plot',
2024-12-03 22:16:08 -05:00
choices=[x.name.lower() for x in Plot],
help = 'the type of plot')
2024-12-12 15:47:26 -05:00
parser.add_argument('title', type=str,
help = 'the name of the plot')
2024-12-03 22:16:08 -05:00
parser.add_argument('-r', '--rows', type=int,
help = 'the number of rows to display when -y is not specified',
default = 5)
parser.add_argument('-s', '--size', type=int,
help = 'figure size multiplier',
2024-12-09 15:07:23 -05:00
default = default_figure_size)
2024-12-03 22:16:08 -05:00
parser.add_argument('-fs', '--font_size', type=int,
help = 'font size',
2024-12-09 15:07:23 -05:00
default = default_font_size)
parser.add_argument('x',
2024-12-05 15:21:29 -05:00
choices=[x.name.lower() for x in Stat],
2024-12-03 22:16:08 -05:00
help = 'the name of the x axis')
2024-12-09 15:07:23 -05:00
parser.add_argument('-ys', nargs='+',
choices=[y.name.lower() for y in Stat],
2024-12-03 22:16:08 -05:00
help = 'the name of the y axis')
parser.add_argument('-c', '--color',
2024-12-09 15:07:23 -05:00
choices=[c.name.lower() for c in Stat],
2024-12-03 22:16:08 -05:00
help = 'the name of the color')
2024-12-16 00:10:30 -05:00
parser.add_argument('-l', '--linestyle',
choices=[l.name.lower() for l in Stat],
help = 'the name of the marker')
parser.add_argument('--x_log', action='store_true',
help = 'set x axis scale to log')
parser.add_argument('--y_log', action='store_true',
help = 'set y axis scale to log')
2024-12-03 22:16:08 -05:00
parser.add_argument('-f', '--filter', nargs = '+',
help = 'a comma-separated string of names and values to filter out.',
default = [])
args = parser.parse_args()
2024-12-05 14:17:08 -05:00
args.plot = Plot[args.plot.upper()]
2024-12-05 15:21:29 -05:00
args.x = Stat[args.x.upper()] if args.x is not None else None
2024-12-09 15:07:23 -05:00
args.ys = ([Stat[y.upper()] for y in args.ys]
if args.ys is not None else None)
2024-12-05 15:21:29 -05:00
args.color = Stat[args.color.upper()] if args.color is not None else None
2024-12-16 00:10:30 -05:00
args.linestyle = Stat[args.linestyle.upper()] if args.linestyle is not None else None
2024-12-05 14:17:08 -05:00
data_list: list[dict] = list()
for filename in glob.glob(f'{args.input_dir.rstrip("/")}/*.json'):
with open(filename, 'r') as file:
2024-12-11 14:43:23 -05:00
print(filename)
2024-12-05 14:17:08 -05:00
data_list.append(json.load(file))
print(filename + " loaded.")
2024-12-05 15:21:29 -05:00
visualize(
data_list,
args.plot,
args.rows,
args.size,
args.font_size,
args.x,
2024-12-09 15:07:23 -05:00
args.ys,
2024-12-05 15:21:29 -05:00
args.color,
2024-12-16 00:10:30 -05:00
args.linestyle,
args.x_log,
args.y_log,
2024-12-12 15:47:26 -05:00
args.filter,
args.title
2024-12-05 15:21:29 -05:00
)
2024-12-03 22:16:08 -05:00
if __name__ == '__main__':
main()