Altra visualization works!

This commit is contained in:
cephi_sui 2024-12-05 15:21:29 -05:00
parent 2c115359c3
commit 81b966f485

View File

@ -18,11 +18,11 @@ class Plot(Enum):
BOX = 'box' BOX = 'box'
LINE = 'line' LINE = 'line'
def accumulate(stats_list: list[dict[str, str | int | float]], category: str, value: str): def accumulate(data_list: list[dict[str, str | int | float]], category: Stat, value: Stat):
print(category) #print(category.name)
print(value) #print(value.name)
category_list = np.array([stats[category] for stats in stats_list if value in stats]) category_list = np.array([stats[category.name] for stats in data_list if value.name in stats])
value_list = np.array([stats[value] for stats in stats_list if value in stats]) value_list = np.array([stats[value.name] for stats in data_list if value.name in stats])
result: dict[np.ndarray] = dict() result: dict[np.ndarray] = dict()
for category in np.sort(np.unique(category_list)): for category in np.sort(np.unique(category_list)):
@ -30,34 +30,34 @@ def accumulate(stats_list: list[dict[str, str | int | float]], category: str, va
return result return result
def box_plot(ax, stats_list: list[dict[str, str | int | float]], x: Stat, y: Stat): def box_plot(ax, data_list: list[dict[str, str | int | float]], x: Stat, y: Stat):
data: dict[str, np.ndarray] = accumulate(stats_list, x, y) data: dict[str, np.ndarray] = accumulate(data_list, x, y)
print("Plotted data: " + str(data)) #print("Plotted data: " + str(data))
ax.boxplot(data.values(), tick_labels=data.keys()) ax.boxplot(data.values(), tick_labels=data.keys())
ax.set_ylabel(y.value) ax.set_ylabel(y.value)
def line_plot( def line_plot(
ax, stats_list: list[dict[str, str | int | float]], ax, data_list: list[dict[str, str | int | float]],
x: Stat, y: Stat, color: Stat x: Stat, y: Stat, color: Stat
): ):
x_data: dict[str, np.ndarray] = accumulate(stats_list, color, x) x_data: dict[str, np.ndarray] = accumulate(data_list, color, x)
y_data: dict[str, np.ndarray] = accumulate(stats_list, color, y) y_data: dict[str, np.ndarray] = accumulate(data_list, color, y)
for category in x_data.keys(): for category in x_data.keys():
sorted_indices = np.argsort(x_data[category]) sorted_indices = np.argsort(x_data[category])
x_data[category] = x_data[category][sorted_indices] x_data[category] = x_data[category][sorted_indices]
y_data[category] = y_data[category][sorted_indices] y_data[category] = y_data[category][sorted_indices]
ax.plot(x_data[category], y_data[category], label=category) ax.plot(x_data[category], y_data[category], label=category)
print("Plotted x data: " + str(x_data[category])) #print("Plotted x data: " + str(x_data[category]))
print("Plotted y data: " + str(y_data[category])) #print("Plotted y data: " + str(y_data[category]))
ax.set_ylabel(y) ax.set_ylabel(y.value)
ax.grid(True) ax.grid(True)
def visualize( def visualize(
stats_list: list[dict[str, str | int | float]], data_list: list[dict[str, str | int | float]],
plot: Plot, plot: Plot,
rows: int, rows: int,
size_multiplier: int, size_multiplier: int,
@ -68,7 +68,7 @@ def visualize(
filter_list: list[str] = [] filter_list: list[str] = []
): ):
# Remove stats entries containing undesired values (like a specific CPU). # Remove stats entries containing undesired values (like a specific CPU).
# stats_list = [stats for stats in stats_list # data_list = [stats for stats in data_list
# if len([stats[key] for key in stats.keys() # if len([stats[key] for key in stats.keys()
# if stats[key] in filter_list]) == 0] # if stats[key] in filter_list]) == 0]
@ -77,21 +77,25 @@ def visualize(
#color = Stat.SOLVER #color = Stat.SOLVER
if y is None: if y is None:
#ys = [stat for stat in Stat if stat.value in stats_list[0].keys() #ys = [stat for stat in Stat if stat.name in data_list[0].keys()]
ys = [stat for stat in stats_list[0].keys() if "power" not in stat] #ys = [stat for stat in data_list[0].keys() if "power" not in stat]
#and stat is not x #and stat is not x
#and y != color #and y != color
#and y != marker #and y != marker
#and stat.value not in filter_list] #and stat.value not in filter_list]
# Create sorted, deduped list of all stats in data_list.
ys = [Stat[stat_name] for stat_name in sorted(list(set([stat_name for data in data_list for stat_name in data if type(data[stat_name]) is not list])))]
print([stat.value for stat in ys])
fig, axes = plt.subplots(rows, int(math.ceil(len(ys) / rows)), fig, axes = plt.subplots(rows, int(math.ceil(len(ys) / rows)),
figsize = (16 * size_multiplier, 9 * size_multiplier)) figsize = (16 * size_multiplier, 9 * size_multiplier))
match plot: match plot:
case Plot.BOX: case Plot.BOX:
for i, y in enumerate(ys): for i, y in enumerate(ys):
box_plot(axes[i % rows][int(i / rows)], stats_list, x, y) box_plot(axes[i % rows][int(i / rows)], data_list, x, y)
case Plot.LINE: case Plot.LINE:
for i, y in enumerate(ys): for i, y in enumerate(ys):
line_plot(axes[i % rows][int(i / rows)], stats_list, x, y, color) line_plot(axes[i % rows][int(i / rows)], data_list, x, y, color)
handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels() handles, labels = axes[i % rows][int(i / rows)].get_legend_handles_labels()
else: else:
@ -99,9 +103,9 @@ def visualize(
match plot: match plot:
case Plot.BOX: case Plot.BOX:
box_plot(ax, stats_list, x, y) box_plot(ax, data_list, x, y)
case Plot.LINE: case Plot.LINE:
line_plot(ax, stats_list, x, y, color) line_plot(ax, data_list, x, y, color)
handles, labels = ax.get_legend_handles_labels() handles, labels = ax.get_legend_handles_labels()
@ -116,7 +120,7 @@ def visualize(
title = "altra_spmv" title = "altra_spmv"
fig.suptitle(title, fontsize = font_size) fig.suptitle(title, fontsize = font_size)
fig.legend(handles, labels, fontsize = font_size) fig.legend(handles, labels, fontsize = font_size)
fig.supxlabel(x, fontsize = font_size) fig.supxlabel(x.value, fontsize = font_size)
plt.savefig(title + ".png", dpi = 100) plt.savefig(title + ".png", dpi = 100)
plt.show() plt.show()
@ -138,19 +142,22 @@ def main():
help = 'font size', help = 'font size',
default = 40) default = 40)
parser.add_argument('-x', parser.add_argument('-x',
#choices=[x.name.lower() for x in Stat], choices=[x.name.lower() for x in Stat],
help = 'the name of the x axis') help = 'the name of the x axis')
parser.add_argument('-y', parser.add_argument('-y',
#choices=[x.name.lower() for x in Stat], choices=[x.name.lower() for x in Stat],
help = 'the name of the y axis') help = 'the name of the y axis')
parser.add_argument('-c', '--color', parser.add_argument('-c', '--color',
#choices=[x.name.lower() for x in Stat], choices=[x.name.lower() for x in Stat],
help = 'the name of the color') help = 'the name of the color')
parser.add_argument('-f', '--filter', nargs = '+', parser.add_argument('-f', '--filter', nargs = '+',
help = 'a comma-separated string of names and values to filter out.', help = 'a comma-separated string of names and values to filter out.',
default = []) default = [])
args = parser.parse_args() args = parser.parse_args()
args.plot = Plot[args.plot.upper()] args.plot = Plot[args.plot.upper()]
args.x = Stat[args.x.upper()] if args.x is not None else None
args.y = Stat[args.y.upper()] if args.y is not None else None
args.color = Stat[args.color.upper()] if args.color is not None else None
data_list: list[dict] = list() data_list: list[dict] = list()
@ -159,14 +166,17 @@ def main():
data_list.append(json.load(file)) data_list.append(json.load(file))
print(filename + " loaded.") print(filename + " loaded.")
print(data_list) visualize(
#x = Stat[args.x.upper()] if args.x is not None else None data_list,
x = args.x args.plot,
#y = Stat[args.y.upper()] if args.y is not None else None args.rows,
y = args.y args.size,
#color = Stat[args.color.upper()] if args.color is not None else None args.font_size,
color = args.color args.x,
visualize(data_list, args.plot, args.rows, args.size, args.font_size, x, y, color, args.filter) args.y,
args.color,
args.filter
)
if __name__ == '__main__': if __name__ == '__main__':
main() main()