From 3b67d72e4c7c4b5f67c637d962e7cf600a2ef94b Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 17 Jan 2025 11:25:31 -0800 Subject: [PATCH] Clean up FusionDefinition::execute --- csrc/python_frontend/fusion_definition.cpp | 103 +++++++++++---------- 1 file changed, 55 insertions(+), 48 deletions(-) diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 9003ec09351..23537b8b467 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -360,7 +360,6 @@ std::vector FusionDefinition::execute( auto scheds = fusionCache()->queryFusionSchedules(id().value()); - std::vector outputs; if (profile) { ProfilerOptionsGuard::getCurOptions().set(ProfilerOption::Enable); } @@ -380,64 +379,72 @@ std::vector FusionDefinition::execute( DisableOptionsGuard::getCurOptions().set(opt.value()); } - if (!override_user_schedule) { + auto find_user_schedule = [&]() -> const UserSchedule* { + if (override_user_schedule) { + return nullptr; + } + + auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, inputs); + if (!user_sched_id.has_value()) { + return nullptr; + } + auto device = getCommonDeviceCUDA(inputs, selected_device); NVF_CHECK( inputs.empty() || device > -1, "Inputs are not all on the same device or don't match selection!"); - auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, inputs); - if (user_sched_id.has_value()) { - if (isProfilerEnabledWithCupti()) { - FusionProfiler::start(); - FusionProfiler::createSegments(1); + const UserSchedule& user_sched = + fusionCache()->queryUserSchedule(scheds, user_sched_id.value(), device); + return &user_sched; + }; + const auto* user_sched = find_user_schedule(); + + std::vector outputs; + if (user_sched == nullptr) { + outputs = scheds->auto_gen_schedules->runFusionWithInputs( + inputs, std::nullopt, selected_device); + } else { + if (isProfilerEnabledWithCupti()) { + FusionProfiler::start(); + FusionProfiler::createSegments(1); + } + scheds->last_user_def_scheduled_ir = user_sched->scheduled_fusion.get(); + scheds->last_user_def_executor = user_sched->executor.get(); + + if (user_sched->heuristic_params == nullptr) { + // Manual schedule + if (!user_sched->executor->isCompiled()) { + user_sched->executor->compile( + user_sched->scheduled_fusion.get(), inputs); } - auto& user_sched = fusionCache()->queryUserSchedule( - scheds, user_sched_id.value(), device); - scheds->last_user_def_scheduled_ir = user_sched.scheduled_fusion.get(); - scheds->last_user_def_executor = user_sched.executor.get(); - - if (user_sched.heuristic_params == nullptr) { - // Manual schedule - if (!user_sched.executor->isCompiled()) { - user_sched.executor->compile( - user_sched.scheduled_fusion.get(), inputs); - } - outputs = user_sched.executor->run(inputs); - } else { - // Automatic scheduler was used for UserSchedule. - // Pass launch and compile params to compileFusion and runFusion. - if (!user_sched.executor->isCompiled()) { - user_sched.executor->compile( - user_sched.scheduled_fusion.get(), - KernelArgumentHolder::createKernelArgumentHolder( - inputs, getCommonDeviceCUDA(inputs)), - user_sched.heuristic_params->lparams, - user_sched.heuristic_params->cparams, - user_sched.heuristic_params->scheduler_type); - } - outputs = user_sched.executor->run( - inputs, - user_sched.heuristic_params->lparams, - user_sched.heuristic_params->cparams); + outputs = user_sched->executor->run(inputs); + } else { + // Automatic scheduler was used for UserSchedule. + // Pass launch and compile params to compileFusion and runFusion. + if (!user_sched->executor->isCompiled()) { + user_sched->executor->compile( + user_sched->scheduled_fusion.get(), + KernelArgumentHolder::createKernelArgumentHolder( + inputs, getCommonDeviceCUDA(inputs)), + user_sched->heuristic_params->lparams, + user_sched->heuristic_params->cparams, + user_sched->heuristic_params->scheduler_type); } + outputs = user_sched->executor->run( + inputs, + user_sched->heuristic_params->lparams, + user_sched->heuristic_params->cparams); + } - if (isProfilerEnabledWithCupti()) { - FusionProfiler::segment(0).scheduler("user"); - FusionProfiler::stop(); - if (isProfilerPrintingEnabled()) { - debug() << FusionProfiler::profile(); - } + if (isProfilerEnabledWithCupti()) { + FusionProfiler::segment(0).scheduler("user"); + FusionProfiler::stop(); + if (isProfilerPrintingEnabled()) { + debug() << FusionProfiler::profile(); } } } - // when `!override_user_schedule == true`, it *could* have produced an - // output already at this point and we would not want to overwrite - // generated output through user scheduled kernel. - if (outputs.empty()) { - outputs = scheds->auto_gen_schedules->runFusionWithInputs( - inputs, std::nullopt, selected_device); - } if (profile) { ProfilerOptionsGuard::getCurOptions().unset(ProfilerOption::Enable); }