ampere_research/pytorch/spmv.py

90 lines
3.0 KiB
Python
Raw Normal View History

from data_stat import Stat, Format
2024-12-04 22:47:16 -05:00
2024-11-28 00:04:57 -05:00
import torch, scipy
import numpy as np
import argparse
import time
2024-12-02 23:32:33 -05:00
import json
2024-12-03 08:53:39 -05:00
import sys, os
2024-11-28 00:04:57 -05:00
parser = argparse.ArgumentParser()
parser.add_argument('iterations', type=int, help='the number of iterations of multiplication to perform')
parser.add_argument('format', type=str,
choices=[fmt.name.lower() for fmt in Format],
help='the sparse format to use')
parser.add_argument('-m', '--matrix_file', help='the input matrix (.mtx) file')
parser.add_argument('-ss', '--synthetic_size', type=int,
help='the synthetic matrix parameters size (rows)')
parser.add_argument('-sd', '--synthetic_density', type=float,
help='the synthetic matrix density (%)')
2024-11-28 00:04:57 -05:00
args = parser.parse_args()
args.format = Format[args.format.upper()]
2024-11-28 00:04:57 -05:00
device = 'cpu'
if args.matrix_file is not None:
matrix = scipy.io.mmread(args.matrix_file)
matrix = torch.sparse_coo_tensor(
np.vstack((matrix.row, matrix.col)),
matrix.data, matrix.shape,
device=device, dtype=torch.float32)
elif args.synthetic_size is not None and args.synthetic_density is not None:
nnz = int((args.synthetic_size ** 2) * (args.synthetic_density / 100))
row_indices = torch.randint(0, args.synthetic_size, (nnz,))
col_indices = torch.randint(0, args.synthetic_size, (nnz,))
indices = torch.stack([row_indices, col_indices])
values = torch.randn(nnz)
matrix = torch.sparse_coo_tensor(
indices, values,
size=(args.synthetic_size, args.synthetic_size),
device=device, dtype=torch.float32)
else:
print("No matrix specified!")
exit(1)
if args.format == Format.CSR:
matrix = matrix.to_sparse_csr().type(torch.float32)
elif args.format == Format.COO:
pass
else:
print("Unrecognized format!")
exit(1)
2024-11-28 00:04:57 -05:00
vector = torch.rand(matrix.shape[1], device=device)
2024-12-02 23:32:33 -05:00
print(matrix, file=sys.stderr)
print(vector, file=sys.stderr)
2024-11-28 00:04:57 -05:00
start = time.time()
for i in range(0, args.iterations):
2024-12-09 10:57:15 -05:00
torch.mv(matrix, vector)
2024-12-09 15:06:46 -05:00
#torch.sparse.mm(matrix, vector.unsqueeze(-1)).squeeze(-1)
#print(i)
2024-11-28 00:04:57 -05:00
end = time.time()
2024-12-02 23:32:33 -05:00
result = dict()
2024-11-28 00:04:57 -05:00
if args.matrix_file is not None:
result[Stat.MATRIX_FILE.name] = os.path.splitext(os.path.basename(args.matrix_file))[0]
else:
result[Stat.MATRIX_FILE.name] = "synthetic"
2024-12-05 14:49:05 -05:00
print(f"Matrix: {result[Stat.MATRIX_FILE.name]}", file=sys.stderr)
2024-12-04 22:47:16 -05:00
2024-12-05 14:49:05 -05:00
result[Stat.MATRIX_SHAPE.name] = matrix.shape
print(f"Shape: {result[Stat.MATRIX_SHAPE.name]}", file=sys.stderr)
2024-12-03 08:53:39 -05:00
2024-12-05 14:49:05 -05:00
result[Stat.MATRIX_SIZE.name] = matrix.shape[0] * matrix.shape[1]
print(f"Size: {result[Stat.MATRIX_SIZE.name]}", file=sys.stderr)
2024-12-02 23:32:33 -05:00
2024-12-05 14:49:05 -05:00
result[Stat.MATRIX_NNZ.name] = matrix.values().shape[0]
print(f"NNZ: {result[Stat.MATRIX_NNZ.name]}", file=sys.stderr)
2024-12-02 23:32:33 -05:00
2024-12-05 14:49:05 -05:00
result[Stat.MATRIX_DENSITY.name] = matrix.values().shape[0] / (matrix.shape[0] * matrix.shape[1])
print(f"Density: {result[Stat.MATRIX_DENSITY.name]}", file=sys.stderr)
2024-12-02 23:32:33 -05:00
result[Stat.TIME_S.name] = end - start
print(f"Time: {result[Stat.TIME_S.name]} seconds", file=sys.stderr)
2024-12-02 23:32:33 -05:00
print(json.dumps(result))