-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Long nvFuser latencies seen in Inference Workload for 1-Layer of Llama 3.2 1B #3469
Comments
import torch [76/137]
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from typing import Callable
from functools import partial, wraps
LLAMA_3_2_1B_CFG = {
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": False,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 16,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"rope_theta": 500000.0,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0.dev0",
"use_cache": True,
"vocab_size": 128256,
"_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}
config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1
with torch.device("cuda"):
model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()
args = dict(
cache_positions=torch.arange(6, device="cuda"),
input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
inputs_embeds=None,
use_cache=True,
return_dict=True,
)
def cuda_timer(warmup_iters: int = 2, timing_iters: int = 10):
def decorator(fn: Callable) -> Callable:
@wraps(fn)
def wrapper(*args, **kwargs) -> float:
for _ in range(warmup_iters):
fn(*args, **kwargs)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(timing_iters):
fn(*args, **kwargs)
end.record()
torch.cuda.synchronize()
kernel_time = start.elapsed_time(end) / timing_iters
return kernel_time
return wrapper
return decorator
import thunder
# from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform
# jm = thunder.jit(model, transformers=[NvtxProfileTransform()])
jm = thunder.jit(model)
@cuda_timer()
def run_model():
jm(**args)
kernel_time = run_model()
print(f"Thunder-nvFuser {kernel_time:.03f} ms")
tc = torch.compile(model)
@cuda_timer()
def run_tc_model():
tc(**args)
tc_kernel_time = run_tc_model()
print(f"torch.compile {tc_kernel_time:.03f} ms")
thunder_eager = thunder.jit(model,
executors=["sdpa", "torch"],
# transformers=[NvtxProfileTransform()]),
) #, transforms=(CUDAGraphTransform(),))
@cuda_timer()
def run_te_model():
thunder_eager(**args)
kernel_time = run_te_model()
print(f"Thunder-eager {kernel_time:.03f} ms")
@cuda_timer()
def run_eager_model():
model(**args)
kernel_time = run_eager_model()
print(f"torch-eager {kernel_time:.03f} ms") |
DGX H100-80GB Results:
|
One thing that is noticeable in Thunder is that there is ~250 us of startup time that is not occurring in |
These are the issues I am seeing:
|
Closing in favor of enhancement #3507. |
While running HF Llama 3.2 1b with 1 layer in Thunder under the PyTorch profiler, we noticed that apparently, the NVFuser's run_fused_kernel takes a lot of time.
Lightning-AI/lightning-thunder#1467
To split things, one can use
The text was updated successfully, but these errors were encountered: