Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whiledce #173

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
build --announce_rc

query --experimental_repo_remote_exec
build --experimental_repo_remote_exec
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
build --cxxopt=-w --host_cxxopt=-w
Expand Down
3 changes: 3 additions & 0 deletions builddeps/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ compile_pip_requirements(
"--build-isolation",
"--rebuild",
],
extra_deps = [
# "@pypi_wheel//:pkg"
],
requirements_in = "requirements.in",
requirements_txt = REQUIREMENTS,
generate_hashes = True,
Expand Down
4 changes: 2 additions & 2 deletions builddeps/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
#
-r test-requirements.txt

jax >= 0.4.21
jaxlib >= 0.4.21
jax
jaxlib
absl_py >= 2.0.0
4 changes: 3 additions & 1 deletion builddeps/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ absl-py
jax
numpy
jaxlib
https://github.com/wsmoses/jax-md/archive/1188490610b95023f8a51166c3f6b92da31e78fe.tar.gz
https://github.com/wsmoses/jax-md/archive/b41e23abc8662e767033c2bcf346eb58b32363a9.tar.gz
# maxtext can't be installed concurrently, but installing it fixes
# https://github.com/wsmoses/maxtext/archive/bc50722be7d89e4003bd830b80e4ac968be658eb.tar.gz
jax[cuda12_pip]; sys_platform == 'linux'
requests; sys_platform == 'linux'
# -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Expand Down
68 changes: 68 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6362,6 +6362,74 @@ struct ReorderElementwiseAndShapeOp final
return success();
}
};


struct WhileDCE final
: OpTraitRewritePattern<WhileOp> {
using OpTraitRewritePattern::OpTraitRewritePattern;

LogicalResult matchAndRewrite(WhileOp wop,
PatternRewriter &rewriter) const override {

SmallVector<BlockArgument> candidates;
SmallSet<Value> used;
ReturnOp ret = cast<ReturnOp>(wop.getBody()->getTerminator());
for (auto &[arg1, arg2, res, yld] : llvm::zip(wop.getCond().getArguments(), wop.getBody().getArguments(), wop.getResults(), ret.getOperands())) {
if (!arg1.use_empty()) {
used.insert(yld);
continue;
}
if (!res.use_empty()) {
used.insert(yld);
continue;
}
candidates.push_back(arg2);
}
if (candidates.size() == 0) return failure();

// Here we assume, perhaps incorrectly, that all operations are readnone
// Reverse traversal of cfg to determine unnecessary operands within while scope
for (auto op : llvm::reverse(wop.getBody().getOperations().without_terminator())) {
bool used = false;
for (auto res : op->getResults()) {
if (used)
}
}

if (op->getOperands().size() != 1)
return rewriter.notifyMatchFailure(op, "expected to be unary");

auto definingOp = op->getOperand(0).getDefiningOp();
if (!definingOp)
return rewriter.notifyMatchFailure(
op, "expected to have an op before elementise op");

if (!isa<mlir::stablehlo::ReshapeOp>(definingOp) &&
!isa<mlir::stablehlo::TransposeOp>(definingOp) &&
!isa<mlir::stablehlo::BroadcastOp>(definingOp))
return rewriter.notifyMatchFailure(
op, "defining operation of unexpected type");

// Only reorder if the defining op has no other uses.
if (!llvm::hasSingleElement(definingOp->getResult(0).getUses()))
return rewriter.notifyMatchFailure(op, "operation has more than one use");

Value input = definingOp->getOperand(0);
Value result = op->getResult(0);
auto intermediateType = input.getType().cast<ShapedType>().clone(
getElementTypeOrSelf(result.getType()));

// Reorder the operation and rewire the inputs/outputs.
op->moveBefore(definingOp);
definingOp->getResult(0).setType(result.getType());
rewriter.replaceAllUsesWith(result, definingOp->getResult(0));
result.setType(intermediateType);
op->setOperands(input);
definingOp->setOperands(result);
return success();
}
};

/////////////// End Imported from stablehlo

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"
Expand Down
10 changes: 10 additions & 0 deletions test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,13 @@ py_test(
deps = TEST_DEPS + ["@pypi_jax_md//:pkg"],
timeout='long'
)

# py_test(
# name = "maxtext",
# srcs = [
# "maxtext.py",
# ],
# imports = ["."],
# deps = TEST_DEPS + ["@pypi_maxtext//:pkg"],
# timeout='long'
# )
186 changes: 186 additions & 0 deletions test/maxtext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Steps for getting results here
# Run:
# 1) pip install https://github.com/wsmoses/maxtext
# 2) bazel build -c opt //:wheel
# 3) pip install ./bazel-bin/*whl
# 4) python test/maxtext.py

from absl.testing import absltest
import jax.numpy as jnp
import jax.random
import jax.lax
import enzyme_ad.jax as enzyme_jax
from enzyme_ad.jax import (
enzyme_jax_ir,
NewXLAPipeline,
OldXLAPipeline,
JaXPipeline,
hlo_opts,
)
import numpy as np
import timeit

argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11")

import jax.numpy as np
import numpy as onp
from jax import jit
from jax import random
from jax import lax

partialopt = (
"inline{default-pipeline=canonicalize max-iterations=4},"
+ """canonicalize,cse,
enzyme-hlo-generate-td{
patterns=compare_op_canon<16>;
transpose_transpose<16>;
broadcast_in_dim_op_canon<16>;
convert_op_canon<16>;
dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;
chained_dynamic_broadcast_in_dim_canonicalization<16>;
dynamic_broadcast_in_dim_all_dims_non_expanding<16>;
noop_reduce_op_canon<16>;
empty_reduce_op_canon<16>;
dynamic_reshape_op_canon<16>;
get_tuple_element_op_canon<16>;
real_op_canon<16>;
imag_op_canon<16>;
get_dimension_size_op_canon<16>;
gather_op_canon<16>;
reshape_op_canon<16>;
merge_consecutive_reshapes<16>;
transpose_is_reshape<16>;
zero_extent_tensor_canon<16>;
reorder_elementwise_and_shape_op<16>;

cse_broadcast_in_dim<16>;
cse_slice<16>;
cse_transpose<16>;
cse_convert<16>;
cse_pad<16>;
cse_dot_general<16>;
cse_reshape<16>;
cse_mul<16>;
cse_div<16>;
cse_add<16>;
cse_subtract<16>;
cse_min<16>;
cse_max<16>;
cse_neg<16>;
cse_concatenate<16>;

concatenate_op_canon<16>(1024);
select_op_canon<16>(1024);
add_simplify<16>;
sub_simplify<16>;
and_simplify<16>;
max_simplify<16>;
min_simplify<16>;
or_simplify<16>;
negate_simplify<16>;
mul_simplify<16>;
div_simplify<16>;
rem_simplify<16>;
pow_simplify<16>;
sqrt_simplify<16>;
cos_simplify<16>;
sin_simplify<16>;
noop_slice<16>;
const_prop_through_barrier<16>;
slice_slice<16>;
shift_right_logical_simplify<16>;
pad_simplify<16>;
negative_pad_to_slice<16>;
tanh_simplify<16>;
exp_simplify<16>;
slice_simplify<16>;
convert_simplify<16>;
dynamic_slice_to_static<16>;
dynamic_update_slice_elim<16>;
concat_to_broadcast<16>;
reduce_to_reshape<16>;
broadcast_to_reshape<16>;
gather_simplify<16>;
iota_simplify<16>(1024);
broadcast_in_dim_simplify<16>(1024);
convert_concat<1>;
dynamic_update_to_concat<1>;
slice_of_dynamic_update<1>;
slice_elementwise<1>;
slice_pad<1>;
dot_reshape_dot<1>;
concat_const_prop<1>;
concat_fuse<1>;
pad_reshape_pad<1>;
pad_pad<1>;
concat_push_binop_add<1>;
concat_push_binop_mul<1>;
scatter_to_dynamic_update_slice<1>;
reduce_concat<1>;
slice_concat<1>;

bin_broadcast_splat_add<1>;
bin_broadcast_splat_subtract<1>;
bin_broadcast_splat_div<1>;
bin_broadcast_splat_mul<1>;
slice_reshape<1>;

dot_reshape_pad<1>;
pad_dot_general<1>(1);
broadcast_reduce<1>;
},
transform-interpreter,
enzyme-hlo-remove-transform,cse"""
)

pipelines = [
("JaX ", None),
("JaXPipe", JaXPipeline()),
(
"HLOOpt",
JaXPipeline(
"inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,enzyme-hlo-opt,cse"
),
),
("PartOpt", JaXPipeline(partialopt)),
("DefOpt", JaXPipeline(hlo_opts())),
]


class MaxText(absltest.TestCase):
def setUp(self):
import MaxText
import MaxText.pyconfig

MaxText.pyconfig.initialize(
[
None,
"test/maxtext_configs/base.yml",
"dataset_type=synthetic",
"steps=10",
]
)

def test(self):
import MaxText
import MaxText.pyconfig
import MaxText.train

config = MaxText.pyconfig.config

for name, pipeline in pipelines:
print("name=", name)

def rewrite(fn):
if pipeline is None:
return fn
else:
return enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(fn)

res1 = MaxText.train.train_loop(config, prejit=rewrite)
print("name=", name, res1)


if __name__ == "__main__":
absltest.main()
Loading
Loading