Skip to content

Commit

Permalink
fix: ivy rnn and while_loop changes to fix lstm transpilations (ivy-l…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong authored Jan 10, 2024
1 parent 70efc60 commit 0650802
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 15 deletions.
5 changes: 3 additions & 2 deletions ivy/functional/backends/torch/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ def cond(*_):


def while_loop(test_fn, body_fn, vars):
result = vars
if isinstance(vars, dict):
result = list(vars.values())
while test_fn(*result):
else:
result = list(vars)
while test_fn(*result) is True:
result = body_fn(*result)
if not isinstance(result, tuple):
result = (result,)
Expand Down
8 changes: 0 additions & 8 deletions ivy/functional/ivy/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
from ivy.utils.backend import current_backend
from ivy.func_wrapper import (
handle_array_like_without_promotion,
handle_backend_invalid,
handle_device,
outputs_to_ivy_arrays,
to_native_arrays_and_back,
)
from ivy.utils.exceptions import handle_exceptions


def if_else(
Expand Down Expand Up @@ -65,10 +61,6 @@ def _if_else(cond, body_fn, orelse_fn, **vars):
return _if_else(cond, body_fn, orelse_fn, **vars)


@handle_exceptions
@handle_backend_invalid
@outputs_to_ivy_arrays
@handle_device
def while_loop(
test_fn: Callable,
body_fn: Callable,
Expand Down
14 changes: 9 additions & 5 deletions ivy/functional/ivy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3328,7 +3328,7 @@ def rnn(
if not ivy.is_bool_dtype(mask):
mask = ivy.astype(mask, ivy.bool)
if len(mask.shape) == 2:
mask = ivy.expand_dims(mask)
mask = ivy.expand_dims(mask, axis=-1)
if not time_major:
mask = ivy.permute_dims(mask, (1, 0, *range(2, len(mask.shape))))

Expand Down Expand Up @@ -3423,7 +3423,13 @@ def _expand_mask(mask_t, input_t, fixed_dim=1):
input_time_zero, tuple(initial_states) + tuple(constants)
)

output_size = int(time_steps_t) if return_all_outputs else 1
if return_all_outputs:
if ivy.is_array(time_steps_t):
output_size = time_steps_t.to_scalar()
else:
output_size = time_steps_t
else:
output_size = 1
output_loop = ivy.empty(
(output_size, *output_time_zero.shape), dtype=output_time_zero.dtype
)
Expand All @@ -3437,10 +3443,8 @@ def _expand_mask(mask_t, input_t, fixed_dim=1):
if go_backwards:
mask = ivy.flip(mask, axis=0)

mask_list = ivy.unstack(mask)

def masking_fn(time):
return mask_list[time]
return mask[time]

def compute_masked_output(mask_t, output, mask):
tiled_mask_t = tuple(
Expand Down

0 comments on commit 0650802

Please sign in to comment.