From 8aa616a73141d431a89ad5c4e5b960cc7e56175c Mon Sep 17 00:00:00 2001 From: Trevor Gale Date: Tue, 15 Aug 2023 07:27:30 -0700 Subject: [PATCH] Do not save unused data tensor for SDD operation. --- stk/backend/sputnik.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/stk/backend/sputnik.py b/stk/backend/sputnik.py index 04e330d..c6d94e8 100644 --- a/stk/backend/sputnik.py +++ b/stk/backend/sputnik.py @@ -247,7 +247,6 @@ def forward(ctx, ctx.save_for_backward( lhs, rhs, - data, offsets, row_indices, column_indices, @@ -257,8 +256,8 @@ def forward(ctx, ctx.shape = shape out = torch.empty( data.shape, - dtype=data.dtype, - device=data.device) + dtype=lhs.dtype, + device=lhs.device) backend.sdd(lhs, rhs, shape, @@ -272,7 +271,7 @@ def forward(ctx, @custom_bwd def backward(ctx, dy): lhs, rhs = ctx.saved_tensors[:2] - dy = (ctx.shape, dy) + ctx.saved_tensors[3:] + dy = (ctx.shape, dy) + ctx.saved_tensors[2:] trans_a = _is_transposed(lhs) trans_b = _is_transposed(rhs)