import torch, scipy import numpy as np import argparse import time parser = argparse.ArgumentParser() parser.add_argument('matrix_file', help='the input matrix (.mtx) file') parser.add_argument('iterations', type=int, help='the number of iterations of multiplication to perform') args = parser.parse_args() device = 'cpu' matrix = scipy.io.mmread(args.matrix_file) matrix = torch.sparse_coo_tensor( np.vstack((matrix.row, matrix.col)), matrix.data, matrix.shape, device=device ).to_sparse_csr().type(torch.float) vector = torch.rand(matrix.shape[1], device=device) print(matrix) print(vector) start = time.time() for i in range(0, args.iterations): torch.sparse.mm(matrix, vector.unsqueeze(-1)).squeeze(-1) #print(i) end = time.time() if matrix.shape[0] == matrix.shape[1]: print(f"Shape: {matrix.shape[1]}") else: print(f"Shape: {matrix.shape}") print(f"Time: {end - start} seconds")