diff --git a/.bazelrc b/.bazelrc index 8596b954..8fe5e0bf 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,5 +1,6 @@ build --announce_rc +query --experimental_repo_remote_exec build --experimental_repo_remote_exec build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 build --cxxopt=-w --host_cxxopt=-w diff --git a/builddeps/BUILD b/builddeps/BUILD index be0ae35f..451e9e46 100644 --- a/builddeps/BUILD +++ b/builddeps/BUILD @@ -10,6 +10,9 @@ compile_pip_requirements( "--build-isolation", "--rebuild", ], + extra_deps = [ + # "@pypi_wheel//:pkg" + ], requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, generate_hashes = True, diff --git a/builddeps/requirements.in b/builddeps/requirements.in index 58fa9a1e..226b8b1d 100644 --- a/builddeps/requirements.in +++ b/builddeps/requirements.in @@ -3,6 +3,6 @@ # -r test-requirements.txt -jax >= 0.4.21 -jaxlib >= 0.4.21 +jax +jaxlib absl_py >= 2.0.0 diff --git a/builddeps/test-requirements.txt b/builddeps/test-requirements.txt index 569c677b..7bac118d 100644 --- a/builddeps/test-requirements.txt +++ b/builddeps/test-requirements.txt @@ -2,7 +2,9 @@ absl-py jax numpy jaxlib -https://github.com/wsmoses/jax-md/archive/1188490610b95023f8a51166c3f6b92da31e78fe.tar.gz +https://github.com/wsmoses/jax-md/archive/b41e23abc8662e767033c2bcf346eb58b32363a9.tar.gz +# maxtext can't be installed concurrently, but installing it fixes +# https://github.com/wsmoses/maxtext/archive/bc50722be7d89e4003bd830b80e4ac968be658eb.tar.gz jax[cuda12_pip]; sys_platform == 'linux' requests; sys_platform == 'linux' # -f https://storage.googleapis.com/jax-releases/libtpu_releases.html diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 3f2a1177..5342ca44 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6362,6 +6362,74 @@ struct ReorderElementwiseAndShapeOp final return success(); } }; + + +struct WhileDCE final + : OpTraitRewritePattern { + using OpTraitRewritePattern::OpTraitRewritePattern; + + LogicalResult matchAndRewrite(WhileOp wop, + PatternRewriter &rewriter) const override { + + SmallVector candidates; + SmallSet used; + ReturnOp ret = cast(wop.getBody()->getTerminator()); + for (auto &[arg1, arg2, res, yld] : llvm::zip(wop.getCond().getArguments(), wop.getBody().getArguments(), wop.getResults(), ret.getOperands())) { + if (!arg1.use_empty()) { + used.insert(yld); + continue; + } + if (!res.use_empty()) { + used.insert(yld); + continue; + } + candidates.push_back(arg2); + } + if (candidates.size() == 0) return failure(); + + // Here we assume, perhaps incorrectly, that all operations are readnone + // Reverse traversal of cfg to determine unnecessary operands within while scope + for (auto op : llvm::reverse(wop.getBody().getOperations().without_terminator())) { + bool used = false; + for (auto res : op->getResults()) { + if (used) + } + } + + if (op->getOperands().size() != 1) + return rewriter.notifyMatchFailure(op, "expected to be unary"); + + auto definingOp = op->getOperand(0).getDefiningOp(); + if (!definingOp) + return rewriter.notifyMatchFailure( + op, "expected to have an op before elementise op"); + + if (!isa(definingOp) && + !isa(definingOp) && + !isa(definingOp)) + return rewriter.notifyMatchFailure( + op, "defining operation of unexpected type"); + + // Only reorder if the defining op has no other uses. + if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) + return rewriter.notifyMatchFailure(op, "operation has more than one use"); + + Value input = definingOp->getOperand(0); + Value result = op->getResult(0); + auto intermediateType = input.getType().cast().clone( + getElementTypeOrSelf(result.getType())); + + // Reorder the operation and rewire the inputs/outputs. + op->moveBefore(definingOp); + definingOp->getResult(0).setType(result.getType()); + rewriter.replaceAllUsesWith(result, definingOp->getResult(0)); + result.setType(intermediateType); + op->setOperands(input); + definingOp->setOperands(result); + return success(); + } +}; + /////////////// End Imported from stablehlo #include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc" diff --git a/test/BUILD b/test/BUILD index 490e15ef..822b268c 100644 --- a/test/BUILD +++ b/test/BUILD @@ -131,3 +131,13 @@ py_test( deps = TEST_DEPS + ["@pypi_jax_md//:pkg"], timeout='long' ) + +# py_test( +# name = "maxtext", +# srcs = [ +# "maxtext.py", +# ], +# imports = ["."], +# deps = TEST_DEPS + ["@pypi_maxtext//:pkg"], +# timeout='long' +# ) diff --git a/test/maxtext.py b/test/maxtext.py new file mode 100644 index 00000000..b14c1874 --- /dev/null +++ b/test/maxtext.py @@ -0,0 +1,186 @@ +# Steps for getting results here +# Run: +# 1) pip install https://github.com/wsmoses/maxtext +# 2) bazel build -c opt //:wheel +# 3) pip install ./bazel-bin/*whl +# 4) python test/maxtext.py + +from absl.testing import absltest +import jax.numpy as jnp +import jax.random +import jax.lax +import enzyme_ad.jax as enzyme_jax +from enzyme_ad.jax import ( + enzyme_jax_ir, + NewXLAPipeline, + OldXLAPipeline, + JaXPipeline, + hlo_opts, +) +import numpy as np +import timeit + +argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") + +import jax.numpy as np +import numpy as onp +from jax import jit +from jax import random +from jax import lax + +partialopt = ( + "inline{default-pipeline=canonicalize max-iterations=4}," + + """canonicalize,cse, +enzyme-hlo-generate-td{ + patterns=compare_op_canon<16>; +transpose_transpose<16>; +broadcast_in_dim_op_canon<16>; +convert_op_canon<16>; +dynamic_broadcast_in_dim_op_not_actually_dynamic<16>; +chained_dynamic_broadcast_in_dim_canonicalization<16>; +dynamic_broadcast_in_dim_all_dims_non_expanding<16>; +noop_reduce_op_canon<16>; +empty_reduce_op_canon<16>; +dynamic_reshape_op_canon<16>; +get_tuple_element_op_canon<16>; +real_op_canon<16>; +imag_op_canon<16>; +get_dimension_size_op_canon<16>; +gather_op_canon<16>; +reshape_op_canon<16>; +merge_consecutive_reshapes<16>; +transpose_is_reshape<16>; +zero_extent_tensor_canon<16>; +reorder_elementwise_and_shape_op<16>; + +cse_broadcast_in_dim<16>; +cse_slice<16>; +cse_transpose<16>; +cse_convert<16>; +cse_pad<16>; +cse_dot_general<16>; +cse_reshape<16>; +cse_mul<16>; +cse_div<16>; +cse_add<16>; +cse_subtract<16>; +cse_min<16>; +cse_max<16>; +cse_neg<16>; +cse_concatenate<16>; + +concatenate_op_canon<16>(1024); +select_op_canon<16>(1024); +add_simplify<16>; +sub_simplify<16>; +and_simplify<16>; +max_simplify<16>; +min_simplify<16>; +or_simplify<16>; +negate_simplify<16>; +mul_simplify<16>; +div_simplify<16>; +rem_simplify<16>; +pow_simplify<16>; +sqrt_simplify<16>; +cos_simplify<16>; +sin_simplify<16>; +noop_slice<16>; +const_prop_through_barrier<16>; +slice_slice<16>; +shift_right_logical_simplify<16>; +pad_simplify<16>; +negative_pad_to_slice<16>; +tanh_simplify<16>; +exp_simplify<16>; +slice_simplify<16>; +convert_simplify<16>; +dynamic_slice_to_static<16>; +dynamic_update_slice_elim<16>; +concat_to_broadcast<16>; +reduce_to_reshape<16>; +broadcast_to_reshape<16>; +gather_simplify<16>; +iota_simplify<16>(1024); +broadcast_in_dim_simplify<16>(1024); +convert_concat<1>; +dynamic_update_to_concat<1>; +slice_of_dynamic_update<1>; +slice_elementwise<1>; +slice_pad<1>; +dot_reshape_dot<1>; +concat_const_prop<1>; +concat_fuse<1>; +pad_reshape_pad<1>; +pad_pad<1>; +concat_push_binop_add<1>; +concat_push_binop_mul<1>; +scatter_to_dynamic_update_slice<1>; +reduce_concat<1>; +slice_concat<1>; + +bin_broadcast_splat_add<1>; +bin_broadcast_splat_subtract<1>; +bin_broadcast_splat_div<1>; +bin_broadcast_splat_mul<1>; +slice_reshape<1>; + +dot_reshape_pad<1>; +pad_dot_general<1>(1); +broadcast_reduce<1>; + }, + transform-interpreter, + enzyme-hlo-remove-transform,cse""" +) + +pipelines = [ + ("JaX ", None), + ("JaXPipe", JaXPipeline()), + ( + "HLOOpt", + JaXPipeline( + "inline{default-pipeline=canonicalize max-iterations=4}," + + "canonicalize,cse,enzyme-hlo-opt,cse" + ), + ), + ("PartOpt", JaXPipeline(partialopt)), + ("DefOpt", JaXPipeline(hlo_opts())), +] + + +class MaxText(absltest.TestCase): + def setUp(self): + import MaxText + import MaxText.pyconfig + + MaxText.pyconfig.initialize( + [ + None, + "test/maxtext_configs/base.yml", + "dataset_type=synthetic", + "steps=10", + ] + ) + + def test(self): + import MaxText + import MaxText.pyconfig + import MaxText.train + + config = MaxText.pyconfig.config + + for name, pipeline in pipelines: + print("name=", name) + + def rewrite(fn): + if pipeline is None: + return fn + else: + return enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(fn) + + res1 = MaxText.train.train_loop(config, prejit=rewrite) + print("name=", name, res1) + + +if __name__ == "__main__": + absltest.main() diff --git a/test/maxtext_configs/base.yml b/test/maxtext_configs/base.yml new file mode 100644 index 00000000..40bef0b5 --- /dev/null +++ b/test/maxtext_configs/base.yml @@ -0,0 +1,456 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +# If there is already a checkpoint under this run, that checkpoint will auto-resume. +run_name: "" + +model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this! +normalization_layer_epsilon: 1.e-05 + +################################## CHECKPOINTING ################################## +# Checkpointing makes the following choices in the following order, starting with (1): +# (1) If there is already a checkpoint for this run_name, we load the latest entire checkpoint. +# This ensures if we're resuming a run after preemption or hardware failure we lose minimum state. +# (2) Same priority and mutually exclusive -- you can't set both! +# * If load_parameters_path is set, we load a parameter only checkpoint from that path. +# * If load_full_state_path is set, we load a full state checkpoint from that path. +# (3) We don't load a checkpoint and initialize state instead! + +# Loads a just parameters from a specific directory +# e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items +load_parameters_path: "" +# Loads a full checkpoint including optimizer state and step count from a specific directory +# e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items +load_full_state_path: "" + +# If enable_checkpointing is true, an asynchronous checkpointer will be used if +# async_checkpointing is true, else a synchronous one is used. If you have +# problems with the checkpointer we recommend trying the synchronous one. +enable_checkpointing: True +async_checkpointing: True +checkpoint_period: 10_000 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +force_unroll: False # during generate_param_only_checkpoint should we unroll the loop? + +# checkpointing using orbax has two important parameters: array driver +# and its underlying storage - the kvstore (preferably ocdbt) +# orbax supports setting a target file size, chunking a single +# large arrays into small physical files (<2GB) can speed up distributed and over +# the network loading enormously +checkpoint_storage_target_data_file_size_bytes: 2147483648 +checkpoint_storage_use_ocdbt: True +checkpoint_storage_use_zarr3: True +############################### END CHECKPOINTING ################################## + + +reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch. + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +gcs_metrics: False + +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False + +# Activation dtypes. +dtype: "bfloat16" +# Used to configure quantization in the transformer layers, defaults to null implying bf16. +# Possible alternative settings are as follows: +# 'int8' for dynamic range quantization using 8-bits +# 'int8w' for weights only quantization using 8-bits +# 'int4w' for weights only quantization using 4-bits +# 'intmp' for mixed precision weight only quantization based on config file +# 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs. +quantization: "" +# Choose one of default, high, and highest. +# https://kolonist26-jax-kr.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +matmul_precision: "default" +activations_in_float32: False # Sets activations to float32 before nonlinearity it true, else dtype +# Path to file with quantization config - only used for mixed precision. +# Example configs in ../Maxtext/configs/quantization +# Allows us to configure different bits, tiling and scale for quantizing selected weights. +# Bits represents number of bits to quantize to, +# tile-size represents the tiling sized used in AQT tiled_dot_general, +# Value of scale is used to scale the abs_max value used for AQT quantization +# Defaults values are 8 bits, tile-size=-1 (no tiling) and scale=1. +quant_cfg_path: "" +quantize_kvcache: False # Set to True to quantize KV Cache values, defaults to False +# Valid kv_quant_axis values: +# - "" is valid only when quantize_kvcache is False +# - "dkv" indicates quantize kv cache over the cache_kv, i.e. kv dimension axis +# - "heads_and_dkv" indicates quantize kv cache over cache_heads and cache_kv axes +# Default to "heads_and_dkv" for faster compution, kv_quant_axis is not used when quantize_kvcache is False +# - "dkv" is expected with better accuracy but degraded computation +kv_quant_axis: "heads_and_dkv" +kv_quant_dtype: "int8" +checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint +# Saves params quantized on fly at following path +save_quantized_params_path: "" + +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 + +decoder_block: "llama2" # which style of DecoderBlock to use. +# Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes +# then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads, +# base_mlp_dim, base_num_decoder_layers and/or head_dim. +weight_dtype: float32 +global_parameter_scale: 1 +base_emb_dim: 512 +# base_emb_dim: 2048 +base_num_query_heads: 1 +# base_num_query_heads: 16 +base_num_kv_heads: 1 +# base_num_kv_heads: 16 +base_mlp_dim: 7168 +base_num_decoder_layers: 1 +# base_num_decoder_layers: 16 +head_dim: 128 +mlp_activations: ["silu", "linear"] +dropout_rate: 0 +logits_via_embedding: False +normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true +logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embedding dot product for stability +cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly. + +# mixture of experts (moe) +num_experts: 1 +num_experts_per_tok: 1 +megablox: True +capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default +load_balance_loss_weight: 0.01 # weight for the load balance loss + +# pipeline parallelism +# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats. +# There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier +# it is to hide the pipeline communication behind the compute since there is more compute per stage, however there will be a larger bubble +# since there are fewer repeats. Similarly there is tradeoff for num_pipeline_microbatches - more microbatches leads to a smaller bubble, +# but a smaller size per microbatch which may hurt per-stage performance. Additionally note when microbatches > num_stages we have the opportunity to +# perform the circular transfer (last stage to first) asynchronously. +# The bubble fraction is (num_stages - 1) / (num_pipeline_repeats * num_pipeline_microbatches + num_stages - 1) +num_layers_per_pipeline_stage: 1 +# The number of repeats will be set to num_decoder_layers / (num_pipeline_stages * num_layers_per_pipeline_stage) +num_pipeline_repeats: -1 +# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages. +# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices +num_pipeline_microbatches: -1 +scan_pipeline_iterations: True # This can be set independently of scan_layers, which is relevant when num_layers_per_pipeline_stage > 1. +pipeline_delay_activation_forwarding: False # This delays the activation forwarding one loop iteration simplifying XLA's task of overlapping since +# the communication and compute in each iteration are now independent. However this comes at the cost of doubling the pipeline bubble, +# and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay). + +# Choose 'remat_policy' between 'minimal', 'save_dot_except_mlpwi', 'save_dot_except_mlp', 'save_qkv_proj', 'qkv_proj_offloaded', 'minimal_offloaded', 'save_out_proj' and 'full'. +# These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) +remat_policy: 'full' +scan_layers: True +param_scan_axis: 1 + +# The attention parameter dictates the specific algorithm/methodology used to compute the attention scores +# The attention_type parameter determines the variants of attention, e.g. global or local_sliding +attention: 'dot_product' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te +attention_type: 'global' # Supported attention_type: global, local_sliding +sliding_window_size: 0 +attn_logits_soft_cap: 0.0 +final_logits_soft_cap: 0.0 +use_post_attn_norm: False +use_post_ffw_norm: False + + +# Combine matmuls for QKV and MLP +fused_qkv: False +fused_mlp: False + +record_internal_nn_metrics: 0 + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Whether or not to enable emergency checkpoint. If True, `local_checkpoint_directory` and a non-zero `local_checkpoint_period` must also be specified. +# Emergency checkpoint is an experimental Orbax feature that: periodically saves to persistent storage and, with a larger invertal, saves to a local directory. +# During restore, if a local copy is available in any slice, it will be broadcast to other slices without having to fetch from persistent storage. +# See more details on https://github.com/google/orbax/tree/main/checkpoint/orbax/checkpoint/experimental/emergency. +enable_emergency_checkpoint: False + +# It should be specified when and only when `enable_emergency_checkpoint` is True. +local_checkpoint_directory: "" + +# It should be a positive number when and only when `enable_emergency_checkpoint` is True. +local_checkpoint_period: 0 + +# Jax cache directory +jax_cache_dir: "~/jax_cache" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' + +# Parallelism +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive'] +logical_axis_rules: [ + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages. + # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape. + # The "stage" needs to be listed first since the microbatch dimension is first before the reshape. + ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_heads', ['tensor','sequence']], + ['activation_kv_heads', ['tensor','sequence']], + ['activation_length', 'sequence'], + ['activation_embed', 'tensor'], + ['activation_mlp', 'tensor'], + ['activation_kv', 'tensor'], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_head_dim', 'tensor'], + ['activation_vocab', ['tensor', 'sequence']], + ['activation_vocab', 'tensor'], + ['activation_vocab', 'sequence'], + ['activation_stage', 'stage'], + ['activation_exp', 'expert'], + ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], + ['vocab', ['tensor', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], + ['embed', ['fsdp', 'sequence', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']], + ['embed_no_exp', ['fsdp', 'sequence']], + ['norm', 'tensor'], + ['heads', ['tensor', 'autoregressive']], + ['layers', 'stage'], + ['kv', []], + ['kv_heads', ['tensor', 'autoregressive']], + ['kv_head_dim', []], + ['cache_batch', []], + ['cache_heads', ['autoregressive', 'tensor']], + ['cache_kv', []], + ['cache_sequence', []], + ['exp', 'expert'], + ] +# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 +dcn_fsdp_transpose_parallelism: 1 +dcn_sequence_parallelism: 1 # never recommended +dcn_tensor_parallelism: 1 # never recommended +dcn_pipeline_parallelism: 1 +dcn_expert_parallelism: 1 +dcn_autoregressive_parallelism: 1 # never recommended +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_fsdp_transpose_parallelism: 1 +ici_sequence_parallelism: 1 +ici_tensor_parallelism: 1 +ici_autoregressive_parallelism: 1 +ici_pipeline_parallelism: 1 +ici_expert_parallelism: 1 + +# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation, +# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. +num_slices: -1 + +# Tokenizer +vocab_size: 32_000 # powers of 2 for sharding +tokenizer_path: "assets/tokenizer.llama2" +tokenize_train_data: True # False if the dataset is pre-tokenized +tokenize_eval_data: True # False if the dataset is pre-tokenized +add_bos: True +add_eos: True + +# Dataset +per_device_batch_size: 12.0 +expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS. +eval_per_device_batch_size: 0 +max_corpus_chars: 10_000_000 +train_data_column: 'text' +eval_data_column: 'text' +# dataset_type must be synthetic, hf, grain, tfds +# details in: https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md +dataset_type: tfds +# for TFDS input pipeline (dataset_type=tfds) +dataset_path: "" # your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/" +dataset_name: 'c4/en:3.0.1' +eval_dataset_name: 'c4/en:3.0.1' +eval_split: 'validation' +# for HuggingFace input pipeline (dataset_type=hf) +hf_path: '' +hf_data_dir: '' +hf_train_files: '' +hf_eval_split: '' +hf_eval_files: '' +hf_access_token: '' +# for Grain input pipeline (dataset_type=grain) +grain_train_files: '' +grain_eval_files: '' +grain_worker_count: 1 + +# Training loop +steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps +log_period: 100 # Flushes Tensorboard + +# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 +# Learning rate schedule has either two or three parts: +# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] +# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps +# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. +# The zero learning rate section can be used to more accurately measure the fully trained model's performance. +learning_rate: 3.e-5 +cosine_learning_rate_final_fraction: 0.1 +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +max_target_length: 2048 # Maximum sequence length +max_prefill_predict_length: 64 # Maximum length for the prefill when doing autoregression +prompt: "I love to" # Prompt for language model sampling. +load_from_prefill_dir: False # If true, decode.py doesn't "prefill" but just reads from directory +prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from directory. If set, decode.py writes to directory +autoregressive_decode_assert: "" + +# For nsys profiler, pass the training command to nsys command +# e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} +profiler: "" # Supported profiler: '', xplane, nsys +# If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. +upload_all_profiler_results: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 1 +# Profile for a small number of steps to avoid a large profile file size. +profiler_steps: 5 + +# When dropout is false the model is a deterministic function of the +# data_shuffle_seed and init_weights_seed (i.e. reproducible losses) +enable_dropout: True +enable_data_shuffling: True +data_shuffle_seed: 0 +init_weights_seed: 0 + +# You may disable clipping by setting gradient_clipping_threshold to zero. +gradient_clipping_threshold: 1.0 + +# Instead of updating the weights every step, you may effectively use a larger +# batch by accumulating the gradient over a set of steps. +gradient_accumulation_steps: 1 + +# AdamW optimizer parameters +# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 +opt_type: "adamw" # one of "adam_pax" or "adamw" +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_eps_root: 0. # A small constant applied to denominator inside the square root. +adam_weight_decay: 0.1 # AdamW Weight decay + +# Stack trace parameters +collect_stack_trace: False +stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False. +stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds. + +# Use iota operator in Embed +use_iota_embed: False +# use positional embedding +use_untrainable_positional_embedding: False +trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size +# Rope parameters +rope_min_timescale: 1 +rope_max_timescale: 10_000 + +# Ahead of time Compilation (aka AOT) +# Only set these arguments if you are running train_compile or loading a compiled train step. +compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compiled_train_v5e-256.pickle +compile_topology: '' # Target hardware version, e.g. 'v5e-256' +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + +decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, or topk +decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p +decode_sampling_top_k: 0 # set if you're doing top-k +decode_sampling_temperature: 1. + +eval_interval: -1 # the specific number of train step between eval_step +eval_steps: -1 # only run this number of batches for eval, for debugging use +target_eval_loss: 0. # early stop once reaching target eval_loss + +# Goodput parameters +enable_goodput_recording: False +monitor_goodput: False +goodput_upload_interval_seconds: 60 + +# Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md +# Set to True for GCE, False if running via XPK +use_vertex_tensorboard: False +# Project to create Vertex AI Tensorboard in for GCE, blank if project is set using 'gcloud config set project' +# Set this to blank if running via XPK +vertex_tensorboard_project: "" +# Region to create Vertex AI Tensorboard in for GCE, blank if running via XPK +# Vertex AI supported regions: https://cloud.google.com/vertex-ai/docs/general/locations#available-regions +vertex_tensorboard_region: "" + +# If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance. +max_checkify: False + +# Inference +inference_microbenchmark_prefill_lengths: "64,128,256,512,1024" +inference_microbenchmark_stages: "prefill,generate" +inference_microbenchmark_loop_iters: 10 +inference_microbenchmark_log_file_path: "" +inference_metadata_file: "" # path to a json file +enable_model_warmup: False + + +# KV Cache layout control +# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV +# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV +prefill_cache_axis_order: "1,2,0,3" +ar_cache_axis_order: "1,2,0,3" + +# Compute layout control +# Default layout: 0,1,2,3 ; BATCH, LENGTH, HEAD, D_KV +# Currently only support compute layout: 0,1,2,3 and 0,2,1,3 +compute_axis_order: "0,1,2,3" + +reshape_q: False + +# Maxengine Metrics +prometheus_port: 0 + +# Maxengine server +enable_jax_profiler: False +jax_profiler_port: 9999 + +# Checkpoint Structured logging +enable_checkpoint_cloud_logger: False +enable_checkpoint_standard_logger: False + +# Single-controller +enable_single_controller: True + +# Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html +allow_split_physical_axes: False + +use_ragged_attention: False +ragged_block_size: 256 + +### Splash attention block sizes +# These can be tuned for specific hardware generations, and can be set up to +# the model's sequence length. +sa_block_q: 512 +sa_block_q_dkv: 512 +sa_block_q_dq: 512