ampere_research/pytorch/analyze.py

199 lines
7.1 KiB
Python
Raw Normal View History

2024-12-03 22:16:08 -05:00
#! /bin/python3
2024-12-05 14:17:08 -05:00
from data_stat import Stat, Cpu
2024-12-03 22:16:08 -05:00
import argparse
import os, glob
import re
import json
from enum import Enum
import math
import numpy as np
import matplotlib.pyplot as plt
import itertools
class Plot(Enum):
BOX = 'box'
LINE = 'line'
2024-12-05 15:21:29 -05:00
def accumulate(data_list: list[dict[str, str | int | float]], category: Stat, value: Stat):
#print(category.name)
#print(value.name)
category_list = np.array([stats[category.name] for stats in data_list if value.name in stats])
value_list = np.array([stats[value.name] for stats in data_list if value.name in stats])
2024-12-03 22:16:08 -05:00
result: dict[np.ndarray] = dict()
for category in np.sort(np.unique(category_list)):
result[category] = value_list[category_list == category]
return result
2024-12-05 15:21:29 -05:00
def box_plot(ax, data_list: list[dict[str, str | int | float]], x: Stat, y: Stat):
data: dict[str, np.ndarray] = accumulate(data_list, x, y)
2024-12-03 22:16:08 -05:00
2024-12-05 15:21:29 -05:00
#print("Plotted data: " + str(data))
2024-12-03 22:16:08 -05:00
ax.boxplot(data.values(), tick_labels=data.keys())
ax.set_ylabel(y.value)
def line_plot(
2024-12-05 15:21:29 -05:00
ax, data_list: list[dict[str, str | int | float]],
2024-12-09 15:07:23 -05:00
x: Stat, y: Stat, color: Stat,
font_size: int
2024-12-03 22:16:08 -05:00
):
2024-12-05 15:21:29 -05:00
x_data: dict[str, np.ndarray] = accumulate(data_list, color, x)
y_data: dict[str, np.ndarray] = accumulate(data_list, color, y)
2024-12-03 22:16:08 -05:00
for category in x_data.keys():
sorted_indices = np.argsort(x_data[category])
x_data[category] = x_data[category][sorted_indices]
y_data[category] = y_data[category][sorted_indices]
ax.plot(x_data[category], y_data[category], label=category)
2024-12-05 15:21:29 -05:00
#print("Plotted x data: " + str(x_data[category]))
#print("Plotted y data: " + str(y_data[category]))
2024-12-03 22:16:08 -05:00
2024-12-09 15:07:23 -05:00
ax.set_ylabel(y.value, fontsize=font_size)
2024-12-03 22:16:08 -05:00
ax.grid(True)
def visualize(
2024-12-05 15:21:29 -05:00
data_list: list[dict[str, str | int | float]],
2024-12-03 22:16:08 -05:00
plot: Plot,
rows: int,
size_multiplier: int,
font_size: int,
x: Stat,
2024-12-09 15:07:23 -05:00
ys: list[Stat],
2024-12-03 22:16:08 -05:00
color: Stat,
filter_list: list[str] = []
):
# Remove stats entries containing undesired values (like a specific CPU).
2024-12-05 15:21:29 -05:00
# data_list = [stats for stats in data_list
2024-12-03 22:16:08 -05:00
# if len([stats[key] for key in stats.keys()
# if stats[key] in filter_list]) == 0]
#x = Stat.MAXWELL_SIZE
#y = Stat.DTLB_MISS_RATE
#color = Stat.SOLVER
2024-12-09 15:07:23 -05:00
if ys is None:
2024-12-05 15:21:29 -05:00
#ys = [stat for stat in Stat if stat.name in data_list[0].keys()]
#ys = [stat for stat in data_list[0].keys() if "power" not in stat]
2024-12-03 22:16:08 -05:00
#and stat is not x
#and y != color
#and y != marker
#and stat.value not in filter_list]
2024-12-05 15:21:29 -05:00
# Create sorted, deduped list of all stats in data_list.
ys = [Stat[stat_name] for stat_name in sorted(list(set([stat_name for data in data_list for stat_name in data if type(data[stat_name]) is not list])))]
2024-12-03 22:16:08 -05:00
2024-12-09 15:07:23 -05:00
print([stat.name for stat in ys])
print(len(ys))
print(rows)
print(int(math.ceil(len(ys) / rows)))
fig, axes = plt.subplots(rows, int(math.ceil(len(ys) / rows)),
figsize = (16 * size_multiplier, 9 * size_multiplier))
2024-12-03 22:16:08 -05:00
match plot:
case Plot.BOX:
2024-12-09 15:07:23 -05:00
for i, y in enumerate(ys):
box_plot(axes[i % rows][int(i / rows)], data_list, x, y)
2024-12-03 22:16:08 -05:00
case Plot.LINE:
2024-12-09 15:07:23 -05:00
for i, y in enumerate(ys):
ax = axes[i % rows][int(i / rows)]
line_plot(ax, data_list, x, y, color, font_size)
ax.get_xaxis().get_major_formatter().set_scientific(False)
ax.get_yaxis().get_major_formatter().set_scientific(False)
handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels()
# else:
# fig, ax = plt.subplots()
#
# match plot:
# case Plot.BOX:
# box_plot(ax, data_list, x, y)
# case Plot.LINE:
# line_plot(ax, data_list, x, y, color)
#
# handles, labels = ax.get_legend_handles_labels()
#box_plot(ax, stats, x, y)
#line_plot(ax, stats, x, y, color)
# match plot:
# case Plot.BOX:
# title = f"{plot.value}_plot_of_{y.value.replace(' ', '_')}_vs_{x.value.replace(' ', '_')}_excluding_{filter_list}"
# case Plot.LINE:
# #title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}"
# title = "altra_spmv"
title = f'{plot.value} plot of {[y.value for y in ys]} vs {x.value} by {color.value} excluding {filter_list}'
#fig.suptitle(title, fontsize = font_size)
2024-12-03 22:16:08 -05:00
fig.legend(handles, labels, fontsize = font_size)
2024-12-05 15:21:29 -05:00
fig.supxlabel(x.value, fontsize = font_size)
2024-12-03 22:16:08 -05:00
2024-12-09 15:07:23 -05:00
#plt.xticks(fontsize=font_size)
#plt.yticks(fontsize=font_size)
plt.savefig(title.replace(' ', '_') + ".png", dpi = 100)
2024-12-03 22:16:08 -05:00
plt.show()
def main():
2024-12-09 15:07:23 -05:00
default_figure_size = 4
default_font_size = 5 * default_figure_size
2024-12-03 22:16:08 -05:00
parser = argparse.ArgumentParser()
2024-12-05 14:17:08 -05:00
parser.add_argument('input_dir',
help='the input directory')
2024-12-09 15:07:23 -05:00
parser.add_argument('plot',
2024-12-03 22:16:08 -05:00
choices=[x.name.lower() for x in Plot],
help = 'the type of plot')
parser.add_argument('-r', '--rows', type=int,
help = 'the number of rows to display when -y is not specified',
default = 5)
parser.add_argument('-s', '--size', type=int,
help = 'figure size multiplier',
2024-12-09 15:07:23 -05:00
default = default_figure_size)
2024-12-03 22:16:08 -05:00
parser.add_argument('-fs', '--font_size', type=int,
help = 'font size',
2024-12-09 15:07:23 -05:00
default = default_font_size)
parser.add_argument('x',
2024-12-05 15:21:29 -05:00
choices=[x.name.lower() for x in Stat],
2024-12-03 22:16:08 -05:00
help = 'the name of the x axis')
2024-12-09 15:07:23 -05:00
parser.add_argument('-ys', nargs='+',
choices=[y.name.lower() for y in Stat],
2024-12-03 22:16:08 -05:00
help = 'the name of the y axis')
parser.add_argument('-c', '--color',
2024-12-09 15:07:23 -05:00
choices=[c.name.lower() for c in Stat],
2024-12-03 22:16:08 -05:00
help = 'the name of the color')
parser.add_argument('-f', '--filter', nargs = '+',
help = 'a comma-separated string of names and values to filter out.',
default = [])
args = parser.parse_args()
2024-12-05 14:17:08 -05:00
args.plot = Plot[args.plot.upper()]
2024-12-05 15:21:29 -05:00
args.x = Stat[args.x.upper()] if args.x is not None else None
2024-12-09 15:07:23 -05:00
args.ys = ([Stat[y.upper()] for y in args.ys]
if args.ys is not None else None)
2024-12-05 15:21:29 -05:00
args.color = Stat[args.color.upper()] if args.color is not None else None
2024-12-05 14:17:08 -05:00
data_list: list[dict] = list()
for filename in glob.glob(f'{args.input_dir.rstrip("/")}/*.json'):
with open(filename, 'r') as file:
2024-12-11 14:43:23 -05:00
print(filename)
2024-12-05 14:17:08 -05:00
data_list.append(json.load(file))
print(filename + " loaded.")
2024-12-05 15:21:29 -05:00
visualize(
data_list,
args.plot,
args.rows,
args.size,
args.font_size,
args.x,
2024-12-09 15:07:23 -05:00
args.ys,
2024-12-05 15:21:29 -05:00
args.color,
args.filter
)
2024-12-03 22:16:08 -05:00
if __name__ == '__main__':
main()