Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse matrix (DCSR_matrix) multiplication #1251

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
511e069
Support latest PyTorch release
ClaudiaComito Nov 1, 2022
6c69a58
Merge branch 'release/1.2.x' into create-pull-request/patch
ClaudiaComito Nov 3, 2022
8136c3c
Update setup.py
ClaudiaComito Nov 3, 2022
2b05d3e
Specify allclose tolerance in test_inv()
ClaudiaComito Nov 5, 2022
cbb2f7c
Increase allclose tolerance in test_inv
ClaudiaComito Nov 5, 2022
92e10fd
Increase allclose tolerance for distributed floating-point operations
ClaudiaComito Nov 7, 2022
73ebe07
fix working branches selection
ClaudiaComito Mar 31, 2023
271d8b8
add pr workflow (#1127)
mtar Apr 18, 2023
117745f
Support latest PyTorch release
ClaudiaComito Apr 27, 2023
ebdcefe
Merge branch 'create-pull-request/patch' of github.com:helmholtz-anal…
ClaudiaComito Apr 27, 2023
6ec71f7
expand version check to torch 2
ClaudiaComito Apr 27, 2023
1a6be7f
dndarray.item() to return ValueError if dndarray.size > 1
ClaudiaComito Apr 27, 2023
4bee685
Merge branch 'main' into create-pull-request/patch
ClaudiaComito Apr 27, 2023
f0d50e3
Merge branch 'main' of https://github.com/helmholtz-analytics/heat
Mystic-Slice Jul 31, 2023
f5bc90b
Merge branch 'main' of https://github.com/helmholtz-analytics/heat
Mystic-Slice Aug 29, 2023
538bd21
Merge branch 'main' of https://github.com/helmholtz-analytics/heat
Mystic-Slice Oct 14, 2023
4faa50e
Initial rough solution
Mystic-Slice Oct 21, 2023
751c48f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2023
0983ee5
Merge branch 'main' of https://github.com/helmholtz-analytics/heat in…
Mystic-Slice Oct 21, 2023
319d126
Merge branch 'sparse-matmul-impl' of https://github.com/helmholtz-ana…
Mystic-Slice Oct 21, 2023
994d0b1
matmul regardless of A's split
Mystic-Slice Oct 23, 2023
e53df48
fix pre-commit
Mystic-Slice Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions heat/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .arithmetics import *
from .dcsr_matrix import *
from .factories import *
from .linalg import *
from ._operations import *
from .manipulations import *
5 changes: 5 additions & 0 deletions heat/sparse/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Import all sparse linear algebra functions into the ht.sparse.linalg namespace
"""

from .basics import *
35 changes: 35 additions & 0 deletions heat/sparse/linalg/basics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Basic linear algebra operations on distributed ``DCSR_matrix``
"""

from ..dcsr_matrix import DCSR_matrix
from ..factories import sparse_csr_matrix
from ...core import devices
import torch


def matmul(A: DCSR_matrix, B: DCSR_matrix) -> DCSR_matrix:
"""
Matrix multiplication of two DCSR matrices.
"""
if A.shape[1] != B.shape[0]:
raise ValueError("Incompatible dimensions for matrix multiplication")

out_split = 0 if A.split == 0 or B.split == 0 else None

collected_B = torch.sparse_csr_tensor(
B.indptr,
B.indices,
B.data,
device=B.device.torch_device if B.device is not None else devices.get_device().torch_device,
size=B.shape,
)

matmul_res = A.larray @ collected_B

return sparse_csr_matrix(
matmul_res, dtype=A.dtype, device=A.device, comm=A.comm, is_split=out_split
)


DCSR_matrix.matmul = lambda self, other: matmul(self, other)
Loading