Skip to content

Commit

Permalink
Merge pull request #25731 from gnecula:poly_random_even
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712826758
  • Loading branch information
Google-ML-Automation committed Jan 7, 2025
2 parents 7997f08 + bc3306c commit 712bece
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 36 deletions.
5 changes: 4 additions & 1 deletion jax/_src/export/shape_poly_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,18 +306,21 @@ def _bounds_for_sorted_terms(self,
scope=scope):
# `c =[eq] 0` AND `t*t_k*t_s + c*c_s` contains only terms smaller than t
# AND c_s > 0.
# rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
# `rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
# `rest` contains only terms smaller than `t`.
rest = _DimExpr._linear_combination_sorted_pairs(e, i, t_s,
c._sorted_terms, 0, c_s)
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, rest, 0,
BoundsPrecision.BEST)
if rest_ub < np.inf:
# We have: e[i:]*t_s = rest - c*c_s <= rest_ub
if t_s > 0:
ub = min(ub, int(np.floor(rest_ub / t_s)))
else:
lb = max(lb, int(np.ceil(rest_ub / t_s)))

if rest_lb > - np.inf and c_eq == Comparator.EQ:
# We have: e[i:]*t_s = rest - c*c_s = rest >= rest_lb
if t_s > 0:
lb = max(lb, int(np.ceil(rest_lb / t_s)))
else:
Expand Down
35 changes: 24 additions & 11 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,22 +1068,35 @@ def threefry_2x32(keypair, count):
msg = "threefry_2x32 requires uint32 arguments, got {}"
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))

odd_size = count.size % 2
if not isinstance(odd_size, int):
msg = ("jax.random functions have limited support for shape polymorphism "
"when using threefry. "
f"In particular, the array size ({count.size}) must be even.")
raise core.InconclusiveDimensionOperation(msg)

if odd_size:
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
flat_count = count.ravel()
odd_size = flat_count.shape[0] % 2
if core.is_constant_dim(odd_size):
if odd_size:
x = list(jnp.split(jnp.concatenate([flat_count, np.uint32([0])]), 2))
else:
x = list(jnp.split(flat_count, 2))
else:
x = list(jnp.split(count.ravel(), 2))
# With symbolic shapes we cannot always tell statically if odd_size is true
# or false, so we rewrite this without a conditional.
flat_count_padded = jnp.concatenate([flat_count, np.uint32([0])])
flat_count_padded_half_size = flat_count_padded.shape[0] // 2
x = [
lax.dynamic_slice(flat_count_padded, (0,),
(flat_count_padded_half_size,)),
lax.dynamic_slice(flat_count_padded,
(flat_count_padded_half_size,),
(flat_count_padded_half_size,))
]
assert x[0].shape == x[1].shape, (x[0].shape, x[1].shape)

x = threefry2x32_p.bind(key1, key2, x[0], x[1])
out = jnp.concatenate(x)
assert out.dtype == np.uint32
return lax.reshape(out[:-1] if odd_size else out, count.shape)
if core.is_constant_dim(odd_size):
return lax.reshape(out[:-1] if odd_size else out, count.shape)
else:
out_no_padding = lax.dynamic_slice(out, (0,), (flat_count.shape[0],))
return lax.reshape(out_no_padding, count.shape)


def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:
Expand Down
5 changes: 1 addition & 4 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,10 +2086,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10]
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
expect_error=(
(core.InconclusiveDimensionOperation,
"array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)),
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags) # type: ignore
]
for key_size, flags_name, override_jax_config_flags in [
Expand Down
38 changes: 18 additions & 20 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ def test_bounds_floordiv(self):
self.assertEqual(_bounds(-b // (a + 1)), (-np.inf, -1))

self.assertEqual(_bounds(a - a // 2), (1, np.inf))
self.assertEqual(_bounds((a + 3) - (a + 3) // 2), (2, np.inf))
self.assertEqual(_bounds((a + 6) - 1 * (a + 6) // 4), (6, np.inf))
self.assertEqual(_bounds((a + 6) - 2 * ((a + 6) // 4)), (4, np.inf))
self.assertEqual(_bounds((a + 6) - 3 * ((a + 6) // 4)), (2, np.inf))
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Possible division by 0"):
Expand Down Expand Up @@ -2982,31 +2986,30 @@ def f(x_ref):
RandArg((3, 4, 5), _f32)],
polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5,
override_jax_config_flags=override_jax_config_flags), # type: ignore
# TODO(necula): The known dimensions product must be even.
PolyHarness("random_categorical", f"axis=0_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=0),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 8), _f32)],
polymorphic_shapes=[None, "b0, ..."],
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"axis=1_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1),
jax.random.wrap_key_data(key), a, axis=1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 8), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, b1, b2"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"axis=1_then_reshape_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1).reshape(-1),
jax.random.wrap_key_data(key), a, axis=1).reshape(-1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 8), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, b1, b2"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"0_dim_{flags_name}", # One axis has 0 size
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1),
jax.random.wrap_key_data(key), a, axis=1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 0), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
Expand All @@ -3024,14 +3027,13 @@ def f(x_ref):
RandArg((64, 12, 4), _f32), # sample on axis=1
RandArg((3, 4), _f32),
StaticArg(use_p)],
# TODO(necula): threefry requires even-sized samples.
polymorphic_shapes=[None,
"_, 2*b1, _" if arr_poly else None,
"b0, b1, b2" if arr_poly else None,
"b3, b4" if shape_poly else None],
# The array sampled dimension must be larger than res_shape.size
symbolic_constraints=[
"2*b1 >= 12" if arr_poly else "1 >= 0",
"2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
"b1 >= 12" if arr_poly else "1 >= 0",
"b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
"12 >= b3*b4" if shape_poly else "1 >= 0"
],
override_jax_config_flags=override_jax_config_flags,
Expand All @@ -3058,24 +3060,20 @@ def f(x_ref):
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
polymorphic_shapes=[None, "b0, 4, 5"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_uniform", f"even_2_{flags_name}",
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
(2 * a.shape[0], a.shape[1]),
dtype=_f32),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, 2*b1"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
PolyHarness("random_uniform", f"error_unknown_evenness_{flags_name}",
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
expect_error=(
(core.InconclusiveDimensionOperation,
"array size .* must be even") if flags_name == "threefry_non_partitionable" else None),
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags) # type: ignore
]
for key_size, flags_name, override_jax_config_flags in [
Expand Down

0 comments on commit 712bece

Please sign in to comment.