Skip to content

Commit

Permalink
make debug automap optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Nov 4, 2024
1 parent 3bd6c27 commit aa38bbc
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions scripts/export_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class ExportConfig(Config):
save_format: Literal["pt", "safetensors"] = "safetensors"
torch_dtype: Literal["float32", "bfloat16"] = "bfloat16"
with_debug_automap: bool = False


def remap_keys_llama(k: str) -> str:
Expand All @@ -45,7 +46,7 @@ def _get_ffn_dim(hidden_dim: int, ffn_dim_multiplier: float, multiple_of: int) -
return hidden_dim


def convert_config_zb_to_hf(zb_config: ModelArgs) -> LlamaConfig:
def convert_config_zb_to_hf(zb_config: ModelArgs, with_debug_automap: bool = False) -> LlamaConfig:
"""Convert ZeroBand config to HuggingFace config"""
config = LlamaConfig()
config.hidden_size = zb_config.dim
Expand All @@ -67,10 +68,11 @@ def convert_config_zb_to_hf(zb_config: ModelArgs) -> LlamaConfig:
"rope_type": "default",
}

config.auto_map = {
"AutoConfig": "PrimeIntellect/prime-llama-debug--configuration_llama.LlamaConfig",
"AutoModelForCausalLM": "PrimeIntellect/prime-llama-debug--modeling_llama.LlamaForCausalLM"
}
if with_debug_automap:
config.auto_map = {
"AutoConfig": "PrimeIntellect/prime-llama-debug--configuration_llama.LlamaConfig",
"AutoModelForCausalLM": "PrimeIntellect/prime-llama-debug--modeling_llama.LlamaForCausalLM"
}

return config

Expand Down Expand Up @@ -122,7 +124,7 @@ def main(config: ExportConfig):
)

# Convert ZeroBand config to HuggingFace config
hf_config = convert_config_zb_to_hf(model_config)
hf_config = convert_config_zb_to_hf(model_config, with_debug_automap=config.with_debug_automap)
hf_config.to_json_file(save_path / "config.json")

# Load checkpoint
Expand All @@ -145,14 +147,15 @@ def main(config: ExportConfig):
index_json = {}
total_size = 0
state_dict = {remap_keys_llama(k): v for k, v in state_dict.items()}
with torch.no_grad():
for i in range(len(hf_config.num_hidden_layers)):
old_q = state_dict[f"model.layers.{i}.self_attn.q_proj.weight"]
old_k = state_dict[f"model.layers.{i}.self_attn.k_proj.weight"]
new_q = convert_qk_from_complex_to_rotate_half(old_q, 128)
new_k = convert_qk_from_complex_to_rotate_half(old_k, 128)
state_dict[f"model.layers.{i}.self_attn.q_proj.weight"].copy_(new_q)
state_dict[f"model.layers.{i}.self_attn.k_proj.weight"].copy_(new_k)
if not config.with_debug_automap: # The debug uses complex rotary impl
with torch.no_grad():
for i in range(hf_config.num_hidden_layers):
old_q = state_dict[f"model.layers.{i}.self_attn.q_proj.weight"]
old_k = state_dict[f"model.layers.{i}.self_attn.k_proj.weight"]
new_q = convert_qk_from_complex_to_rotate_half(old_q, 128)
new_k = convert_qk_from_complex_to_rotate_half(old_k, 128)
state_dict[f"model.layers.{i}.self_attn.q_proj.weight"].copy_(new_q)
state_dict[f"model.layers.{i}.self_attn.k_proj.weight"].copy_(new_k)
if "model.freqs_cis" in state_dict: # This should not be persisted
del state_dict["model.freqs_cis"]
if config.torch_dtype == "bfloat16":
Expand All @@ -162,6 +165,7 @@ def main(config: ExportConfig):
state_keys = list(state_dict.keys())
shard_size = int(math.ceil(len(state_keys) / num_shards))
logger.info("Saving model to %d shards", num_shards)

for i in range(num_shards):
_file = save_path / f"model-{i:04}-of-{num_shards:04}.{config.save_format}"
start = i * shard_size
Expand Down

0 comments on commit aa38bbc

Please sign in to comment.