ampere_research/pytorch/analyze.py

190 lines
6.9 KiB
Python
Raw Normal View History

2024-12-03 22:16:08 -05:00
#! /bin/python3
from perf_stat import Stat, CPU
import argparse
import os, glob
import re
import json
from enum import Enum
import math
import numpy as np
import matplotlib.pyplot as plt
import itertools
class Plot(Enum):
BOX = 'box'
LINE = 'line'
def accumulate(stats_list: list[dict[str, str | int | float]], category: str, value: str):
print(category)
print(value)
category_list = np.array([stats[category] for stats in stats_list if value in stats])
value_list = np.array([stats[value] for stats in stats_list if value in stats])
result: dict[np.ndarray] = dict()
for category in np.sort(np.unique(category_list)):
result[category] = value_list[category_list == category]
return result
def box_plot(ax, stats_list: list[dict[str, str | int | float]], x: Stat, y: Stat):
data: dict[str, np.ndarray] = accumulate(stats_list, x, y)
print("Plotted data: " + str(data))
ax.boxplot(data.values(), tick_labels=data.keys())
ax.set_ylabel(y.value)
def line_plot(
ax, stats_list: list[dict[str, str | int | float]],
x: Stat, y: Stat, color: Stat
):
x_data: dict[str, np.ndarray] = accumulate(stats_list, color, x)
y_data: dict[str, np.ndarray] = accumulate(stats_list, color, y)
for category in x_data.keys():
sorted_indices = np.argsort(x_data[category])
x_data[category] = x_data[category][sorted_indices]
y_data[category] = y_data[category][sorted_indices]
ax.plot(x_data[category], y_data[category], label=category)
print("Plotted x data: " + str(x_data[category]))
print("Plotted y data: " + str(y_data[category]))
ax.set_ylabel(y)
ax.grid(True)
def visualize(
stats_list: list[dict[str, str | int | float]],
plot: Plot,
rows: int,
size_multiplier: int,
font_size: int,
x: Stat,
y: Stat,
color: Stat,
filter_list: list[str] = []
):
# Remove stats entries containing undesired values (like a specific CPU).
# stats_list = [stats for stats in stats_list
# if len([stats[key] for key in stats.keys()
# if stats[key] in filter_list]) == 0]
#x = Stat.MAXWELL_SIZE
#y = Stat.DTLB_MISS_RATE
#color = Stat.SOLVER
if y is None:
#ys = [stat for stat in Stat if stat.value in stats_list[0].keys()
ys = [stat for stat in stats_list[0].keys() if "power" not in stat]
#and stat is not x
#and y != color
#and y != marker
#and stat.value not in filter_list]
fig, axes = plt.subplots(rows, int(math.ceil(len(ys) / rows)),
figsize = (16 * size_multiplier, 9 * size_multiplier))
match plot:
case Plot.BOX:
for i, y in enumerate(ys):
box_plot(axes[i % rows][int(i / rows)], stats_list, x, y)
case Plot.LINE:
for i, y in enumerate(ys):
line_plot(axes[i % rows][int(i / rows)], stats_list, x, y, color)
handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels()
else:
fig, ax = plt.subplots()
match plot:
case Plot.BOX:
box_plot(ax, stats_list, x, y)
case Plot.LINE:
line_plot(ax, stats_list, x, y, color)
handles, labels = ax.get_legend_handles_labels()
#box_plot(ax, stats, x, y)
#line_plot(ax, stats, x, y, color)
match plot:
case Plot.BOX:
title = f"{plot.value}_plot_of_{y.value.replace(' ', '_')}_vs_{x.value.replace(' ', '_')}_excluding_{filter_list}"
case Plot.LINE:
#title = f"{plot.value}_plot_of_{y.replace(' ', '_')}_vs_{x.replace(' ', '_')}_by_{color.replace(' ', '_')}_excluding_{filter_list}"
title = "altra_spmv"
fig.suptitle(title, fontsize = font_size)
fig.legend(handles, labels, fontsize = font_size)
fig.supxlabel(x, fontsize = font_size)
plt.savefig(title + ".png", dpi = 100)
plt.show()
def main():
class Command(Enum):
PARSE = 'parse'
VISUALIZE = 'visualize'
parser = argparse.ArgumentParser()
parser.add_argument('command', choices=[x.value for x in Command])
parser.add_argument('filepath',
help='the output for the ' + Command.PARSE.value + ' command or the input for the ' + Command.VISUALIZE.value + ' command')
parser.add_argument('-i', '--input_dir',
help='the input directory for the parse command')
parser.add_argument('-p', '--plot',
choices=[x.name.lower() for x in Plot],
help = 'the type of plot')
parser.add_argument('-r', '--rows', type=int,
help = 'the number of rows to display when -y is not specified',
default = 5)
parser.add_argument('-s', '--size', type=int,
help = 'figure size multiplier',
default = 4)
parser.add_argument('-fs', '--font_size', type=int,
help = 'font size',
default = 40)
parser.add_argument('-x',
#choices=[x.name.lower() for x in Stat],
help = 'the name of the x axis')
parser.add_argument('-y',
#choices=[x.name.lower() for x in Stat],
help = 'the name of the y axis')
parser.add_argument('-c', '--color',
#choices=[x.name.lower() for x in Stat],
help = 'the name of the color')
parser.add_argument('-f', '--filter', nargs = '+',
help = 'a comma-separated string of names and values to filter out.',
default = [])
args = parser.parse_args()
stats_list: list[dict] = list()
if args.command == Command.PARSE.value:
if (args.input_dir) is None:
print("An input directory is required with -i")
exit(-1)
for filename in glob.glob(f'{args.input_dir.rstrip("/")}/*.json'):
with open(filename, 'r') as file:
stats_list.append(json.load(file))
print(filename + " loaded.")
with open(args.filepath, 'w') as file:
json.dump(stats_list, file, indent = 2)
elif args.command == Command.VISUALIZE.value:
with open(args.filepath, 'r') as file:
stats_list = json.load(file)
#x = Stat[args.x.upper()] if args.x is not None else None
x = args.x
#y = Stat[args.y.upper()] if args.y is not None else None
y = args.y
#color = Stat[args.color.upper()] if args.color is not None else None
color = args.color
visualize(stats_list, Plot[args.plot.upper()], args.rows, args.size, args.font_size, x, y, color, args.filter)
if __name__ == '__main__':
main()