Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jacrev, jacfwd and grad do not propagate sharding #25844

Open
ASKabalan opened this issue Jan 10, 2025 · 3 comments
Open

jacrev, jacfwd and grad do not propagate sharding #25844

ASKabalan opened this issue Jan 10, 2025 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@ASKabalan
Copy link

ASKabalan commented Jan 10, 2025

Description

Hello,

I noticed that unlike vjp and jvp .. jacrev jacfwd and grad do not propagate sharding correctly

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

from jax import lax
import jax.numpy as jnp
import jax

from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
pdims = (8,)
mesh = jax.make_mesh(pdims , axis_names=('x'))
sharding = NamedSharding(mesh, P('x'))


def fn(a):
    return 2 * a

x = jnp.arange(8).astype(jnp.float32)
x = lax.with_sharding_constraint(x, sharding)
batched_x = jnp.stack([x , 2*x])

batch_sharding = batched_x.sharding

print(f"Forward pass sharding {fn(x).sharding}")
print(f"Vmapped sharding {jax.vmap(fn)(batched_x).sharding}")
print(f"Gradient sharding {jax.grad(lambda x:jnp.sum(fn(x)))(x).sharding}")
print(f"Jacrev sharding {jax.jacrev(fn)(x).sharding}")
print(f"Jacfwd sharding {jax.jacfwd(fn)(x).sharding}")
print(f"Jvp sharding {jax.jvp(fn , (x,) , (jnp.ones_like(x),))[1].sharding}")
print(f"Vjp sharding {jax.vjp(fn , x)[0].sharding}")
# Output:
#Forward pass sharding NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec('x',), memory_kind=unpinned_host)
#Vmapped sharding NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec(None, 'x'), memory_kind=unpinned_host)
#Gradient sharding NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec(), memory_kind=unpinned_host)
#Jacrev sharding NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec(), memory_kind=unpinned_host)
#Jacfwd sharding NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec(), memory_kind=unpinned_host)
#Jvp sharding NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec('x',), memory_kind=unpinned_host)
#Vjp sharding NamedSharding(mesh=Mesh('x': 8), spec=PartitionSpec('x',), memory_kind=unpinned_host)

is this normal?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.1
python: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0]
device info: cpu-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='apc2324', release='6.8.0-51-generic', version='#52-Ubuntu SMP PREEMPT_DYNAMIC Thu Dec  5 13:09:44 UTC 2024', machine='x86_64')


$ nvidia-smi
Sat Jan 11 00:30:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4060 ...    Off |   00000000:01:00.0 Off |                  N/A |
| N/A   43C    P8              6W /   35W |     110MiB /   8188MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1579      G   /usr/lib/xorg/Xorg                              4MiB |
|    0   N/A  N/A      9369      C   .../micromamba/envs/jax/bin/python3.10         98MiB |
+-----------------------------------------------------------------------------------------+
@ASKabalan ASKabalan added the bug Something isn't working label Jan 10, 2025
@ASKabalan
Copy link
Author

ASKabalan commented Jan 11, 2025

update

both these examples work

from jax.experimental.shard_map import shard_map
from functools import partial


def fn_with_constraint(a):
    return lax.with_sharding_constraint(2 * a , sharding)

@partial(shard_map , mesh=mesh , in_specs=P('x'), out_specs=P('x'))
def fn(a):
    return 2 * a

print(f"="*50)
print(f"Shardmapped Forward pass sharding {fn(x).sharding}")
print(f"Shardmapped Vmapped sharding {jax.vmap(fn)(batched_x).sharding}")
print(f"Shardmapped Gradient sharding {jax.grad(lambda x:jnp.sum(fn(x)))(x).sharding}")
print(f"Shardmapped Jacrev sharding {jax.jacrev(fn)(x).sharding}")
print(f"Shardmapped Jacfwd sharding {jax.jacfwd(fn)(x).sharding}")
print(f"Shardmapped Jvp sharding {jax.jvp(fn , (x,) , (jnp.ones_like(x),))[1].sharding}")
print(f"Shardmapped Vjp sharding {jax.vjp(fn , x)[0].sharding}")

I think this beats the purpose of the compiler take the wheel strategy since in this case the gradients are fully replicated if I understood correctly if the user does not manually specify the output sharding

So I would expect the sharding to propagate correctly or maybe to be able to access the sharding at trace time (from a DynamicTracer object)

@ASKabalan
Copy link
Author

Another update

here are two codes that tests most cases with the four level of partionning

1 - Compiler take the wheel
2 - with_sharding_constraints
3 - shardmap
4 - custom_partionning

import os

os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

from jax import lax
import jax.numpy as jnp

from jax.experimental.shard_map import shard_map
from functools import partial
import jax

from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

mesh = Mesh(jax.devices(), ("x"))
sharding = NamedSharding(mesh, P("x"))

x = jnp.arange(8).astype(jnp.float32)
x = lax.with_sharding_constraint(x, sharding)
batched_x = jnp.stack([x, 2 * x])

def scaler_fn_shardmap(x):
    @partial(shard_map, mesh=mesh, in_specs=P("x"), out_specs=P("x"))
    def inner(x):
        return lax.pmean(x, "x")

    return jnp.mean(inner(x))

def scaler_fn_with_constraint(x):
    x = lax.with_sharding_constraint(x, sharding)
    return jnp.mean(x)

def scaler_fn(x):
    return jnp.mean(x)

def fn_shardmap(x):
    @partial(shard_map, mesh=mesh, in_specs=P("x"), out_specs=P("x"))
    def inner(x):
        return 2 * x

    return inner(x)

def fn_with_constraint(x):
    return lax.with_sharding_constraint(2 * x, sharding)

def fn(x):
    return 2 * x


def trace(fn, name, scaler=False):
    print("=" * 50)
    f_pass = fn(x)
    print(f"{name} Forward pass sharding {fn(x).sharding}")
    print(f"{name} Vmapped sharding {jax.vmap(fn)(batched_x).sharding}")
    if scaler:
        print(f"{name} Gradient sharding {jax.grad(fn)(x).sharding}")
    else:
        print(
            f"{name} Gradient sharding {jax.grad(lambda x:jnp.sum(fn(x)))(x).sharding}"
        )
    print(f"{name} Jacrev sharding {jax.jacrev(fn)(x).sharding}")
    print(f"{name} Jacfwd sharding {jax.jacfwd(fn)(x).sharding}")
    print(f"{name} Jvp sharding {jax.jvp(fn , (x,) , (jnp.ones_like(x),))[1].sharding}")
    print(f"{name} Vjp sharding {jax.vjp(fn , x)[1](f_pass)[0].sharding}")


trace(scaler_fn_shardmap, "Scaler Shardmapped" , scaler=True)
trace(scaler_fn_with_constraint, "Scaler With Constraint" , scaler=True)
trace(scaler_fn, "Scaler No Constraint" , scaler=True)

trace(fn_shardmap, "Shardmapped")
trace(fn_with_constraint, "With Constraint")
trace(fn, "No Constraint")

Custom Partionning

import os

os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

from jax import lax
from jax.interpreters import mlir, ad, batching
import jax.numpy as jnp
import jax
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import dispatch
import jax.extend as jex

from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P, Mesh
from functools import partial

mesh = Mesh(jax.devices(), ("x"))
sharding = NamedSharding(mesh, P("x"))

# ================================
# Double Primitive Rules
# ================================

# Step 1: Define the Primitive
double_prim_p = jex.core.Primitive("double_prim")
dispatch.prim_requires_devices_during_lowering.add(double_prim_p)
# Step 2: Define the Implementation


def origin_fn(x, scaler):
    return scaler * x


@partial(custom_partitioning, static_argnums=(1,))
def double_prim_impl(x, scaler):
    return origin_fn(x, scaler)


def infer_sharding_from_operands(scaler, mesh, arg_infos, result_infos):
    return arg_infos[0].sharding


def partition(scaler, mesh, arg_infos, result_infos):
    input_sharding = arg_infos[0].sharding
    output_sharding = result_infos.sharding
    input_mesh = input_sharding.mesh

    def impl(operand):
        return scaler * operand

    return input_mesh, impl, output_sharding, (input_sharding,)


@partial(custom_partitioning, static_argnums=(1, 2))
def vmapped_double_prim_impl(x, batch_dims, scaler):
    strip_static_fn = lambda x: origin_fn(x, scaler)
    return jax.vmap(strip_static_fn, in_axes=batch_dims)(x)


def strip_shaped_dtype_struct(arg_info):
    new_shape = arg_info.shape[1:]  # Remove the batch dimension
    new_spec = arg_info.sharding.spec[1:]  # Adjust the PartitionSpec
    new_sharding = NamedSharding(arg_info.sharding.mesh, P(*new_spec))
    return jax.ShapeDtypeStruct(
        new_shape, arg_info.dtype, sharding=new_sharding, weak_type=arg_info.weak_type
    )


def strip_shaped_array(shaped_array):
    return shaped_array.update(shape=shaped_array.shape[1:])


def add_batch_dimension_to_sharding(arg_info):
    # Add the batch dimension back to the shape and PartitionSpec
    new_spec = (None, *arg_info.spec)  # Add `None` to PartitionSpec
    new_sharding = NamedSharding(arg_info.mesh, P(*new_spec))
    return new_sharding


def v_infer_sharding_from_operands(batch_dims, scaler, mesh, arg_infos, result_infos):
    un_batched_arg_infos = jax.tree.map(strip_shaped_dtype_struct, arg_infos)
    un_batched_result_infos = jax.tree.map(strip_shaped_array, result_infos)
    un_batched_output_sharding = infer_sharding_from_operands(
        scaler, mesh, un_batched_arg_infos, un_batched_result_infos
    )
    batched_results = add_batch_dimension_to_sharding(un_batched_output_sharding)

    return batched_results


def v_partition(batch_dims, scaler, mesh, arg_infos, result_infos):
    un_batched_arg_infos = jax.tree.map(strip_shaped_dtype_struct, arg_infos)
    un_batched_result_infos = jax.tree.map(strip_shaped_dtype_struct, result_infos)
    input_mesh, impl, output_sharding, input_shardings = partition(
        scaler, mesh, un_batched_arg_infos, un_batched_result_infos
    )

    output_sharding = add_batch_dimension_to_sharding(output_sharding)
    input_shardings = jax.tree.map(add_batch_dimension_to_sharding, input_shardings)
    impl = jax.vmap(impl, in_axes=batch_dims)

    return input_mesh, impl, output_sharding, input_shardings


vmapped_double_prim_impl.def_partition(
    infer_sharding_from_operands=v_infer_sharding_from_operands, partition=v_partition
)


# Step 3: Define Abstract Evaluation
def double_prim_abstract_eval(*args, **kwargs):
    return jax.make_jaxpr(origin_fn, static_argnums=1)(*args, **kwargs).out_avals[0]


# Step 4: Define JVP Rule
def double_prim_jvp_rule(primals, tangents, scaler):
    (x,) = primals
    (t,) = tangents
    # Forward computation
    primal_out = double_prim_call(x, scaler)

    # Tangent computation (reuse the primitive itself)
    tangent_out = double_prim_call(t, scaler)
    return primal_out, tangent_out


# Step 5: Define Transpose Rule
def double_prim_transpose_rule(ct_out, x, scaler):
    ct_x = 2 * ct_out if ad.is_undefined_primal(x) else None
    return (ct_x,)


# Step 6: Define Batch Rule
def double_prim_batch_rule(batched_args, batch_dims, *args, **kwargs):
    (x,) = batched_args
    (bx,) = batch_dims
    # Apply vmapped double operation
    res = vmapped_double_prim_impl(x, bx, *args, **kwargs)
    return res, 0


# Step 7: Register the Primitive
double_prim_p.def_impl(double_prim_impl)  # Implementation
double_prim_p.def_abstract_eval(double_prim_abstract_eval)  # Abstract Eval
mlir.register_lowering(
    double_prim_p, mlir.lower_fun(double_prim_impl, multiple_results=False)
)  # Lowering
ad.primitive_jvps[double_prim_p] = double_prim_jvp_rule  # JVP Rule
ad.primitive_transposes[double_prim_p] = double_prim_transpose_rule  # Transpose Rule
batching.primitive_batchers[double_prim_p] = double_prim_batch_rule  # Batch Rule


# Define a Python wrapper for the primitive
@partial(jax.jit, static_argnums=(1,))
def double_prim_call(x, scaler=2):
    return double_prim_p.bind(x, scaler=scaler)


double_prim_impl.def_partition(
    infer_sharding_from_operands=infer_sharding_from_operands, partition=partition
)


@jax.custom_vjp
def double_vjp(x):
    return double_prim_call(x)


def double_vjp_fwd(x):
    return double_vjp(x), None


def double_vjp_bwd(_, g):
    return (2 * g,)


@jax.custom_jvp
def double_jvp(x):
    return double_prim_call(x)


@double_jvp.defjvp
def double_jvp_fwd(p, t):
    (x,) = p
    (t,) = t
    return double_jvp(x), double_jvp(t)


double_vjp.defvjp(double_vjp_fwd, double_vjp_bwd)

# ================================
# Linear Double Primitive Testing
# ================================


x = jnp.arange(8).astype(jnp.float32)
x = lax.with_sharding_constraint(x, sharding)
batched_x = jnp.stack([x, 2 * x])


def trace(fn, name, scaler=False, vjp=False):
    print("=" * 50)
    f_pass = fn(x)
    print(f"{name} Forward pass sharding {fn(x).sharding}")
    print(f"{name} Vmapped sharding {jax.vmap(fn)(batched_x).sharding}")
    if scaler:
        print(f"{name} Gradient sharding {jax.grad(fn)(x).sharding}")
    else:
        print(
            f"{name} Gradient sharding {jax.grad(lambda x:jnp.sum(fn(x)))(x).sharding}"
        )
    print(f"{name} Jacrev sharding {jax.jacrev(fn)(x).sharding}")
    if not vjp:
        print(f"{name} Jacfwd sharding {jax.jacfwd(fn)(x).sharding}")
        print(
            f"{name} Jvp sharding {jax.jvp(fn , (x,) , (jnp.ones_like(x),))[1].sharding}"
        )
    print(f"{name} Vjp sharding {jax.vjp(fn , x)[1](f_pass)[0].sharding}")


trace(double_jvp, "custom_partial_double_jvp")
trace(double_vjp, "custom_partial_double_vjp", vjp=True)

@ASKabalan
Copy link
Author

ASKabalan commented Jan 12, 2025

Here is the summary

Function Name Forward Pass Vmap Grad Jacfwd Jacrev Jvp Vjp
scalar Shardmapped ⚠️ N/A: scalar returned
scalar With Constraint ⚠️ N/A: scalar returned
scalar No Constraint ⚠️ N/A: scalar returned
Shardmapped ✅ (extra dim handled) ✅ (extra dim handled)
With Constraint ✅ (extra dim handled) ✅ (extra dim handled)
No Constraint ✅ (extra dim handled)
custom_par_jvp ✅ (extra dim handled)
custom_par_vjp ✅ (extra dim handled) ⚠️ Cannot jvp

In general custom_partionning works just like automatic pjit sharding which makes sense
But gradient propagation is not supported in both cases

The only issue with using shardmap or with_shard_constraint is that the gradients are not propagated for functions that output a scalar and differentiated with a Jacfwd
I think this is fine since in that case we would always use a grad or jacrev

IMO getting access to a static sharding just like a shape is crucial in this case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants