Skip to content

Commit

Permalink
Do not save unused data tensor for SDD operation.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgale96 committed Aug 15, 2023
1 parent d5c2dc6 commit 8aa616a
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions stk/backend/sputnik.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ def forward(ctx,
ctx.save_for_backward(
lhs,
rhs,
data,
offsets,
row_indices,
column_indices,
Expand All @@ -257,8 +256,8 @@ def forward(ctx,
ctx.shape = shape
out = torch.empty(
data.shape,
dtype=data.dtype,
device=data.device)
dtype=lhs.dtype,
device=lhs.device)
backend.sdd(lhs,
rhs,
shape,
Expand All @@ -272,7 +271,7 @@ def forward(ctx,
@custom_bwd
def backward(ctx, dy):
lhs, rhs = ctx.saved_tensors[:2]
dy = (ctx.shape, dy) + ctx.saved_tensors[3:]
dy = (ctx.shape, dy) + ctx.saved_tensors[2:]
trans_a = _is_transposed(lhs)
trans_b = _is_transposed(rhs)

Expand Down

0 comments on commit 8aa616a

Please sign in to comment.