From cd01f77647dc2f54d169b2dbfe2380637bd193f8 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 20 Jan 2025 14:26:03 -0800 Subject: [PATCH 1/3] Fix a typo (#3731) --- tests/python/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/utils.py b/tests/python/utils.py index 308412b9657..3c0185e0896 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -288,7 +288,7 @@ def check_cpp_translation( "(A failure here suggests a mismatch in functionality between the original and cloned definitions.)" ) print("Does FusionDefinition supports segmentation?\t", supports_segmentation) - print(fd.getReproErrorString("executing", inputs)) + print(fd._repro_error_str("executing", inputs)) raise err From fd23b1b99f88142e45bf4a0148581f0ddbda1abb Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 20 Jan 2025 14:26:48 -0800 Subject: [PATCH 2/3] Fix type annotation (#3728) --- nvfuser/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index d65198c21a7..b4d0d91cc49 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -6,7 +6,7 @@ import os import re import sys -from typing import Callable, Optional, Union, List # noqa: F401 +from typing import Callable import warnings import torch @@ -556,8 +556,8 @@ def _repro_error_str(self, section: str, inputs: list | None = None): def validate( self, - inputs: List[torch.Tensor], - reference_outputs: List[torch.Tensor], + inputs: list[torch.Tensor], + reference_outputs: list[torch.Tensor], kwargs=None, ): """ From aa3c3d3b9d39e73c21bd5bdfee0a38b9bf7c06b6 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 20 Jan 2025 15:37:47 -0800 Subject: [PATCH 3/3] Clean up FusionDefinition::execute (#3726) --- csrc/python_frontend/fusion_definition.cpp | 103 +++++++++++---------- 1 file changed, 55 insertions(+), 48 deletions(-) diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 9003ec09351..23537b8b467 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -360,7 +360,6 @@ std::vector FusionDefinition::execute( auto scheds = fusionCache()->queryFusionSchedules(id().value()); - std::vector outputs; if (profile) { ProfilerOptionsGuard::getCurOptions().set(ProfilerOption::Enable); } @@ -380,64 +379,72 @@ std::vector FusionDefinition::execute( DisableOptionsGuard::getCurOptions().set(opt.value()); } - if (!override_user_schedule) { + auto find_user_schedule = [&]() -> const UserSchedule* { + if (override_user_schedule) { + return nullptr; + } + + auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, inputs); + if (!user_sched_id.has_value()) { + return nullptr; + } + auto device = getCommonDeviceCUDA(inputs, selected_device); NVF_CHECK( inputs.empty() || device > -1, "Inputs are not all on the same device or don't match selection!"); - auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, inputs); - if (user_sched_id.has_value()) { - if (isProfilerEnabledWithCupti()) { - FusionProfiler::start(); - FusionProfiler::createSegments(1); + const UserSchedule& user_sched = + fusionCache()->queryUserSchedule(scheds, user_sched_id.value(), device); + return &user_sched; + }; + const auto* user_sched = find_user_schedule(); + + std::vector outputs; + if (user_sched == nullptr) { + outputs = scheds->auto_gen_schedules->runFusionWithInputs( + inputs, std::nullopt, selected_device); + } else { + if (isProfilerEnabledWithCupti()) { + FusionProfiler::start(); + FusionProfiler::createSegments(1); + } + scheds->last_user_def_scheduled_ir = user_sched->scheduled_fusion.get(); + scheds->last_user_def_executor = user_sched->executor.get(); + + if (user_sched->heuristic_params == nullptr) { + // Manual schedule + if (!user_sched->executor->isCompiled()) { + user_sched->executor->compile( + user_sched->scheduled_fusion.get(), inputs); } - auto& user_sched = fusionCache()->queryUserSchedule( - scheds, user_sched_id.value(), device); - scheds->last_user_def_scheduled_ir = user_sched.scheduled_fusion.get(); - scheds->last_user_def_executor = user_sched.executor.get(); - - if (user_sched.heuristic_params == nullptr) { - // Manual schedule - if (!user_sched.executor->isCompiled()) { - user_sched.executor->compile( - user_sched.scheduled_fusion.get(), inputs); - } - outputs = user_sched.executor->run(inputs); - } else { - // Automatic scheduler was used for UserSchedule. - // Pass launch and compile params to compileFusion and runFusion. - if (!user_sched.executor->isCompiled()) { - user_sched.executor->compile( - user_sched.scheduled_fusion.get(), - KernelArgumentHolder::createKernelArgumentHolder( - inputs, getCommonDeviceCUDA(inputs)), - user_sched.heuristic_params->lparams, - user_sched.heuristic_params->cparams, - user_sched.heuristic_params->scheduler_type); - } - outputs = user_sched.executor->run( - inputs, - user_sched.heuristic_params->lparams, - user_sched.heuristic_params->cparams); + outputs = user_sched->executor->run(inputs); + } else { + // Automatic scheduler was used for UserSchedule. + // Pass launch and compile params to compileFusion and runFusion. + if (!user_sched->executor->isCompiled()) { + user_sched->executor->compile( + user_sched->scheduled_fusion.get(), + KernelArgumentHolder::createKernelArgumentHolder( + inputs, getCommonDeviceCUDA(inputs)), + user_sched->heuristic_params->lparams, + user_sched->heuristic_params->cparams, + user_sched->heuristic_params->scheduler_type); } + outputs = user_sched->executor->run( + inputs, + user_sched->heuristic_params->lparams, + user_sched->heuristic_params->cparams); + } - if (isProfilerEnabledWithCupti()) { - FusionProfiler::segment(0).scheduler("user"); - FusionProfiler::stop(); - if (isProfilerPrintingEnabled()) { - debug() << FusionProfiler::profile(); - } + if (isProfilerEnabledWithCupti()) { + FusionProfiler::segment(0).scheduler("user"); + FusionProfiler::stop(); + if (isProfilerPrintingEnabled()) { + debug() << FusionProfiler::profile(); } } } - // when `!override_user_schedule == true`, it *could* have produced an - // output already at this point and we would not want to overwrite - // generated output through user scheduled kernel. - if (outputs.empty()) { - outputs = scheds->auto_gen_schedules->runFusionWithInputs( - inputs, std::nullopt, selected_device); - } if (profile) { ProfilerOptionsGuard::getCurOptions().unset(ProfilerOption::Enable); }