Skip to content

Commit

Permalink
[shape_poly] Improve threefry with symbolic shapes
Browse files Browse the repository at this point in the history
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.

Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.

We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
  • Loading branch information
gnecula committed Jan 7, 2025
1 parent 7997f08 commit bc3306c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 36 deletions.
5 changes: 4 additions & 1 deletion jax/_src/export/shape_poly_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,18 +306,21 @@ def _bounds_for_sorted_terms(self,
scope=scope):
# `c =[eq] 0` AND `t*t_k*t_s + c*c_s` contains only terms smaller than t
# AND c_s > 0.
# rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
# `rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
# `rest` contains only terms smaller than `t`.
rest = _DimExpr._linear_combination_sorted_pairs(e, i, t_s,
c._sorted_terms, 0, c_s)
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, rest, 0,
BoundsPrecision.BEST)
if rest_ub < np.inf:
# We have: e[i:]*t_s = rest - c*c_s <= rest_ub
if t_s > 0:
ub = min(ub, int(np.floor(rest_ub / t_s)))
else:
lb = max(lb, int(np.ceil(rest_ub / t_s)))

if rest_lb > - np.inf and c_eq == Comparator.EQ:
# We have: e[i:]*t_s = rest - c*c_s = rest >= rest_lb
if t_s > 0:
lb = max(lb, int(np.ceil(rest_lb / t_s)))
else:
Expand Down
35 changes: 24 additions & 11 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,22 +1068,35 @@ def threefry_2x32(keypair, count):
msg = "threefry_2x32 requires uint32 arguments, got {}"
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))

odd_size = count.size % 2
if not isinstance(odd_size, int):
msg = ("jax.random functions have limited support for shape polymorphism "
"when using threefry. "
f"In particular, the array size ({count.size}) must be even.")
raise core.InconclusiveDimensionOperation(msg)

if odd_size:
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
flat_count = count.ravel()
odd_size = flat_count.shape[0] % 2
if core.is_constant_dim(odd_size):
if odd_size:
x = list(jnp.split(jnp.concatenate([flat_count, np.uint32([0])]), 2))
else:
x = list(jnp.split(flat_count, 2))
else:
x = list(jnp.split(count.ravel(), 2))
# With symbolic shapes we cannot always tell statically if odd_size is true
# or false, so we rewrite this without a conditional.
flat_count_padded = jnp.concatenate([flat_count, np.uint32([0])])
flat_count_padded_half_size = flat_count_padded.shape[0] // 2
x = [
lax.dynamic_slice(flat_count_padded, (0,),
(flat_count_padded_half_size,)),
lax.dynamic_slice(flat_count_padded,
(flat_count_padded_half_size,),
(flat_count_padded_half_size,))
]
assert x[0].shape == x[1].shape, (x[0].shape, x[1].shape)

x = threefry2x32_p.bind(key1, key2, x[0], x[1])
out = jnp.concatenate(x)
assert out.dtype == np.uint32
return lax.reshape(out[:-1] if odd_size else out, count.shape)
if core.is_constant_dim(odd_size):
return lax.reshape(out[:-1] if odd_size else out, count.shape)
else:
out_no_padding = lax.dynamic_slice(out, (0,), (flat_count.shape[0],))
return lax.reshape(out_no_padding, count.shape)


def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:
Expand Down
5 changes: 1 addition & 4 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,10 +2086,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10]
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
expect_error=(
(core.InconclusiveDimensionOperation,
"array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)),
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags) # type: ignore
]
for key_size, flags_name, override_jax_config_flags in [
Expand Down
38 changes: 18 additions & 20 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ def test_bounds_floordiv(self):
self.assertEqual(_bounds(-b // (a + 1)), (-np.inf, -1))

self.assertEqual(_bounds(a - a // 2), (1, np.inf))
self.assertEqual(_bounds((a + 3) - (a + 3) // 2), (2, np.inf))
self.assertEqual(_bounds((a + 6) - 1 * (a + 6) // 4), (6, np.inf))
self.assertEqual(_bounds((a + 6) - 2 * ((a + 6) // 4)), (4, np.inf))
self.assertEqual(_bounds((a + 6) - 3 * ((a + 6) // 4)), (2, np.inf))
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Possible division by 0"):
Expand Down Expand Up @@ -2982,31 +2986,30 @@ def f(x_ref):
RandArg((3, 4, 5), _f32)],
polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5,
override_jax_config_flags=override_jax_config_flags), # type: ignore
# TODO(necula): The known dimensions product must be even.
PolyHarness("random_categorical", f"axis=0_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=0),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 8), _f32)],
polymorphic_shapes=[None, "b0, ..."],
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"axis=1_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1),
jax.random.wrap_key_data(key), a, axis=1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 8), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, b1, b2"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"axis=1_then_reshape_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1).reshape(-1),
jax.random.wrap_key_data(key), a, axis=1).reshape(-1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 8), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, b1, b2"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"0_dim_{flags_name}", # One axis has 0 size
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1),
jax.random.wrap_key_data(key), a, axis=1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 0), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
Expand All @@ -3024,14 +3027,13 @@ def f(x_ref):
RandArg((64, 12, 4), _f32), # sample on axis=1
RandArg((3, 4), _f32),
StaticArg(use_p)],
# TODO(necula): threefry requires even-sized samples.
polymorphic_shapes=[None,
"_, 2*b1, _" if arr_poly else None,
"b0, b1, b2" if arr_poly else None,
"b3, b4" if shape_poly else None],
# The array sampled dimension must be larger than res_shape.size
symbolic_constraints=[
"2*b1 >= 12" if arr_poly else "1 >= 0",
"2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
"b1 >= 12" if arr_poly else "1 >= 0",
"b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
"12 >= b3*b4" if shape_poly else "1 >= 0"
],
override_jax_config_flags=override_jax_config_flags,
Expand All @@ -3058,24 +3060,20 @@ def f(x_ref):
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
polymorphic_shapes=[None, "b0, 4, 5"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_uniform", f"even_2_{flags_name}",
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
(2 * a.shape[0], a.shape[1]),
dtype=_f32),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, 2*b1"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
PolyHarness("random_uniform", f"error_unknown_evenness_{flags_name}",
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
expect_error=(
(core.InconclusiveDimensionOperation,
"array size .* must be even") if flags_name == "threefry_non_partitionable" else None),
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags) # type: ignore
]
for key_size, flags_name, override_jax_config_flags in [
Expand Down

0 comments on commit bc3306c

Please sign in to comment.