Skip to content

Commit

Permalink
refactor: Reduced calls in torch and tf backends of ivy.concat
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTz committed Nov 8, 2023
1 parent 96a8ce7 commit 2d95392
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 28 deletions.
39 changes: 19 additions & 20 deletions ivy/functional/backends/tensorflow/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,25 @@ def concat(
axis: int = 0,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
is_tuple = type(xs) is tuple
is_axis_none = axis is None
if is_tuple:
xs = list(xs)
highest_dtype = xs[0].dtype
for i in xs:
highest_dtype = ivy.as_native_dtype(ivy.promote_types(highest_dtype, i.dtype))

for i in range(len(xs)):
if is_axis_none:
xs[i] = tf.reshape(xs[i], -1)
xs[i] = ivy.astype(xs[i], highest_dtype, copy=False).to_native()
if is_axis_none:
axis = 0
if is_tuple:
xs = tuple(xs)
try:
return tf.concat(xs, axis)
except (tf.errors.InvalidArgumentError, np.AxisError) as error:
raise ivy.utils.exceptions.IvyIndexError(error)
if axis is not None:
try:
return tf.concat(xs, axis)
except tf.errors.InvalidArgumentError as error:
if "(zero-based) was expected to be" in error.message:
highest_dtype = xs[0].dtype
for i in xs:
highest_dtype = ivy.promote_types(highest_dtype, i.dtype)
highest_dtype = ivy.as_native_dtype(highest_dtype)
return tf.concat(
[
tf.cast(x, highest_dtype) if x.dtype != highest_dtype else x
for x in xs
],
axis,
)
else:
raise
return concat([tf.reshape(x, -1) for x in xs], axis=0)


def expand_dims(
Expand Down
9 changes: 1 addition & 8 deletions ivy/functional/backends/torch/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,7 @@ def concat(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if axis is None:
is_tuple = type(xs) is tuple
if is_tuple:
xs = list(xs)
for i in range(len(xs)):
xs[i] = torch.flatten(xs[i])
if is_tuple:
xs = tuple(xs)
axis = 0
return torch.cat([torch.flatten(x) for x in xs], dim=0, out=out)
return torch.cat(xs, dim=axis, out=out)


Expand Down

0 comments on commit 2d95392

Please sign in to comment.