diff --git a/jax/experimental/roofline/__init__.py b/jax/experimental/roofline/__init__.py new file mode 100644 index 000000000000..8d76c46858c7 --- /dev/null +++ b/jax/experimental/roofline/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from jax.experimental.roofline.roofline import ( + RooflineRuleContext as RooflineRuleContext, +) +from jax.experimental.roofline.roofline import RooflineShape as RooflineShape +from jax.experimental.roofline.roofline import RooflineResult as RooflineResult +from jax.experimental.roofline.roofline import roofline as roofline +from jax.experimental.roofline.roofline import register_roofline as register_roofline +from jax.experimental.roofline.roofline import ( + register_standard_roofline as register_standard_roofline, +) +from jax.experimental.roofline.roofline import roofline_and_grad as roofline_and_grad + + +import jax.experimental.roofline.rooflines as rooflines + +del rooflines diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py new file mode 100644 index 000000000000..42f72f005034 --- /dev/null +++ b/jax/experimental/roofline/roofline.py @@ -0,0 +1,342 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Protocol, Sequence +import numpy as np + +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax._src import api +from jax._src import core +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src import util +from jax._src.api import make_jaxpr +from jax._src.interpreters.partial_eval import dce_jaxpr +from jax._src.interpreters.xla import abstractify +from jax._src.mesh import AbstractMesh, Mesh +from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map +from jax.experimental import shard_map + + +ShapeDtypeStructTree = Any + + +map = util.safe_map + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineRuleContext: + name_stack: source_info_util.NameStack + primitive: core.Primitive + avals_in: Sequence[core.AbstractValue] + avals_out: Sequence[core.AbstractValue] + jaxpr_eqn_ctx: core.JaxprEqnContext + mesh: Mesh | AbstractMesh + pin_lhs_in_vmem: bool + pin_rhs_in_vmem: bool + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineShape: + shape: tuple[int, ...] + dtype: np.dtype + + @classmethod + def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape": + if not isinstance(aval, core.ShapedArray): + raise TypeError(f"Expected ShapedArray, got {type(aval)}.") + if not isinstance(aval.dtype, np.dtype): + raise TypeError(f"Expected numpy dtype, got {type(aval.dtype)}.") + return cls(shape=aval.shape, dtype=aval.dtype) + + @property + def size(self) -> int: + return int(np.prod(self.shape)) + + @property + def bytes(self) -> int: + return int(self.size * self.dtype.itemsize) + + @classmethod + def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int: + return sum(cls.from_aval(aval).bytes for aval in avals) + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineResult: + flops: int = 0 + ici_bytes: dict[str, int] = field(default_factory=dict) + ici_latency: dict[str, int] = field(default_factory=dict) + hbm_bytes: int = 0 + peak_hbm_bytes: int = 0 + + @classmethod + def zeros(cls) -> "RooflineResult": + return cls() + + def __add__(self, other: "RooflineResult") -> "RooflineResult": + def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: + return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)} + + return RooflineResult( + flops=self.flops + other.flops, + ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes), + ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency), + hbm_bytes=self.hbm_bytes + other.hbm_bytes, + peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes), + ) + + def __mul__(self, constant: int | float) -> "RooflineResult": + return RooflineResult( + flops=int(self.flops * constant), + ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()}, + ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()}, + hbm_bytes=int(self.hbm_bytes * constant), + peak_hbm_bytes=int(self.peak_hbm_bytes * constant), + ) + + def __rmul__(self, constant: int | float) -> "RooflineResult": + return self.__mul__(constant) + + +class _RooflineRule(Protocol): + def __call__( + self, ctx: RooflineRuleContext, *args: RooflineShape, **kw + ) -> RooflineResult: ... + + +_rooflines: dict[core.Primitive, _RooflineRule] = {} + + +def _roofline_interpreter( + f_name: str, + jaxpr: core.Jaxpr, + mesh: Mesh | AbstractMesh, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, +) -> RooflineResult: + name_stack = source_info_util.new_name_stack(util.wrap_name(f_name, "roofline")) + + result = RooflineResult.zeros() + + env: dict[core.Var, RooflineShape] = {} + + def write(v: core.Var, node: RooflineShape): + assert node is not None + env[v] = node + + def read(v: core.Atom) -> RooflineShape: + if type(v) is core.Literal: + return RooflineShape.from_aval(abstractify(v.val)) + else: + assert isinstance(v, core.Var) + return env[v] + + def aval(v: core.Atom) -> core.AbstractValue: + if type(v) is core.Literal: + return abstractify(v.val) + else: + return v.aval + + def calculate_peak_hbm_bytes() -> int: + return int( + sum(np.prod(shape.shape) * shape.dtype.itemsize for shape in env.values()) + ) + + make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x)) + map( + write, + jaxpr.constvars, + map(make_roofline_shape, jaxpr.constvars), + ) + map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars)) + last_used = core.last_used(jaxpr) + for eqn in jaxpr.eqns: + source_info = eqn.source_info.replace( + name_stack=name_stack + eqn.source_info.name_stack + ) + with source_info_util.user_context( + eqn.source_info.traceback, name_stack=source_info.name_stack + ): + if "jaxpr" in eqn.params: + result += _roofline_interpreter( + util.wrap_name(f_name, eqn.primitive.name), + eqn.params["jaxpr"], + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) + else: + if eqn.primitive not in _rooflines: + msg = f"No roofline rule for {eqn.primitive}." + for attr in dir(eqn): + if not attr.startswith("_"): + msg += f"\n{attr}: {getattr(eqn, attr)}" + raise NotImplementedError(msg) + rule = _rooflines[eqn.primitive] + result += rule( + RooflineRuleContext( + name_stack=source_info.name_stack, + primitive=eqn.primitive, + avals_in=map(aval, eqn.invars), + avals_out=map(aval, eqn.outvars), + jaxpr_eqn_ctx=eqn.ctx, + mesh=mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ), + *map(read, eqn.invars), + **eqn.params, + ) + + map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars)) + core.clean_up_dead_vars(eqn, env, last_used) + result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes()) + + return result + + +def _f_with_vjp(f: Callable): + @util.wraps(f) + def wrapped(*args): + primals, f_vjp = api.vjp(f, *args) + return f_vjp(tree_map(jnp.bfloat16, primals)) + + return wrapped + + +def roofline( + f: Callable, + mesh: Mesh | AbstractMesh, + in_specs: shard_map.Specs, + out_specs: shard_map.Specs, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, + vjp: bool = False, + print_jaxpr: bool = False, +) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]: + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + wrapped_f = shard_map.shard_map(f, mesh, in_specs, out_specs) + if vjp: + wrapped_f = _f_with_vjp(wrapped_f) + + jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args) + + def make_sharded_shape_dtype_struct( + shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs + ) -> api.ShapeDtypeStruct: + return api.ShapeDtypeStruct( + shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec) + ) + + out_specs_flat = broadcast_prefix(out_specs, out_shapes) + flat_out_shapes, treedef = tree_flatten(out_shapes) + flat_out_shapes = map( + make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat + ) + out_shapes = tree_unflatten(treedef, flat_out_shapes) + + used_outputs = (True,) * len(jaxpr.jaxpr.outvars) + jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs) + try: + jaxpr = [e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p][ + -1 + ].params["jaxpr"] + except KeyError: + raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.") + + if print_jaxpr: + print(jaxpr) + + return out_shapes, _roofline_interpreter( + util.fun_qual_name(f), + jaxpr, + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) + + return wrapped + + +def register_roofline(prim: core.Primitive): + def register(rule: _RooflineRule): + _rooflines[prim] = rule + return rule + + return register + + +def register_standard_roofline(prim: core.Primitive): + def standard_rule(ctx: RooflineRuleContext, *args, **kwargs): + return RooflineResult.zeros() + + _rooflines[prim] = standard_rule + + +def roofline_and_grad( + f: Callable, + mesh: Mesh | AbstractMesh, + in_specs: shard_map.Specs, + out_specs: shard_map.Specs, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, + print_jaxpr: bool = False, +) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]: + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + primal_shapes, fwd_result = roofline( + f, + mesh, + in_specs, + out_specs, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + print_jaxpr=print_jaxpr, + )(*args) + + return ( + primal_shapes, + fwd_result, + roofline( + f, + mesh, + in_specs, + out_specs, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + vjp=True, + print_jaxpr=print_jaxpr, + )( + *tree_map( + lambda x: api.ShapeDtypeStruct( + x.shape, + jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16, + sharding=x.sharding, + ), + args, + ) + )[1], + ) + + return wrapped diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py new file mode 100644 index 000000000000..cfdb6358bc76 --- /dev/null +++ b/jax/experimental/roofline/rooflines.py @@ -0,0 +1,270 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from dataclasses import replace +import itertools as it +import numpy as np + +from jax._src import ad_util +from jax._src import core, util +from jax._src import ops +from jax._src import prng +from jax._src import random +from jax._src.lax import ( + ann, + convolution, + fft, + lax, + linalg, + parallel as lax_parallel, + slicing, + special, + windowed_reductions, +) +from jax.experimental import roofline +from jax.experimental import shard_map + + +for prim in it.chain( + ad_util.__dict__.values(), + ann.__dict__.values(), + convolution.__dict__.values(), + fft.__dict__.values(), + lax.__dict__.values(), + linalg.__dict__.values(), + ops.__dict__.values(), + prng.__dict__.values(), + random.__dict__.values(), + shard_map.__dict__.values(), + slicing.__dict__.values(), + special.__dict__.values(), + windowed_reductions.__dict__.values(), +): + if isinstance(prim, core.Primitive): + roofline.register_standard_roofline(prim) + + +@roofline.register_roofline(lax.dot_general_p) +def _dot_general_roofline( + ctx: roofline.RooflineRuleContext, + *args, + dimension_numbers: lax.DotDimensionNumbers, + **kw, +) -> roofline.RooflineResult: + lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + (lhs_contract, _), (lhs_batch, _) = dimension_numbers + + flops = ( + 2 + * lhs.size + * rhs.size + / np.prod([lhs.shape[i] for i in lhs_contract]) + / np.prod([lhs.shape[i] for i in lhs_batch]) + ) + + hbm_bytes = 0 + if not ctx.pin_lhs_in_vmem: + hbm_bytes += lhs.bytes + hbm_bytes += out.bytes + if not ctx.pin_rhs_in_vmem: + hbm_bytes += rhs.bytes + + return roofline.RooflineResult(flops=int(flops), hbm_bytes=hbm_bytes) + + +def _return_zeros_if_one_sized_axis( + ctx: roofline.RooflineRuleContext, axes: tuple[str, ...] +) -> roofline.RooflineResult | None: + axes_size = np.prod([ctx.mesh.shape[axis] for axis in axes]) + if axes_size > 1: + return None + return roofline.RooflineResult( + ici_bytes={axis: 0 for axis in axes}, + ici_latency={axis: 0 for axis in axes}, + ) + + +def _ring_collective_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + is_reduce: bool = True, + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axes): + return zeros_result + + mesh = ctx.mesh.shape + current_shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in) + if is_reduce: + current_shard_size /= np.prod([mesh[axis] for axis in axes]) + + # We model the slowest color as the bottleneck. + sorted_axes = sorted(axes, key=lambda x: mesh[x], reverse=True) + num_axes = len(sorted_axes) + + ici_bytes = 0 + # Phase split. + current_shard_size //= num_axes + for axis in sorted_axes: + axis_size = mesh[axis] + # Do phase. + ici_bytes += current_shard_size * (axis_size - 1) + # Increase shard size. + current_shard_size *= axis_size + + # Bottleneck is the longest axis. + ici_latency = mesh[sorted_axes[0]] * num_axes + + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in sorted_axes}, + ici_latency={axis: int(ici_latency) for axis in sorted_axes}, + ) + + +roofline.register_roofline(lax_parallel.reduce_scatter_p)( + lambda *args, axis_name, **kw: _ring_collective_roofline(*args, axes=axis_name, **kw) +) +roofline.register_roofline(lax_parallel.all_gather_p)( + lambda *args, axis_name, **kw: _ring_collective_roofline( + *args, axes=axis_name, is_reduce=False, **kw + ) +) + + +def _scalar_collective_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in] + ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes]) + return _ring_collective_roofline(ctx, *args, axes=axes, is_reduce=False, **kw) + + +roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline) +roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline) + + +@roofline.register_roofline(shard_map.psum2_p) +def _psum2_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + ring_roofline = _ring_collective_roofline(ctx, *args, axes=axes, **kw) + + def double_dict(d: dict[str, int]) -> dict[str, int]: + return {k: v * 2 for k, v in d.items()} + + return roofline.RooflineResult( + ici_bytes=double_dict(ring_roofline.ici_bytes), + ici_latency=double_dict(ring_roofline.ici_latency), + ) + + +@roofline.register_roofline(lax_parallel.all_to_all_p) +def _all_to_all_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis_name: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name): + return zeros_result + + mesh = ctx.mesh.shape + size = roofline.RooflineShape.total_bytes(ctx.avals_in) * np.prod([ + mesh[axis] for axis in axis_name + ]) + + smallest_axis = sorted(axis_name, key=lambda x: mesh[x])[0] + num_axes = len(axis_name) + bisection_bw = mesh[smallest_axis] ** (num_axes - 1) + if mesh[smallest_axis] > 2: + # Times 2 because of wraparound. + bisection_bw *= 2 + + # Half the data needs to cross the bisection on average. + ici_bytes = size / 2 / bisection_bw + + # The latency is the max number of hops across the mesh. + ici_latency = sum(mesh[axis] / 2 for axis in axis_name) + + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in axis_name}, + ici_latency={axis: int(ici_latency) for axis in axis_name}, + ) + + +@roofline.register_roofline(lax_parallel.ppermute_p) +def _ppermute_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis_name: tuple[str, ...], + perm: tuple[tuple[int, int], ...], + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name): + return zeros_result + + mesh = ctx.mesh.shape + mesh_dims: list[int] = [mesh.get(axis, 1) for axis in axis_name] + shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in) + + ici_contention: dict[tuple[tuple[int, ...], ...], float] = defaultdict(float) + ici_latency = 0 + + for src, dst in perm: + if src == dst: + continue + # Perms are linearized. + src_coords = tuple(int(i) for i in np.unravel_index(src, mesh_dims)) + dst_coords = tuple(int(i) for i in np.unravel_index(dst, mesh_dims)) + + ici_latency_for_perm = 0 + + # For each dimension. + for i in range(len(axis_name)): + dim_size = mesh_dims[i] + src_pos = src_coords[i] + dst_pos = dst_coords[i] + + if src_pos != dst_pos: + # Calculate distance with wraparound. + clockwise_dist = (dst_pos - src_pos) % dim_size + counter_dist = (src_pos - dst_pos) % dim_size + direction = 1 if clockwise_dist <= counter_dist else -1 + + curr_pos = src_pos + while curr_pos != dst_pos: + curr_coords = util.tuple_update(src_coords, i, curr_pos) + next_pos = (curr_pos + direction) % dim_size + next_coords = util.tuple_update(curr_coords, i, next_pos) + ici_contention[tuple(sorted([curr_coords, next_coords]))] += 1 + curr_pos = next_pos + + distance = min(clockwise_dist, counter_dist) + ici_latency_for_perm += distance + + ici_latency = max(ici_latency, ici_latency_for_perm) + + ici_bytes = shard_size * max(ici_contention.values(), default=0) + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in axis_name}, + ici_latency={axis: int(ici_latency) for axis in axis_name}, + ) diff --git a/tests/roofline_test.py b/tests/roofline_test.py new file mode 100644 index 000000000000..e5003947181b --- /dev/null +++ b/tests/roofline_test.py @@ -0,0 +1,426 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from functools import partial +import contextlib + +from absl.testing import absltest +from jax.sharding import PartitionSpec as P +import jax +import jax.lax as lax +import jax.numpy as jnp + +from jax._src import test_util as jtu + +from jax.experimental import roofline + + +jax.config.parse_flags_with_absl() + + +def create_inputs( + *shardings: P, + dtype: jnp.dtype = jnp.float32, + mesh_shape: tuple[int, ...] = (2, 2, 2), +) -> tuple[jax.sharding.Mesh, tuple[jax.ShapeDtypeStruct, ...]]: + mesh = jtu.create_mesh(mesh_shape, ("x", "y", "z")) + arrays = [] + for sharding in shardings: + array = jax.ShapeDtypeStruct( + (8, 8), dtype, sharding=jax.sharding.NamedSharding(mesh, sharding) + ) + arrays.append(array) + return mesh, tuple(arrays) + + +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + + +def setUpModule(): + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + + +def tearDownModule(): + _exit_stack.close() + + +class RooflineTest(jtu.JaxTestCase): + def test_scalar_collectives(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(P("z", None), P(("x", "y"), None)), + ) + def scalar_collectives(a, b): + a = lax.pmin(a, ("x", "y")) + b = lax.pmax(b, "z") + return a, b + + _, results = scalar_collectives(a, b) + + itemsize = 4 + + axis_size = 2 + axis_size_m1 = axis_size - 1 + + xy_num_axes = 2 + xy_ici_bytes = int( + itemsize + # 2 phases. + * ( + (1 / xy_num_axes * axis_size_m1) + (1 * axis_size / xy_num_axes * axis_size_m1) + ) + ) + # 2 phases times 2 hops. + xy_ici_latency = 2 * 2 + + z_ici_bytes = int(itemsize * 1 * axis_size_m1) + # 2 hops. + z_ici_latency = 2 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_collective_matmul(self): + a_spec = P(None, "x") + b_spec = P(None, "x") + c_spec = P("x", None) + mesh, (a, b, c) = create_inputs(a_spec, b_spec, c_spec, dtype=jnp.int8) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec, c_spec), + out_specs=a_spec, + ) + def collective_matmul(a, b, c): + a = lax.all_gather(a, "x", axis=1, tiled=True) + # Test broadcasting and slicing works. + a = a[None, :, :] + b = b[:, None, :] + ab = jnp.einsum("bij,jbk->ikb", a, b).astype(jnp.int8)[..., 0] + abc = jnp.einsum("ik,kc->ic", ab, c).astype(jnp.int8) + abc = lax.psum_scatter(abc, "x", scatter_dimension=1, tiled=True) + return abc + + _, results = collective_matmul(a, b, c) + + itemsize = 1 + m, k, n = 8, 4, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_mk = mk + + # Times 2 for ag + rs. + ici_bytes = 2 * int(itemsize * sharded_mk * axis_size_m1) + ici_latency = 2 * 2 + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=2 * itemsize * (mk + kn + mn), + # Right after all_gather. + peak_hbm_bytes=itemsize * (mk * axis_size + mk + kn), + ) + self.assertDataclassEqual(results, expected) + + def test_matmul_psum(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), None) + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=P("z", None), + ) + def matmul_psum(a, b): + c = a @ b + c = lax.psum(c, ("x", "y")) + return c + + _, results = matmul_psum(a, b) + + itemsize = 4 + m, k, n = 4, 2, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + num_axes = 2 + sharded_mn = mn / axis_size / num_axes + + # Times 2 for ag + rs. + ici_bytes = 2 * int( + itemsize + # 2 phases. + * ( + (sharded_mn / num_axes * axis_size_m1) + + (sharded_mn * axis_size / num_axes * axis_size_m1) + ) + ) + ici_latency = 2 * 2 * 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={axis: ici_bytes for axis in ("x", "y")}, + ici_latency={axis: ici_latency for axis in ("x", "y")}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mn), + ) + self.assertDataclassEqual(results, expected) + + def test_all_to_all(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(P(("z", "x", "y"), None), P(("x", "y", "z"), None)), + ) + def all_to_all(a, b): + a = lax.all_to_all(a, ("x", "y"), split_axis=0, concat_axis=1, tiled=True) + b = lax.all_to_all(b, "z", split_axis=0, concat_axis=1, tiled=True) + return a, b + + _, results = all_to_all(a, b) + + itemsize = 4 + + xy_size = itemsize * 8 * 8 / 2 + # Half the data over 2 links. + xy_ici_bytes = int(xy_size / 2 / 2) + # 2 hops. + xy_ici_latency = 2 + + z_size = itemsize * 8 * 8 / 2 / 2 + # Half the data over 1 link. + z_ici_bytes = int(z_size / 2) + # 1 hop. + z_ici_latency = 1 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_ppermute(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(a_spec, b_spec), + ) + def ppermute(a, b): + a = lax.ppermute(a, ("x", "y"), perm=((0, 3), (3, 0), (1, 2), (2, 1))) + b = lax.ppermute(b, "z", perm=((1, 0), (0, 1))) + return a, b + + _, results = ppermute(a, b) + + itemsize = 4 + shard_size = itemsize * 4 * 2 + + # At most 2 shards contend for 1 link. + xy_ici_bytes = int(shard_size * 2) + # 2 hops. + xy_ici_latency = 2 + + # No contention but there is a single link. + z_ici_bytes = int(shard_size * 2) + # 1 hop. + z_ici_latency = 1 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_grad_matmuls(self): + a_spec = P(None, "x") + b_spec = P(None, None) + mesh, (a, b) = create_inputs(a_spec, b_spec, dtype=jnp.int8) + + @partial( + roofline.roofline_and_grad, + mesh=mesh, + in_specs=(a_spec, b_spec), + # Numerically incorrect AD, but tests that we handle it properly. + out_specs=P("x", None), + ) + def collective_matmul(a, b): + a = lax.all_gather(a, "x", axis=1, tiled=True) + return a @ b + + c, fwd_results, bwd_results = collective_matmul(a, b) + + itemsize = 1 + m, k, n = 8, 8, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_mk = mk // axis_size + + ici_bytes = int(itemsize * sharded_mk * axis_size_m1) + ici_latency = 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mk + kn), + ) + self.assertDataclassEqual(fwd_results, expected) + + bwd_itemsize = 2 + # 2 for psum + 1 for rs. + bwd_ici_bytes = 3 * int(bwd_itemsize * sharded_mk * axis_size_m1) + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": bwd_ici_bytes}, + ici_latency={"x": 3 * ici_latency}, + hbm_bytes=2 * bwd_itemsize * (mk + kn + mn), + # Residuals + cotangents. + peak_hbm_bytes=bwd_itemsize * (mk + kn + mn), + ) + self.assertDataclassEqual(bwd_results, expected) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=c.sharding.spec, + out_specs=c.sharding.spec, + ) + def mul_2(c): + return c * 2 + + results = mul_2(c) + self.assertLen(results, 2) + + def test_one_sized_axis_collectives(self): + a_spec = P("x") + mesh, (a,) = create_inputs(a_spec, mesh_shape=(1, 2, 4)) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=a_spec, + out_specs=a_spec, + ) + def one_sized_axis_collectives(a): + a = lax.pmin(a, "x") + a = lax.all_gather(a, "x", axis=1, tiled=True) + a = lax.psum_scatter(a, "x", scatter_dimension=1, tiled=True) + a = lax.psum(a, "x") + a = lax.all_to_all(a, "x", split_axis=0, concat_axis=1, tiled=True) + a = lax.ppermute(a, "x", perm=((1, 0), (0, 1))) + return a + + _, results = one_sized_axis_collectives(a) + expected = roofline.RooflineResult( + ici_bytes={"x": 0}, + ici_latency={"x": 0}, + peak_hbm_bytes=4 * 8 * 8, + ) + self.assertDataclassEqual(results, expected) + + def test_remat(self): + a_spec = P("x", None) + b_spec = P("x", None) + mesh, (a, b) = create_inputs(a_spec, b_spec) + + def fsdp_checkpoint_policy(prim, *args, **kwargs): + if prim == lax.all_gather_p and kwargs["axis_name"] == "x": + return True + return False + + @partial( + roofline.roofline_and_grad, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=P("x", None), + ) + @partial(jax.checkpoint, policy=fsdp_checkpoint_policy) + def collective_matmul(a, b): + b = lax.all_gather(b, "x", axis=0, tiled=True) + return a @ b + + _, fwd_results, bwd_results = collective_matmul(a, b) + + itemsize = 4 + m, k, n = 4, 8, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_kn = kn // axis_size + + ici_bytes = int(itemsize * sharded_kn * axis_size_m1) + ici_latency = 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mk + kn), + ) + self.assertDataclassEqual(fwd_results, expected) + + bwd_itemsize = 2 + # Remat ag + rs. + bwd_ici_bytes = 2 * int(bwd_itemsize * sharded_kn * axis_size_m1) + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": bwd_ici_bytes}, + ici_latency={"x": 2 * ici_latency}, + hbm_bytes=2 * bwd_itemsize * (mk + kn + mn), + # Residuals + cotangents. + # We gather kn while computing the kn cotangents. + peak_hbm_bytes=bwd_itemsize * (kn + kn + mn), + ) + self.assertDataclassEqual(bwd_results, expected) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())