Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long nvFuser latencies seen in Inference Workload for 1-Layer of Llama 3.2 1B #3469

Closed
t-vi opened this issue Nov 25, 2024 · 5 comments
Closed
Assignees

Comments

@t-vi
Copy link
Contributor

t-vi commented Nov 25, 2024

While running HF Llama 3.2 1b with 1 layer in Thunder under the PyTorch profiler, we noticed that apparently, the NVFuser's run_fused_kernel takes a lot of time.
Lightning-AI/lightning-thunder#1467

To split things, one can use

for k, fdw in lt.python_ctx().items():
    if k.startswith('nvFusion'):
        fdw.store_inputs = True
jm(**args);

for k, fdw in lt.python_ctx().items():
    if k.startswith('nvFusion'):
        print(k)
        torch.cuda.synchronize(); 
        with torch.profiler.profile() as prof:
             fdw.last_used.execute(fdw.last_inputs); torch.cuda.synchronize()
        print(prof.key_averages().table())
@kevinstephano
Copy link
Collaborator

kevinstephano commented Nov 30, 2024

import torch                                                                                                                                                                                                  [76/137]
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from typing import Callable
from functools import partial, wraps

LLAMA_3_2_1B_CFG = {
    "architectures": ["LlamaForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "eos_token_id": 128001,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8192,
    "max_position_embeddings": 131072,
    "mlp_bias": False,
    "model_type": "llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 16,
    "num_key_value_heads": 8,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "factor": 32.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    "rope_theta": 500000.0,
    "tie_word_embeddings": True,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.45.0.dev0",
    "use_cache": True,
    "vocab_size": 128256,
    "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}


config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1

with torch.device("cuda"):
    model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()

args = dict(
    cache_positions=torch.arange(6, device="cuda"),
    input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
    attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
    inputs_embeds=None,
    use_cache=True,
    return_dict=True,
)

def cuda_timer(warmup_iters: int = 2, timing_iters: int = 10):
    def decorator(fn: Callable) -> Callable:
        @wraps(fn)
        def wrapper(*args, **kwargs) -> float:
            for _ in range(warmup_iters):
                fn(*args, **kwargs)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            start.record()
            for _ in range(timing_iters):
                fn(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()

            kernel_time = start.elapsed_time(end) / timing_iters
            return kernel_time
        return wrapper
    return decorator

import thunder
# from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform
# jm = thunder.jit(model, transformers=[NvtxProfileTransform()])
jm = thunder.jit(model)

@cuda_timer()
def run_model():
    jm(**args)

kernel_time = run_model()

print(f"Thunder-nvFuser {kernel_time:.03f} ms")

tc = torch.compile(model)

@cuda_timer()
def run_tc_model():
    tc(**args)

tc_kernel_time = run_tc_model()

print(f"torch.compile {tc_kernel_time:.03f} ms")

thunder_eager = thunder.jit(model,
                    executors=["sdpa", "torch"],
                    # transformers=[NvtxProfileTransform()]),
                    ) #, transforms=(CUDAGraphTransform(),))

@cuda_timer()
def run_te_model():
    thunder_eager(**args)

kernel_time = run_te_model()

print(f"Thunder-eager {kernel_time:.03f} ms")

@cuda_timer()
def run_eager_model():
    model(**args)

kernel_time = run_eager_model()

print(f"torch-eager {kernel_time:.03f} ms")

@kevinstephano
Copy link
Collaborator

kevinstephano commented Nov 30, 2024

DGX H100-80GB Results:

Thunder-nvFuser 1.281 ms
Thunder-torch 1.232 ms
torch.compile 0.455 ms
torch-eager 1.014 ms
Execution Type Wall Clock Time (ms) CPU Overhead (ms) Kernel Time (ms) Kernels Overhead / Kernel (us)
Thunder-nvFuser 1.281 0.984 0.297 32 30.7
Thunder-torch 1.232 0.814 0.413 77 10.5
torch.compile 0.455 0.163 0.292 24 6.8
torch-eager 1.014 0.616 0.398 65 9.4

@kevinstephano
Copy link
Collaborator

One thing that is noticeable in Thunder is that there is ~250 us of startup time that is not occurring in torch.compile or torch-eager that does not allow the remaining CPU overhead to overlap with the large GEMM at the end of the previous step. This difference might be the difference in CPU overhead between Thunder-torch and torch.eager, although, there is a difference in the number of kernels.

Image

@kevinstephano
Copy link
Collaborator

kevinstephano commented Nov 30, 2024

These are the issues I am seeing:

  1. There is a ~250 us startup overhead per step with Thunder that appears the difference between Thunder-Torch and torch-eager.
  2. nvFuser has 20 to 30 us overhead per kernel launch even with segmentation. This is cumulative with Thunder's overhead. There is not a quick fix to this overhead. For nvFuser to remove this overhead it would need to change it's relationship with Thunder and rely on Thunder to do its static shapes caching not recalculate launch parameters. We have not prioritized reducing this overhead as our main focus has been training and then using Cuda Graphs for latency sensitive worklaods.
  3. It looks like Inductor chooses to re-order some factory functions that might have been advantageous towards reducing kernels. You see this is this in the mask generation at the beginning of the step.
  4. I did see a bug in nvFuser that was found when I reduced the executors to ['sdpa', 'torch', 'nvfuser'].

@kevinstephano kevinstephano changed the title long self cpu time in run_fused_kernel Long nvFuser latencies seen in Inference Workload for 1-Layer of Llama 3.2 1B Nov 30, 2024
@kevinstephano
Copy link
Collaborator

Closing in favor of enhancement #3507.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants