37 lines
921 B
Python
37 lines
921 B
Python
|
import torch, scipy
|
||
|
import numpy as np
|
||
|
import argparse
|
||
|
import time
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('matrix_file', help='the input matrix (.mtx) file')
|
||
|
parser.add_argument('iterations', type=int, help='the number of iterations of multiplication to perform')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
device = 'cpu'
|
||
|
|
||
|
matrix = scipy.io.mmread(args.matrix_file)
|
||
|
matrix = torch.sparse_coo_tensor(
|
||
|
np.vstack((matrix.row, matrix.col)),
|
||
|
matrix.data, matrix.shape,
|
||
|
device=device
|
||
|
).to_sparse_csr().type(torch.float)
|
||
|
|
||
|
vector = torch.rand(matrix.shape[1], device=device)
|
||
|
|
||
|
print(matrix)
|
||
|
print(vector)
|
||
|
|
||
|
start = time.time()
|
||
|
for i in range(0, args.iterations):
|
||
|
torch.sparse.mm(matrix, vector.unsqueeze(-1)).squeeze(-1)
|
||
|
#print(i)
|
||
|
end = time.time()
|
||
|
|
||
|
if matrix.shape[0] == matrix.shape[1]:
|
||
|
print(f"Shape: {matrix.shape[1]}")
|
||
|
else:
|
||
|
print(f"Shape: {matrix.shape}")
|
||
|
|
||
|
print(f"Time: {end - start} seconds")
|