Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8 FA from Quark format #388

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,20 @@ def get_cache_scale(self, name: str) -> Optional[str]:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")

elif len(kv_proj_names) == 2:
elif len(kv_proj_names) in {2, 3}:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
elif kv_proj_name in name and kv_proj_name == "q_proj":
return name.replace(".q_proj.output_scale",
".attn.q_scale")

if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")

# If no matches, return None
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.out_dtype = torch.get_default_dtype()

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -134,6 +135,7 @@ def apply_weights(self,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -197,7 +198,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(quant_config, Fp8Config)
self.use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = Grok1Attention(hidden_size=self.hidden_size,
Expand Down
14 changes: 11 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -84,7 +85,9 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.use_fp8 = (isinstance(quant_config, Fp8Config)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
if current_platform.is_rocm() and not is_navi() else
False)
if hidden_act != "silu":
Expand Down Expand Up @@ -196,10 +199,13 @@ def __init__(self,
sliding_window = None

# For CUDA devices and Navi4x, attn_fp8 will be set to false.
use_fp8 = isinstance(
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config)
and use_fp8

self.attn = Attention(
self.num_heads,
Expand Down Expand Up @@ -240,7 +246,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = (isinstance(quant_config, Fp8Config)
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
(isinstance(quant_config, QuarkConfig)
and quant_config._is_fp8_w8a8)
if current_platform.is_rocm() and not is_navi() else
False)
rope_theta = getattr(config, "rope_theta", 10000)
Expand Down
Loading