From e291c518a6fcc6e34a12859cec1125e3bb9b68b9 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 13 Dec 2024 21:11:25 +0100 Subject: [PATCH] add layer scan --- entropix/model.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/entropix/model.py b/entropix/model.py index f9f0028..c9ad5e3 100644 --- a/entropix/model.py +++ b/entropix/model.py @@ -60,12 +60,19 @@ def feed_forward(x: jax.Array, layer_weights: LayerWeights) -> jax.Array: h = h1 * shard(jnp.dot(x, layer_weights.w3), PS(None, None, 'mp')) return shard(jnp.dot(h, layer_weights.w2), PS()) + def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache]: - h = xfmr_weights.tok_embeddings[tokens] - for i in range(model_params.n_layers): - norm_x = rms_norm(h, xfmr_weights.layer_weights[i].attention_norm) - h_attn, kvcache, scores = attention(norm_x, xfmr_weights.layer_weights[i], model_params, cur_pos, i, freqs_cis, kvcache, attn_mask=attn_mask) + def step_fn(inp, wi): + h, kvcache = inp + w, i = wi + norm_x = rms_norm(h, w.attention_norm) + h_attn, kvcache, scores = attention(norm_x, w, model_params, cur_pos, i, freqs_cis, kvcache, attn_mask=attn_mask) h = h + h_attn - h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i]) - logits = jnp.dot(rms_norm(h, xfmr_weights.norm), xfmr_weights.output.T) - return logits, kvcache, scores + h = h + feed_forward(rms_norm(h, w.ffn_norm), w) + return (h, kvcache), None + + x = xfmr_weights.tok_embeddings[tokens] + (x, kvcache), _ = jax.lax.scan(step_fn, (x, kvcache), (jax.tree_util.tree_map(lambda *x: jnp.stack(x, 0), *xfmr_weights.layer_weights), jnp.arange(model_params.n_layers)), model_params.n_layers) + logits = jnp.dot(rms_norm(x, xfmr_weights.norm), xfmr_weights.output.T) + return logits, kvcache, None +