Skip to content

Commit

Permalink
[FRONTEND] Add env to override fp_fusion default behavior (triton-lan…
Browse files Browse the repository at this point in the history
…g#4157)

This is convenient for debugging potential floating point precision
differences.
  • Loading branch information
ThomasRaoux authored Jun 18, 2024
1 parent 150d199 commit 0ba5f0c
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit.
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).

# Changelog

Expand Down
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"MLIR_ENABLE_DIAGNOSTICS",
"MLIR_ENABLE_DUMP",
"MLIR_ENABLE_TIMING",
"TRITON_DEFAULT_FP_FUSION",
"TRITON_DISABLE_LINE_INFO",
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
"TRITON_ENABLE_LLVM_DEBUG",
Expand Down
9 changes: 7 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5185,7 +5185,8 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device):


@pytest.mark.parametrize("enable_fp_fusion", [False, True])
def test_enable_fp_fusion(enable_fp_fusion, device):
@pytest.mark.parametrize("default_override", [False, True])
def test_enable_fp_fusion(enable_fp_fusion, default_override, device):
if is_hip():
pytest.skip(
'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton'
Expand All @@ -5198,7 +5199,11 @@ def mul_add(data):
tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0)

data = torch.randn((128, ), device=device, dtype=torch.float32)
h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion)
if default_override:
os.environ["TRITON_DEFAULT_FP_FUSION"] = "1" if enable_fp_fusion else "0"
h = mul_add[(1, )](data)
else:
h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion)

if not is_cuda():
return
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(self, target: GPUTarget) -> None:

def parse_options(self, opts) -> Any:
args = {'arch': self.target.arch}
if not "enable_fp_fusion" in args:
args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts})
return HIPOptions(**args)

Expand Down
2 changes: 2 additions & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
args["allow_fp8e4nv"] = self.capability >= 89
args["allow_fp8e4b15"] = self.capability < 90
if not "enable_fp_fusion" in args:
args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
return CUDAOptions(**args)

Expand Down

0 comments on commit 0ba5f0c

Please sign in to comment.