diff --git a/jax/_src/export/shape_poly_decision.py b/jax/_src/export/shape_poly_decision.py index d6a73cbb1450..425ffdee321c 100644 --- a/jax/_src/export/shape_poly_decision.py +++ b/jax/_src/export/shape_poly_decision.py @@ -306,18 +306,21 @@ def _bounds_for_sorted_terms(self, scope=scope): # `c =[eq] 0` AND `t*t_k*t_s + c*c_s` contains only terms smaller than t # AND c_s > 0. - # rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb` + # `rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb` + # `rest` contains only terms smaller than `t`. rest = _DimExpr._linear_combination_sorted_pairs(e, i, t_s, c._sorted_terms, 0, c_s) rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, rest, 0, BoundsPrecision.BEST) if rest_ub < np.inf: + # We have: e[i:]*t_s = rest - c*c_s <= rest_ub if t_s > 0: ub = min(ub, int(np.floor(rest_ub / t_s))) else: lb = max(lb, int(np.ceil(rest_ub / t_s))) if rest_lb > - np.inf and c_eq == Comparator.EQ: + # We have: e[i:]*t_s = rest - c*c_s = rest >= rest_lb if t_s > 0: lb = max(lb, int(np.ceil(rest_lb / t_s))) else: diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 4f43b54bb478..9ccdc53bbcae 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1068,22 +1068,35 @@ def threefry_2x32(keypair, count): msg = "threefry_2x32 requires uint32 arguments, got {}" raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]])) - odd_size = count.size % 2 - if not isinstance(odd_size, int): - msg = ("jax.random functions have limited support for shape polymorphism " - "when using threefry. " - f"In particular, the array size ({count.size}) must be even.") - raise core.InconclusiveDimensionOperation(msg) - - if odd_size: - x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2)) + flat_count = count.ravel() + odd_size = flat_count.shape[0] % 2 + if core.is_constant_dim(odd_size): + if odd_size: + x = list(jnp.split(jnp.concatenate([flat_count, np.uint32([0])]), 2)) + else: + x = list(jnp.split(flat_count, 2)) else: - x = list(jnp.split(count.ravel(), 2)) + # With symbolic shapes we cannot always tell statically if odd_size is true + # or false, so we rewrite this without a conditional. + flat_count_padded = jnp.concatenate([flat_count, np.uint32([0])]) + flat_count_padded_half_size = flat_count_padded.shape[0] // 2 + x = [ + lax.dynamic_slice(flat_count_padded, (0,), + (flat_count_padded_half_size,)), + lax.dynamic_slice(flat_count_padded, + (flat_count_padded_half_size,), + (flat_count_padded_half_size,)) + ] + assert x[0].shape == x[1].shape, (x[0].shape, x[1].shape) x = threefry2x32_p.bind(key1, key2, x[0], x[1]) out = jnp.concatenate(x) assert out.dtype == np.uint32 - return lax.reshape(out[:-1] if odd_size else out, count.shape) + if core.is_constant_dim(odd_size): + return lax.reshape(out[:-1] if odd_size else out, count.shape) + else: + out_no_padding = lax.dynamic_slice(out, (0,), (flat_count.shape[0],)) + return lax.reshape(out_no_padding, count.shape) def threefry_split(key: typing.Array, shape: Shape) -> typing.Array: diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index f08fa6eb53b4..7077116e2b3d 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2086,10 +2086,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] PolyHarness("random_uniform", f"error_not_even_{flags_name}", lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32), arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5), _f32)], - polymorphic_shapes=[None, "b0, ..."], - expect_error=( - (core.InconclusiveDimensionOperation, - "array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)), + polymorphic_shapes=[None, "b0, b1"], override_jax_config_flags=override_jax_config_flags) # type: ignore ] for key_size, flags_name, override_jax_config_flags in [ diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 7679e53c2982..9be11b7da51e 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -459,6 +459,10 @@ def test_bounds_floordiv(self): self.assertEqual(_bounds(-b // (a + 1)), (-np.inf, -1)) self.assertEqual(_bounds(a - a // 2), (1, np.inf)) + self.assertEqual(_bounds((a + 3) - (a + 3) // 2), (2, np.inf)) + self.assertEqual(_bounds((a + 6) - 1 * (a + 6) // 4), (6, np.inf)) + self.assertEqual(_bounds((a + 6) - 2 * ((a + 6) // 4)), (4, np.inf)) + self.assertEqual(_bounds((a + 6) - 3 * ((a + 6) // 4)), (2, np.inf)) self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1)) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Possible division by 0"): @@ -2982,31 +2986,30 @@ def f(x_ref): RandArg((3, 4, 5), _f32)], polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5, override_jax_config_flags=override_jax_config_flags), # type: ignore - # TODO(necula): The known dimensions product must be even. PolyHarness("random_categorical", f"axis=0_{flags_name}", lambda key, a: jax.random.categorical( jax.random.wrap_key_data(key), a, axis=0), arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 8), _f32)], - polymorphic_shapes=[None, "b0, ..."], + polymorphic_shapes=[None, "b0, b1"], override_jax_config_flags=override_jax_config_flags), # type: ignore PolyHarness("random_categorical", f"axis=1_{flags_name}", lambda key, a: jax.random.categorical( - jax.random.wrap_key_data(key), a, axis=1), + jax.random.wrap_key_data(key), a, axis=1), arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 8), _f32)], - polymorphic_shapes=[None, "b0, b1, ..."], + polymorphic_shapes=[None, "b0, b1, b2"], override_jax_config_flags=override_jax_config_flags), # type: ignore PolyHarness("random_categorical", f"axis=1_then_reshape_{flags_name}", lambda key, a: jax.random.categorical( - jax.random.wrap_key_data(key), a, axis=1).reshape(-1), + jax.random.wrap_key_data(key), a, axis=1).reshape(-1), arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 8), _f32)], - polymorphic_shapes=[None, "b0, b1, ..."], + polymorphic_shapes=[None, "b0, b1, b2"], override_jax_config_flags=override_jax_config_flags), # type: ignore PolyHarness("random_categorical", f"0_dim_{flags_name}", # One axis has 0 size lambda key, a: jax.random.categorical( - jax.random.wrap_key_data(key), a, axis=1), + jax.random.wrap_key_data(key), a, axis=1), arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 0), _f32)], polymorphic_shapes=[None, "b0, b1, ..."], @@ -3024,14 +3027,13 @@ def f(x_ref): RandArg((64, 12, 4), _f32), # sample on axis=1 RandArg((3, 4), _f32), StaticArg(use_p)], - # TODO(necula): threefry requires even-sized samples. polymorphic_shapes=[None, - "_, 2*b1, _" if arr_poly else None, + "b0, b1, b2" if arr_poly else None, "b3, b4" if shape_poly else None], # The array sampled dimension must be larger than res_shape.size symbolic_constraints=[ - "2*b1 >= 12" if arr_poly else "1 >= 0", - "2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0", + "b1 >= 12" if arr_poly else "1 >= 0", + "b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0", "12 >= b3*b4" if shape_poly else "1 >= 0" ], override_jax_config_flags=override_jax_config_flags, @@ -3058,24 +3060,20 @@ def f(x_ref): lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key), a.shape, dtype=_f32), arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4, 5), _f32)], - polymorphic_shapes=[None, "b0, ..."], + polymorphic_shapes=[None, "b0, 4, 5"], override_jax_config_flags=override_jax_config_flags), # type: ignore PolyHarness("random_uniform", f"even_2_{flags_name}", lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key), - (2 * a.shape[0], a.shape[1]), - dtype=_f32), + a.shape, dtype=_f32), arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4), _f32)], - polymorphic_shapes=[None, "b0, b1, ..."], + polymorphic_shapes=[None, "b0, 2*b1"], override_jax_config_flags=override_jax_config_flags), # type: ignore - PolyHarness("random_uniform", f"error_not_even_{flags_name}", + PolyHarness("random_uniform", f"error_unknown_evenness_{flags_name}", lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key), a.shape, dtype=_f32), arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5), _f32)], - polymorphic_shapes=[None, "b0, ..."], - expect_error=( - (core.InconclusiveDimensionOperation, - "array size .* must be even") if flags_name == "threefry_non_partitionable" else None), + polymorphic_shapes=[None, "b0, b1"], override_jax_config_flags=override_jax_config_flags) # type: ignore ] for key_size, flags_name, override_jax_config_flags in [