ampere_research/pytorch/spmv.py

48 lines
1.2 KiB
Python
Raw Normal View History

2024-11-28 00:04:57 -05:00
import torch, scipy
import numpy as np
import argparse
import time
2024-12-02 23:32:33 -05:00
import json
import sys
2024-11-28 00:04:57 -05:00
parser = argparse.ArgumentParser()
parser.add_argument('matrix_file', help='the input matrix (.mtx) file')
parser.add_argument('iterations', type=int, help='the number of iterations of multiplication to perform')
args = parser.parse_args()
device = 'cpu'
matrix = scipy.io.mmread(args.matrix_file)
matrix = torch.sparse_coo_tensor(
np.vstack((matrix.row, matrix.col)),
matrix.data, matrix.shape,
device=device
).to_sparse_csr().type(torch.float)
vector = torch.rand(matrix.shape[1], device=device)
2024-12-02 23:32:33 -05:00
print(matrix, file=sys.stderr)
print(vector, file=sys.stderr)
2024-11-28 00:04:57 -05:00
start = time.time()
for i in range(0, args.iterations):
torch.sparse.mm(matrix, vector.unsqueeze(-1)).squeeze(-1)
#print(i)
end = time.time()
2024-12-02 23:32:33 -05:00
result = dict()
2024-11-28 00:04:57 -05:00
2024-12-02 23:32:33 -05:00
result['shape'] = matrix.shape
print(f"Shape: {result['shape']}", file=sys.stderr)
result['nnz'] = matrix.values().shape[0]
print(f"NNZ: {result['nnz']}", file=sys.stderr)
result['% density'] = matrix.values().shape[0] / (matrix.shape[0] * matrix.shape[1])
print(f"Density: {result['% density']}", file=sys.stderr)
result['time_s'] = end - start
print(f"Time: {result['time_s']} seconds", file=sys.stderr)
print(json.dumps(result))