Skip to content

Commit

Permalink
[torch_xla2] Fix reenabled op info tests (#8548)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg authored Jan 10, 2025
1 parent 196cab3 commit af223d3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
7 changes: 2 additions & 5 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@

skiplist = {
"_segment_reduce",
"_unsafe_masked_index_put_accumulate",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"byte",
"cat",
"cholesky_solve",
"cov",
"diagonal_copy",
"gather",
"geqrf",
"histogram", # hard op: AssertionError: Tensor-likes are not close!
"histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got <class 'list'> at position 1.
Expand All @@ -44,7 +41,6 @@
"normal",
"ormqr",
"pca_lowrank",
"scatter",
"searchsorted",
"special.airy_ai",
"special.scaled_modified_bessel_k0",
Expand Down Expand Up @@ -96,7 +92,8 @@
'nn.functional.dropout',
}

atol_dict = {"linalg.eig": (2e0, 3e0),
atol_dict = {"cov": (2e-1, 2e-4),
"linalg.eig": (2e0, 3e0),
"linalg.eigh": (5e1, 3e0),
"linalg.eigvalsh": (5e1, 3e0),
"linalg.pinv": (8e-1, 2e0),
Expand Down
12 changes: 7 additions & 5 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,9 +1745,8 @@ def _aten_atan(self):
@op(torch.ops.aten.scatter)
@op(torch.ops.aten.scatter_reduce)
def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
if isinstance(src, float):
dtype = _torch_binary_scalar_type(src, input)
src = jnp.array(src, dtype=dtype)
if not isinstance(src, jnp.ndarray):
src = jnp.array(src, dtype=input.dtype)
input_indexes, source_indexes = _scatter_index(dim, index)
# "Zero out" target elements when not included
if not include_self:
Expand Down Expand Up @@ -2596,6 +2595,9 @@ def _aten_frexp(input):
def _aten_gather(input, dim, index):
if input.ndim == 0:
return jnp.broadcast_to(input, index.shape)
# short circuit for empty outputs
if not all(index.shape):
return jnp.zeros(index.shape, dtype=input.dtype)
if dim < 0:
dim += input.ndim
input_indexes, source_indexes = _scatter_index(dim, index)
Expand Down Expand Up @@ -4732,9 +4734,9 @@ def _new_empty_strided(self, size, stride, dtype=None, **kwargs):
return jnp.empty(size, dtype=jax_dtype)


@op(torch.ops.aten._unsafe_index_put, is_jax_function=False)
@op(torch.ops.aten._unsafe_index_put)
def _aten_unsafe_index_put(self, indices, values, accumulate=False):
return self.index_put_(indices, values, accumulate)
return _aten_index_put(self, indices, values, accumulate)


@op(torch.ops.aten.conj_physical,
Expand Down

0 comments on commit af223d3

Please sign in to comment.