ampere_research/pytorch/batch.py

79 lines
2.5 KiB
Python
Raw Normal View History

2024-12-02 23:32:33 -05:00
#! /bin/python3
import argparse
import glob
import os
import subprocess
2024-12-03 08:53:39 -05:00
import random
2024-12-02 23:32:33 -05:00
parser = argparse.ArgumentParser()
parser.add_argument('arch')
parser.add_argument('output_dir')
parser.add_argument('matrix_dir')
parser.add_argument('iterations', type=int)
parser.add_argument('baseline_time_s', type=int)
parser.add_argument('baseline_delay_s', type=int)
parser.add_argument('--perf', action='store_const', const='--perf')
parser.add_argument('--power', action='store_const', const='--power')
2024-12-03 08:53:39 -05:00
parser.add_argument('--distribute', action='store_true')
2024-12-02 23:32:33 -05:00
args = parser.parse_args()
srun_args_altra = [
'--account', 'oasis',
'--partition', 'oasis',
'--qos', 'oasis-exempt',
#'--cpus-per-task 160',
'--cpus-per-task', '160',
#'--mem 28114',
'--mem', '16G',
'--ntasks-per-node', '1'#,
#'--exclusive',
#'--output', '/dev/null',
#'--error', '/dev/null'
]
def srun(srun_args_list: list, run_args, matrix_file: str) -> list:
run_args_list = [
args.arch,
matrix_file,
str(args.iterations),
str(args.baseline_time_s),
str(args.baseline_delay_s)]
if args.perf is not None:
run_args_list += [args.perf]
if args.power is not None:
run_args_list += [args.power]
2024-12-03 08:53:39 -05:00
return ['srun'] + srun_args_list + ['./run.py'] + run_args_list
processes = list()
2024-12-02 23:32:33 -05:00
for i, matrix in enumerate(glob.glob(f'{args.matrix_dir.rstrip("/")}/*.mtx')):
if args.arch == 'altra':
2024-12-03 08:53:39 -05:00
if args.distribute:
2024-12-02 23:32:33 -05:00
i = i % 40
2024-12-03 08:53:39 -05:00
srun_args = srun_args_altra + ['--nodelist', f'oasis{i:02}']
else:
srun_args = srun_args_altra
2024-12-02 23:32:33 -05:00
output_filename = '_'.join([
args.arch,
str(args.baseline_time_s),
2024-12-03 08:53:39 -05:00
str(args.baseline_delay_s),
os.path.splitext(os.path.basename(matrix))[0],
str(args.iterations)])
2024-12-02 23:32:33 -05:00
json_filepath = f'{args.output_dir.rstrip("/")}/{output_filename}.json'
raw_filepath = f'{args.output_dir.rstrip("/")}/{output_filename}.output'
with open(json_filepath, 'w') as json_file, open(raw_filepath, 'w') as raw_file:
2024-12-03 08:53:39 -05:00
print(srun(srun_args, args, matrix))
print(json_filepath)
print(raw_filepath)
processes.append(subprocess.Popen(
2024-12-02 23:32:33 -05:00
srun(srun_args_altra, args, matrix),
stdout=json_file,
2024-12-03 08:53:39 -05:00
stderr=raw_file))
2024-12-02 23:32:33 -05:00
2024-12-03 08:53:39 -05:00
for process in processes:
process.wait()