-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jacrev, jacfwd and grad do not propagate sharding #25844
Comments
update both these examples work from jax.experimental.shard_map import shard_map
from functools import partial
def fn_with_constraint(a):
return lax.with_sharding_constraint(2 * a , sharding)
@partial(shard_map , mesh=mesh , in_specs=P('x'), out_specs=P('x'))
def fn(a):
return 2 * a
print(f"="*50)
print(f"Shardmapped Forward pass sharding {fn(x).sharding}")
print(f"Shardmapped Vmapped sharding {jax.vmap(fn)(batched_x).sharding}")
print(f"Shardmapped Gradient sharding {jax.grad(lambda x:jnp.sum(fn(x)))(x).sharding}")
print(f"Shardmapped Jacrev sharding {jax.jacrev(fn)(x).sharding}")
print(f"Shardmapped Jacfwd sharding {jax.jacfwd(fn)(x).sharding}")
print(f"Shardmapped Jvp sharding {jax.jvp(fn , (x,) , (jnp.ones_like(x),))[1].sharding}")
print(f"Shardmapped Vjp sharding {jax.vjp(fn , x)[0].sharding}") I think this beats the purpose of the compiler take the wheel strategy since in this case the gradients are fully replicated if I understood correctly if the user does not manually specify the output sharding So I would expect the sharding to propagate correctly or maybe to be able to access the sharding at trace time (from a DynamicTracer object) |
Another update here are two codes that tests most cases with the four level of partionning 1 - Compiler take the wheel import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
from jax import lax
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
from functools import partial
import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
mesh = Mesh(jax.devices(), ("x"))
sharding = NamedSharding(mesh, P("x"))
x = jnp.arange(8).astype(jnp.float32)
x = lax.with_sharding_constraint(x, sharding)
batched_x = jnp.stack([x, 2 * x])
def scaler_fn_shardmap(x):
@partial(shard_map, mesh=mesh, in_specs=P("x"), out_specs=P("x"))
def inner(x):
return lax.pmean(x, "x")
return jnp.mean(inner(x))
def scaler_fn_with_constraint(x):
x = lax.with_sharding_constraint(x, sharding)
return jnp.mean(x)
def scaler_fn(x):
return jnp.mean(x)
def fn_shardmap(x):
@partial(shard_map, mesh=mesh, in_specs=P("x"), out_specs=P("x"))
def inner(x):
return 2 * x
return inner(x)
def fn_with_constraint(x):
return lax.with_sharding_constraint(2 * x, sharding)
def fn(x):
return 2 * x
def trace(fn, name, scaler=False):
print("=" * 50)
f_pass = fn(x)
print(f"{name} Forward pass sharding {fn(x).sharding}")
print(f"{name} Vmapped sharding {jax.vmap(fn)(batched_x).sharding}")
if scaler:
print(f"{name} Gradient sharding {jax.grad(fn)(x).sharding}")
else:
print(
f"{name} Gradient sharding {jax.grad(lambda x:jnp.sum(fn(x)))(x).sharding}"
)
print(f"{name} Jacrev sharding {jax.jacrev(fn)(x).sharding}")
print(f"{name} Jacfwd sharding {jax.jacfwd(fn)(x).sharding}")
print(f"{name} Jvp sharding {jax.jvp(fn , (x,) , (jnp.ones_like(x),))[1].sharding}")
print(f"{name} Vjp sharding {jax.vjp(fn , x)[1](f_pass)[0].sharding}")
trace(scaler_fn_shardmap, "Scaler Shardmapped" , scaler=True)
trace(scaler_fn_with_constraint, "Scaler With Constraint" , scaler=True)
trace(scaler_fn, "Scaler No Constraint" , scaler=True)
trace(fn_shardmap, "Shardmapped")
trace(fn_with_constraint, "With Constraint")
trace(fn, "No Constraint") Custom Partionning import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
from jax import lax
from jax.interpreters import mlir, ad, batching
import jax.numpy as jnp
import jax
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import dispatch
import jax.extend as jex
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P, Mesh
from functools import partial
mesh = Mesh(jax.devices(), ("x"))
sharding = NamedSharding(mesh, P("x"))
# ================================
# Double Primitive Rules
# ================================
# Step 1: Define the Primitive
double_prim_p = jex.core.Primitive("double_prim")
dispatch.prim_requires_devices_during_lowering.add(double_prim_p)
# Step 2: Define the Implementation
def origin_fn(x, scaler):
return scaler * x
@partial(custom_partitioning, static_argnums=(1,))
def double_prim_impl(x, scaler):
return origin_fn(x, scaler)
def infer_sharding_from_operands(scaler, mesh, arg_infos, result_infos):
return arg_infos[0].sharding
def partition(scaler, mesh, arg_infos, result_infos):
input_sharding = arg_infos[0].sharding
output_sharding = result_infos.sharding
input_mesh = input_sharding.mesh
def impl(operand):
return scaler * operand
return input_mesh, impl, output_sharding, (input_sharding,)
@partial(custom_partitioning, static_argnums=(1, 2))
def vmapped_double_prim_impl(x, batch_dims, scaler):
strip_static_fn = lambda x: origin_fn(x, scaler)
return jax.vmap(strip_static_fn, in_axes=batch_dims)(x)
def strip_shaped_dtype_struct(arg_info):
new_shape = arg_info.shape[1:] # Remove the batch dimension
new_spec = arg_info.sharding.spec[1:] # Adjust the PartitionSpec
new_sharding = NamedSharding(arg_info.sharding.mesh, P(*new_spec))
return jax.ShapeDtypeStruct(
new_shape, arg_info.dtype, sharding=new_sharding, weak_type=arg_info.weak_type
)
def strip_shaped_array(shaped_array):
return shaped_array.update(shape=shaped_array.shape[1:])
def add_batch_dimension_to_sharding(arg_info):
# Add the batch dimension back to the shape and PartitionSpec
new_spec = (None, *arg_info.spec) # Add `None` to PartitionSpec
new_sharding = NamedSharding(arg_info.mesh, P(*new_spec))
return new_sharding
def v_infer_sharding_from_operands(batch_dims, scaler, mesh, arg_infos, result_infos):
un_batched_arg_infos = jax.tree.map(strip_shaped_dtype_struct, arg_infos)
un_batched_result_infos = jax.tree.map(strip_shaped_array, result_infos)
un_batched_output_sharding = infer_sharding_from_operands(
scaler, mesh, un_batched_arg_infos, un_batched_result_infos
)
batched_results = add_batch_dimension_to_sharding(un_batched_output_sharding)
return batched_results
def v_partition(batch_dims, scaler, mesh, arg_infos, result_infos):
un_batched_arg_infos = jax.tree.map(strip_shaped_dtype_struct, arg_infos)
un_batched_result_infos = jax.tree.map(strip_shaped_dtype_struct, result_infos)
input_mesh, impl, output_sharding, input_shardings = partition(
scaler, mesh, un_batched_arg_infos, un_batched_result_infos
)
output_sharding = add_batch_dimension_to_sharding(output_sharding)
input_shardings = jax.tree.map(add_batch_dimension_to_sharding, input_shardings)
impl = jax.vmap(impl, in_axes=batch_dims)
return input_mesh, impl, output_sharding, input_shardings
vmapped_double_prim_impl.def_partition(
infer_sharding_from_operands=v_infer_sharding_from_operands, partition=v_partition
)
# Step 3: Define Abstract Evaluation
def double_prim_abstract_eval(*args, **kwargs):
return jax.make_jaxpr(origin_fn, static_argnums=1)(*args, **kwargs).out_avals[0]
# Step 4: Define JVP Rule
def double_prim_jvp_rule(primals, tangents, scaler):
(x,) = primals
(t,) = tangents
# Forward computation
primal_out = double_prim_call(x, scaler)
# Tangent computation (reuse the primitive itself)
tangent_out = double_prim_call(t, scaler)
return primal_out, tangent_out
# Step 5: Define Transpose Rule
def double_prim_transpose_rule(ct_out, x, scaler):
ct_x = 2 * ct_out if ad.is_undefined_primal(x) else None
return (ct_x,)
# Step 6: Define Batch Rule
def double_prim_batch_rule(batched_args, batch_dims, *args, **kwargs):
(x,) = batched_args
(bx,) = batch_dims
# Apply vmapped double operation
res = vmapped_double_prim_impl(x, bx, *args, **kwargs)
return res, 0
# Step 7: Register the Primitive
double_prim_p.def_impl(double_prim_impl) # Implementation
double_prim_p.def_abstract_eval(double_prim_abstract_eval) # Abstract Eval
mlir.register_lowering(
double_prim_p, mlir.lower_fun(double_prim_impl, multiple_results=False)
) # Lowering
ad.primitive_jvps[double_prim_p] = double_prim_jvp_rule # JVP Rule
ad.primitive_transposes[double_prim_p] = double_prim_transpose_rule # Transpose Rule
batching.primitive_batchers[double_prim_p] = double_prim_batch_rule # Batch Rule
# Define a Python wrapper for the primitive
@partial(jax.jit, static_argnums=(1,))
def double_prim_call(x, scaler=2):
return double_prim_p.bind(x, scaler=scaler)
double_prim_impl.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands, partition=partition
)
@jax.custom_vjp
def double_vjp(x):
return double_prim_call(x)
def double_vjp_fwd(x):
return double_vjp(x), None
def double_vjp_bwd(_, g):
return (2 * g,)
@jax.custom_jvp
def double_jvp(x):
return double_prim_call(x)
@double_jvp.defjvp
def double_jvp_fwd(p, t):
(x,) = p
(t,) = t
return double_jvp(x), double_jvp(t)
double_vjp.defvjp(double_vjp_fwd, double_vjp_bwd)
# ================================
# Linear Double Primitive Testing
# ================================
x = jnp.arange(8).astype(jnp.float32)
x = lax.with_sharding_constraint(x, sharding)
batched_x = jnp.stack([x, 2 * x])
def trace(fn, name, scaler=False, vjp=False):
print("=" * 50)
f_pass = fn(x)
print(f"{name} Forward pass sharding {fn(x).sharding}")
print(f"{name} Vmapped sharding {jax.vmap(fn)(batched_x).sharding}")
if scaler:
print(f"{name} Gradient sharding {jax.grad(fn)(x).sharding}")
else:
print(
f"{name} Gradient sharding {jax.grad(lambda x:jnp.sum(fn(x)))(x).sharding}"
)
print(f"{name} Jacrev sharding {jax.jacrev(fn)(x).sharding}")
if not vjp:
print(f"{name} Jacfwd sharding {jax.jacfwd(fn)(x).sharding}")
print(
f"{name} Jvp sharding {jax.jvp(fn , (x,) , (jnp.ones_like(x),))[1].sharding}"
)
print(f"{name} Vjp sharding {jax.vjp(fn , x)[1](f_pass)[0].sharding}")
trace(double_jvp, "custom_partial_double_jvp")
trace(double_vjp, "custom_partial_double_vjp", vjp=True) |
Here is the summary
In general The only issue with using IMO getting access to a static sharding just like a shape is crucial in this case |
Description
Hello,
I noticed that unlike vjp and jvp .. jacrev jacfwd and grad do not propagate sharding correctly
is this normal?
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: