Skip to content

Commit

Permalink
Adds a minimal but viable implementation of string arrays (with `nump…
Browse files Browse the repository at this point in the history
…y.dtypes.StringDType`) in JAX. Currently this only supports making of a string array by means of either `jax.numpy.asarray` or `jax.device_put` and reading it back with `jax.device_get`.

PiperOrigin-RevId: 716042460
  • Loading branch information
Google-ML-Automation committed Jan 17, 2025
1 parent 12b59f8 commit 6d12971
Show file tree
Hide file tree
Showing 9 changed files with 341 additions and 24 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ pytype_strict_library(
":traceback_util",
":typing",
":util",
"//jax/_src/lib",
] + py_deps("ml_dtypes") + py_deps("numpy"),
)

Expand Down
30 changes: 30 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member

try:
import numpy.dtypes as np_dtypes # pylint: disable=g-import-not-at-top
except ImportError:
np_dtypes = None # type: ignore

traceback_util.register_exclusion(__file__)

Expand Down Expand Up @@ -989,6 +994,14 @@ def vmap_f(*args, **kwargs):
"to the positional arguments passed to the function, "
f"but got {len(in_axes)=}, {len(args)=}")
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)

# StringDTtype arrays are not supported for vmap.
if hasattr(np_dtypes, "StringDType") and any(
hasattr(x, "dtype") and isinstance(x.dtype, np_dtypes.StringDType)
for x in args_flat
):
raise TypeError("StringDType arrays are not supported for vmap")

f = lu.wrap_init(fun)
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
Expand Down Expand Up @@ -2200,6 +2213,17 @@ def _infer_src_sharding(src, x) -> Sharding | None:
return val.sharding
return None

# Checks if sharding is compatible with StringDType arrays.
def _check_string_compatible_sharding(s):
if isinstance(s, xc.Device) and s.device_kind == "cpu":
return
if isinstance(s, Sharding) and next(iter(s.device_set)).device_kind == "cpu":
return
raise TypeError(
"StringDType arrays can only be sharded to CPU devices. Received"
f" invalid value: {s}"
)
#TODO(jmudigonda): Add checks for Layout and TransferToMemoryKind.

# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use
# that to check if shardings are compatible with the input.
Expand All @@ -2211,6 +2235,12 @@ def _check_sharding(aval, s):
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
" `jax.Device`, `Layout` or a pytree of these values. Received"
f" invalid value: {s}")

if hasattr(np_dtypes, "StringDType") and xla_extension_version >= 304:
if isinstance(aval.dtype, np_dtypes.StringDType):
_check_string_compatible_sharding(s)
return

if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.token_shaped_array
Expand Down
17 changes: 12 additions & 5 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
is_single_device_sharding)
import numpy as np

try:
import numpy.dtypes as np_dtypes # pylint: disable=g-import-not-at-top
except ImportError:
np_dtypes = None # type: ignore

JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
Expand Down Expand Up @@ -279,12 +283,15 @@ def _is_bint_axis_size(d: core.AxisSize) -> bool:
type(d.aval.dtype) is core.bint)
return False


def check_arg(arg: Any):
if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
"JAX type.")

if isinstance(arg, core.Tracer):
return
aval = core.abstractify(arg)
if (
hasattr(np_dtypes, "StringDType")
and isinstance(aval.dtype, np_dtypes.StringDType)
):
raise TypeError("StringDType arrays are not supported by jit")

def jaxpr_replicas(jaxpr: core.Jaxpr) -> int:
"""The number of replicas needed for a jaxpr.
Expand Down
42 changes: 33 additions & 9 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import numpy as np

from jax._src import config
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC

Expand Down Expand Up @@ -478,18 +479,41 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}

_string_types: list[JAXType] = []
try:
import numpy.dtypes as np_dtypes
if hasattr(np_dtypes, 'StringDType') and xla_extension_version >= 304:
_string_types: list[JAXType] = [np_dtypes.StringDType()] # type: ignore
except ImportError:
np_dtypes = None # type: ignore

_jax_types = (
_bool_types + _int_types + _float_types + _complex_types + _string_types
)
_jax_dtype_set = {
float0,
*_bool_types,
*_int_types,
*_float_types,
*_complex_types,
*_string_types,
}

_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {
*_signed_types,
*_unsigned_types,
*_float_types,
*_complex_types,
},
'string': {*_string_types},
}


Expand Down
35 changes: 30 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,39 @@
from jax._src.lax.lax import (PrecisionLike,_array_copy,
_sort_le_comparator, _sort_lt_comparator)
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.sharding_impls import (
NamedSharding,
PartitionSpec as P,
SingleDeviceSharding,
canonicalize_sharding,
)
from jax._src.typing import (
Array, ArrayLike,
DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar,
)
from jax._src.util import (
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
tuple_replace)
NumpyComplexWarning,
canonicalize_axis as _canonicalize_axis,
ceil_of_ratio,
partition_list,
safe_zip,
set_module,
tuple_replace,
unzip2,
)
from jax.sharding import Sharding
from jax._src.sharding_impls import (SingleDeviceSharding, NamedSharding,
PartitionSpec as P, canonicalize_sharding)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
import numpy as np

try:
from numpy import dtypes as np_dtypes
except ImportError:
np_dtypes = None # type: ignore
import opt_einsum

export = set_module('jax.numpy')
Expand Down Expand Up @@ -5564,6 +5580,15 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted.
return jax.device_put(object)

# Do a device_put for string arrays since XLA does not support string dtype.
if xla_extension_version >= 304:
if isinstance(object, np.ndarray) and hasattr(np_dtypes, "StringDType") and isinstance(object.dtype, np_dtypes.StringDType): # type: ignore
if (ndmin > 0) and (ndmin != object.ndim):
raise TypeError(
f"ndmin {ndmin} does not match ndims {object.ndim} of input array"
)
return jax.device_put(x=object, device=device)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
Expand Down
10 changes: 10 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,15 @@ jax_py_test(
],
)

jax_multiplatform_test(
name = "string_array_test",
srcs = ["string_array_test.py"],
# deps = [
# "//jax",
# "//jax:test_util",
# ],
)

jax_multiplatform_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
Expand Down Expand Up @@ -1590,6 +1599,7 @@ exports_files(
"shard_map_test.py",
"transfer_guard_test.py",
"layout_test.py",
"string_array_test.py",
],
visibility = jax_test_file_visibility,
)
Expand Down
21 changes: 18 additions & 3 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@

import numpy as np

try:
import numpy.dtypes as np_dtypes
except ImportError:
np_dtypes = None # type: ignore


import jax
from jax import numpy as jnp
from jax._src import earray
Expand Down Expand Up @@ -770,16 +776,25 @@ def g(x):


class TestPromotionTables(jtu.JaxTestCase):
# Not all types are promotable. For example, currently StringDType is not
# promotable.
if hasattr(np_dtypes, 'StringDType'):
promotable_types = [
x for x in dtypes._jax_types if not isinstance(x, np_dtypes.StringDType)
]
else:
promotable_types = dtypes._jax_types


@parameterized.named_parameters(
{"testcase_name": f"_{jaxtype=}", "jaxtype": jaxtype}
for jaxtype in dtypes._jax_types + dtypes._weak_types)
for jaxtype in promotable_types + dtypes._weak_types)
def testJaxTypeFromType(self, jaxtype):
self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)), jaxtype)

@parameterized.named_parameters(
{"testcase_name": f"_{jaxtype=}", "jaxtype": jaxtype}
for jaxtype in dtypes._jax_types + dtypes._weak_types)
for jaxtype in promotable_types + dtypes._weak_types)
def testJaxTypeFromVal(self, jaxtype):
from jax._src.export import shape_poly
if jaxtype is shape_poly._DimExpr:
Expand All @@ -795,7 +810,7 @@ def testJaxTypeFromVal(self, jaxtype):

@parameterized.named_parameters(
{"testcase_name": f"_{dtype=}", "dtype": dtype}
for dtype in dtypes._jax_types)
for dtype in promotable_types)
def testJaxTypeWeak(self, dtype):
jax_type = dtypes._jax_type(dtype, weak_type=True)
if dtypes.issubdtype(jax_type, np.complexfloating):
Expand Down
13 changes: 11 additions & 2 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@

import numpy as np

try:
import numpy.dtypes as np_dtypes
except ImportError:
np_dtypes = None # type: ignore

# ruff: noqa: F401
try:
import flatbuffers
Expand Down Expand Up @@ -407,7 +412,6 @@ def f(x1, x2):
self.assertEqual(tree_util.tree_structure(res2),
tree_util.tree_structure(res))


def test_error_wrong_intree(self):
def f(a_b_pair, *, c):
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
Expand Down Expand Up @@ -1002,6 +1006,12 @@ def f_jax(x): # x: bool[b]
for dtype in dtypes._jax_types if dtype != np.dtype("bool")
])
def test_poly_numeric_dtypes(self, dtype=np.int32):
if hasattr(np_dtypes, "StringDType") and isinstance(
dtype, np_dtypes.StringDType
):
self.skipTest(
"StringDType is not a numeric type"
) # TODO(jmudigonda): revisit.
if str(dtype) in {"float8_e4m3b11fnuz",
"float8_e4m3fnuz",
"float8_e5m2fnuz",
Expand Down Expand Up @@ -1617,7 +1627,6 @@ def test_multi_platform_unknown_platform(self):
platforms=("tpu", "cpu", "cuda", "other"))(x)
self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "other"))


def test_multi_platform_with_donation(self):
f = jax.jit(jnp.sin, donate_argnums=(0,))
x = np.arange(3, dtype=np.float32)
Expand Down
Loading

0 comments on commit 6d12971

Please sign in to comment.