diff --git a/pytorch/analyze.py b/pytorch/analyze.py index bf1c73d..109915a 100755 --- a/pytorch/analyze.py +++ b/pytorch/analyze.py @@ -1,6 +1,6 @@ #! /bin/python3 -from perf_stat import Stat, CPU +from data_stat import Stat, Cpu import argparse import os, glob @@ -122,16 +122,9 @@ def visualize( plt.show() def main(): - class Command(Enum): - PARSE = 'parse' - VISUALIZE = 'visualize' - parser = argparse.ArgumentParser() - parser.add_argument('command', choices=[x.value for x in Command]) - parser.add_argument('filepath', - help='the output for the ' + Command.PARSE.value + ' command or the input for the ' + Command.VISUALIZE.value + ' command') - parser.add_argument('-i', '--input_dir', - help='the input directory for the parse command') + parser.add_argument('input_dir', + help='the input directory') parser.add_argument('-p', '--plot', choices=[x.name.lower() for x in Plot], help = 'the type of plot') @@ -156,34 +149,24 @@ def main(): parser.add_argument('-f', '--filter', nargs = '+', help = 'a comma-separated string of names and values to filter out.', default = []) - args = parser.parse_args() + args.plot = Plot[args.plot.upper()] - stats_list: list[dict] = list() - if args.command == Command.PARSE.value: - if (args.input_dir) is None: - print("An input directory is required with -i") - exit(-1) + data_list: list[dict] = list() - for filename in glob.glob(f'{args.input_dir.rstrip("/")}/*.json'): - with open(filename, 'r') as file: - stats_list.append(json.load(file)) - print(filename + " loaded.") + for filename in glob.glob(f'{args.input_dir.rstrip("/")}/*.json'): + with open(filename, 'r') as file: + data_list.append(json.load(file)) + print(filename + " loaded.") - with open(args.filepath, 'w') as file: - json.dump(stats_list, file, indent = 2) - - elif args.command == Command.VISUALIZE.value: - with open(args.filepath, 'r') as file: - stats_list = json.load(file) - - #x = Stat[args.x.upper()] if args.x is not None else None - x = args.x - #y = Stat[args.y.upper()] if args.y is not None else None - y = args.y - #color = Stat[args.color.upper()] if args.color is not None else None - color = args.color - visualize(stats_list, Plot[args.plot.upper()], args.rows, args.size, args.font_size, x, y, color, args.filter) + print(data_list) + #x = Stat[args.x.upper()] if args.x is not None else None + x = args.x + #y = Stat[args.y.upper()] if args.y is not None else None + y = args.y + #color = Stat[args.color.upper()] if args.color is not None else None + color = args.color + visualize(data_list, args.plot, args.rows, args.size, args.font_size, x, y, color, args.filter) if __name__ == '__main__': main()