Skip to content

Commit

Permalink
Clean up FusionDefinition::execute
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Jan 17, 2025
1 parent bf66a0c commit 7584c8e
Showing 1 changed file with 55 additions and 48 deletions.
103 changes: 55 additions & 48 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ std::vector<at::Tensor> FusionDefinition::execute(

auto scheds = fusionCache()->queryFusionSchedules(id().value());

std::vector<at::Tensor> outputs;
if (profile) {
ProfilerOptionsGuard::getCurOptions().set(ProfilerOption::Enable);
}
Expand All @@ -380,64 +379,72 @@ std::vector<at::Tensor> FusionDefinition::execute(
DisableOptionsGuard::getCurOptions().set(opt.value());
}

if (!override_user_schedule) {
auto find_user_schedule = [&]() -> const UserSchedule* {
if (override_user_schedule) {
return nullptr;
}

auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, inputs);
if (!user_sched_id.has_value()) {
return nullptr;
}

auto device = getCommonDeviceCUDA(inputs, selected_device);
NVF_CHECK(
inputs.empty() || device > -1,
"Inputs are not all on the same device or don't match selection!");
auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, inputs);
if (user_sched_id.has_value()) {
if (isProfilerEnabledWithCupti()) {
FusionProfiler::start();
FusionProfiler::createSegments(1);
const UserSchedule& user_sched =
fusionCache()->queryUserSchedule(scheds, user_sched_id.value(), device);
return &user_sched;
};
const auto* user_sched = find_user_schedule();

std::vector<at::Tensor> outputs;
if (user_sched == nullptr) {
outputs = scheds->auto_gen_schedules->runFusionWithInputs(
inputs, std::nullopt, selected_device);
} else {
if (isProfilerEnabledWithCupti()) {
FusionProfiler::start();
FusionProfiler::createSegments(1);
}
scheds->last_user_def_scheduled_ir = user_sched->scheduled_fusion.get();
scheds->last_user_def_executor = user_sched->executor.get();

if (user_sched->heuristic_params == nullptr) {
// Manual schedule
if (!user_sched->executor->isCompiled()) {
user_sched->executor->compile(
user_sched->scheduled_fusion.get(), inputs);
}
auto& user_sched = fusionCache()->queryUserSchedule(
scheds, user_sched_id.value(), device);
scheds->last_user_def_scheduled_ir = user_sched.scheduled_fusion.get();
scheds->last_user_def_executor = user_sched.executor.get();

if (user_sched.heuristic_params == nullptr) {
// Manual schedule
if (!user_sched.executor->isCompiled()) {
user_sched.executor->compile(
user_sched.scheduled_fusion.get(), inputs);
}
outputs = user_sched.executor->run(inputs);
} else {
// Automatic scheduler was used for UserSchedule.
// Pass launch and compile params to compileFusion and runFusion.
if (!user_sched.executor->isCompiled()) {
user_sched.executor->compile(
user_sched.scheduled_fusion.get(),
KernelArgumentHolder::createKernelArgumentHolder(
inputs, getCommonDeviceCUDA(inputs)),
user_sched.heuristic_params->lparams,
user_sched.heuristic_params->cparams,
user_sched.heuristic_params->scheduler_type);
}
outputs = user_sched.executor->run(
inputs,
user_sched.heuristic_params->lparams,
user_sched.heuristic_params->cparams);
outputs = user_sched->executor->run(inputs);
} else {
// Automatic scheduler was used for UserSchedule.
// Pass launch and compile params to compileFusion and runFusion.
if (!user_sched->executor->isCompiled()) {
user_sched->executor->compile(
user_sched->scheduled_fusion.get(),
KernelArgumentHolder::createKernelArgumentHolder(
inputs, getCommonDeviceCUDA(inputs)),
user_sched->heuristic_params->lparams,
user_sched->heuristic_params->cparams,
user_sched->heuristic_params->scheduler_type);
}
outputs = user_sched->executor->run(
inputs,
user_sched->heuristic_params->lparams,
user_sched->heuristic_params->cparams);
}

if (isProfilerEnabledWithCupti()) {
FusionProfiler::segment(0).scheduler("user");
FusionProfiler::stop();
if (isProfilerPrintingEnabled()) {
debug() << FusionProfiler::profile();
}
if (isProfilerEnabledWithCupti()) {
FusionProfiler::segment(0).scheduler("user");
FusionProfiler::stop();
if (isProfilerPrintingEnabled()) {
debug() << FusionProfiler::profile();
}
}
}

// when `!override_user_schedule == true`, it *could* have produced an
// output already at this point and we would not want to overwrite
// generated output through user scheduled kernel.
if (outputs.empty()) {
outputs = scheds->auto_gen_schedules->runFusionWithInputs(
inputs, std::nullopt, selected_device);
}
if (profile) {
ProfilerOptionsGuard::getCurOptions().unset(ProfilerOption::Enable);
}
Expand Down

0 comments on commit 7584c8e

Please sign in to comment.