Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Fusion Executor, pulling out compiled kernel #3468

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

csarofeen
Copy link
Collaborator

@csarofeen csarofeen commented Nov 24, 2024

Pull out kernel compilation from the KernelExecutor, trying to separate out the two concepts as we will move towards a world where the execution of a kernel is done through HostIr.

  • Made CompiledKernel class to hold compilation information in compiled_kernel.h/cpp
  • Moved code:
    • from runtime/executor.h to runtime/compiled_kernel.h
    • from runtime/executor.cpp to runtime/compiled_kernel.cpp
    • from runtime/executor_utils.cpp to runtime/compiled_kernel.cpp (these are functions only used in compiled_kernel)
    • from sys/utils.cpp to runtime/compiled_kernel.cpp (these are functions only used in compiled_kernel)
  • Moved executor::compileRTC and executor::runRTC into its own class (RtcKernel). It shares compilation logic with CompiledKernel and is in compiled_kernel.h/cpp

Some improvements left for another time:

  • Don't disable the parameter cache completely when the size of a tensor is a function of an input scalar. I don't think this is particularly common as Thunder is mostly static shapes, but it might be good to support for resize ops.
  • Merge executor_utils::CudaExecutable and CompiledKernel. I'm not sure if this is the right thing to do, partially just because of RTCKernel and CompiledKernel both own a executor_utils::CudaExecutable

@csarofeen
Copy link
Collaborator Author

!test

Copy link
Collaborator Author

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a few TODO's in this PR, I might not take on all of them in this PR.

csrc/fusion.cpp Outdated Show resolved Hide resolved
buffer << cuda_src.rdbuf();
return buffer.str();
}
} // namespace
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything above this is only code motion.

csrc/runtime/compiled_kernel.cpp Show resolved Hide resolved
csrc/runtime/compiled_kernel.cpp Outdated Show resolved Hide resolved
csrc/runtime/compiled_kernel.cpp Outdated Show resolved Hide resolved
csrc/runtime/executor_utils.h Show resolved Hide resolved
@@ -58,20 +47,11 @@ struct CompiledKernel : public NonCopyable {
int register_spills = -1;
};

// Returns executable function and the ptxas log from compilation
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to compiled_kernel.cpp

@@ -253,12 +233,5 @@ void validateCircularBuffering(
kir::Kernel* kernel,
ExpressionEvaluator& expr_eval);

//! Query the target GPU version number NVRTC compiles CUDA kernels for
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to compiled_kernel.cpp

Copy link
Collaborator Author

@csarofeen csarofeen Dec 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually these are likely just removed as they should now be contained in runtime/compiled_kernel.cpp

@@ -32,117 +30,6 @@
#include <cstdlib>

namespace nvfuser {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to compiled_kernel.cpp

@@ -194,14 +81,6 @@ bool detectComputeSanitizer() {

namespace nvfuser {

namespace executor_utils {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to compiled_kernel.cpp

@csarofeen
Copy link
Collaborator Author

csarofeen commented Dec 22, 2024

TODO list:

  • Move pre lowering hooks to constructor of compiled kernel (cleanup executor.cpp 279)
  • Move post lowering hooks to the constructor of compiled kernel as well
  • rename compileFusion to compile
  • When compiled kernel checks when to disable the parameter cache change the check to make sure when an extent depends on a TensorView input it goes through metadata op, or it throws an error.
  • Not sure what this TODO means it's on compiled_kernel.cpp 1131 ("TODO: These properties should be set as part of the constructor so that it can be const")
  • Check that compiled_kernel.cpp 1233 ("TODO: high water mark should be computed via occupancy API after") is an old todo. This could be done using the occupancy calculator but when I tried to do it the linker failed. It's also a minor optimization so simply removed the todo.
  • Remove compiled_kernel.h::CompileOptions it only holds device and should likely be in compile params
  • Remove the TODO "TODO: Consider split out compileRtc and runRtc to a different" it will need to be evaluated in the future if that makes sense when compilation and execution are completely separate concepts
  • Check if the default constructor of CompiledKernel can be removed
  • executor.cpp 241 "TODO: Is this necessary?"
  • Check if executor.h still needs the function disableLaunchParamCache since CompiledKernel has one
  • executor.h 384 "TODO: Should this be removed?" SchedulerType scheduler_type_ = SchedulerType::None;

I'm leaving these to consider in the future.

  • Don't disable parameter cache completely when a scalar input is used in a computation of an ID extent
  • Check if executor_utils::CudaExecutable can be merged with CompiledKernel

@csarofeen
Copy link
Collaborator Author

!test

@csarofeen
Copy link
Collaborator Author

47 successful checks
https://nv/e2E/130642066

…untime_id, group_id, and device constant in CompiledKernel.
@csarofeen
Copy link
Collaborator Author

!test

@csarofeen
Copy link
Collaborator Author

!test

@csarofeen
Copy link
Collaborator Author

All checks have passed
48 successful checks
https://nv/e2E/131360124

@csarofeen csarofeen marked this pull request as ready for review January 1, 2025 22:20
@csarofeen csarofeen requested a review from naoyam January 1, 2025 22:20
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just took a pass. Left several comments/questions.

As far as I can see, there's no change in the underlying logic, but it's just code movement. Am I missing anything?

csrc/runtime/executor_params.h Show resolved Hide resolved
csrc/runtime/compiled_kernel.h Outdated Show resolved Hide resolved
csrc/runtime/compiled_kernel.h Show resolved Hide resolved
tests/cpp/utils.h Show resolved Hide resolved
csrc/runtime/executor.h Outdated Show resolved Hide resolved
@csarofeen
Copy link
Collaborator Author

!test

@csarofeen
Copy link
Collaborator Author

!test

@csarofeen
Copy link
Collaborator Author

csarofeen commented Jan 12, 2025

Getting a bunch of thunder failures: https://gitlab-master.nvidia.com/dl/pytorch/fuser-gh-mirror/-/jobs/133353224 I was able to reproduce one of them on main, so uncertain what's going on.

Clang build was the only other test to fail.

Follow up: Thunder failures reproduced on main at this point: #3698

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The Thunder failures are unlikely to have anything to do with this PR.

@csarofeen
Copy link
Collaborator Author

!test

@csarofeen
Copy link
Collaborator Author

!test

Copy link

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 No relevant tests
⚡ Recommended focus areas for review

Potential Logic Change

The logic of the KernelExecutor class has been modified to use a CompiledKernel object instead of directly accessing its members. This change may affect the behavior of the class and its methods.

namespace nvfuser {

std::unique_ptr<PrecomputedValues>& KernelExecutor::
    evaluatorPrecomputedValues() {
  if (!evaluator_precomputed_values_) {
    evaluator_precomputed_values_ =
        std::make_unique<PrecomputedValues>(compiledKernel()->kernel());
  }
  return evaluator_precomputed_values_;
}

bool ExprEvalExecutor::supported(Fusion* fusion) {
  FUSER_PERF_SCOPE("ExprEvalExecutor::supported");
  return std::all_of(
      fusion->outputs().begin(), fusion->outputs().end(), [&fusion](Val* out) {
        return fusion->getOutputAlias(out).type == AllocationType::Evaluate;
      });
}

void ExprEvalExecutor::compile(Fusion* fusion) {
  FUSER_PERF_SCOPE("ExprEvalExecutor::compile");
  if (isProfilerEnabled()) {
    FusionProfiler::segment(group_id_).startCompile();
  }
  NVF_ERROR(
      supported(fusion),
      "ExprEvalExecutor does not support the Fusion provided.");
  fusion_ = std::make_unique<Fusion>(*fusion);
  if (isProfilerEnabled()) {
    FusionProfiler::segment(group_id_).stopCompile();
  }
}

bool ExprEvalExecutor::isCompiled() const {
  return fusion_ != nullptr;
}

std::vector<at::Tensor> ExprEvalExecutor::run(
    KernelArgumentHolder& args,
    std::vector<at::Tensor> outputs) {
  FUSER_PERF_SCOPE("ExprEvalExecutor::run");

  if (isProfilerEnabled()) {
    NVF_CHECK(
        group_id_ >= 0,
        "An invalid segment id is passed to FusionProfiler!:",
        group_id_);
    SegmentProfiler& sprof = FusionProfiler::segment(group_id_);
    sprof.inputBytesAccessed(computeBytes(args));
    sprof.scheduler(toString(SchedulerType::ExprEval));
    sprof.startKernel();
  }

  NVF_ERROR(fusion_, "Need to compile before you can run.");
  // Bind fusion inputs
  auto expr_eval = executor_utils::bindInputs(args, fusion_.get());
  {
    NVF_ERROR(
        outputs.empty(),
        "Fusion executor is using expression evaluator,",
        " and expects that the outputs are not populated, which they were.");
    if (outputs.empty()) {
      for (const auto& out_val : fusion_->outputs()) {
        auto out_tensor =
            expr_eval.evaluate(out_val->as<TensorView>()).as<at::Tensor>();
        expr_eval.bind(out_val, out_tensor);
        outputs.emplace_back(out_tensor);
      }
    }
  }
  if (isProfilerEnabled()) {
    FusionProfiler::segment(group_id_).stopKernel();
    FusionProfiler::segment(group_id_).setDevice(args.getDeviceIndex());
  }
  return outputs;
}
Function Signature Change

The compile method of the KernelExecutor class has been modified to take a CompileParams object instead of individual parameters. This change may affect the usage of the method and its compatibility with existing code.

  const LaunchParams& launch_constraints,
  CompileParams compile_params,
  SchedulerType scheduler_type) {
FUSER_PERF_SCOPE("KernelExecutor::compile");

NVF_ERROR(
    supported(fusion),
    "KernelExecutor does not support the Fusion provided.");

NVF_ERROR(
    !fusion->outputs().empty(), "No output found for this kernel, aborting.");

auto device = c10::Device(c10::DeviceType::CUDA, args.getDeviceIndex());

if (isProfilerEnabled()) {
  NVF_CHECK(
      group_id_ >= 0,
      "An invalid segment id is passed to FusionProfiler!:",
      group_id_);
  FusionProfiler::segment(group_id_).setDevice(device.index());
  FusionProfiler::segment(group_id_).startCompile();
}

//! Force index_type to int and disable magic zero if we detect that the
//! kernel contains any TMA memory operations.
std::vector<Expr*> exprs = fusion->exprs();
bool has_cp_async_bulk = std::any_of(exprs.begin(), exprs.end(), [](Expr* e) {
  return ir_utils::isCpAsyncBulk(e);
});

// Disable magic zero if there are any TMA operations in Fusion
if (has_cp_async_bulk) {
  compile_params.enable_magic_zero = false;
}

// Set the index type of compile params if not already set. If set,
// make sure the compile param type is valid with the given kernel
// arguments.
auto arg_index_type = args.getSmallestIndexTypeOfArguments();
if (compile_params.index_type.has_value()) {
  // If the int32 compilation is requested, but the arguments demand
  // int64, that's an error
  NVF_ERROR(
      !(compile_params.index_type.value() == PrimDataType::Int32 &&
        arg_index_type == PrimDataType::Int),
      "Compilation with int32 is requested but int64 is required for the arguments");
  NVF_ERROR(
      !has_cp_async_bulk ||
          (compile_params.index_type.value() == PrimDataType::Int32),
      "Compilation with int64 is requested but int32 is required because ",
      "of TMA operations.");

} else if (arg_index_type == PrimDataType::Int) {
  // If the given compile option doesn't specify the index type, and
  // the arguments require 64-bit indexing, we need to use 64-bit
  // indexing. Note that if the arg type is 32-bit, it doesn't mean
  // it's safe to use 32-bit for the whole kernel, so unless it's
  // specified through CompileParams, we do not use 32-bit indexing.
  compile_params.index_type = arg_index_type;
  NVF_ERROR(
      !has_cp_async_bulk,
      "Compilation with int64 is required based on input arguments, but ",
      "int32 is required because of TMA operations.");
} else if (has_cp_async_bulk) {
Potential Logic Change

The logic of the run method of the KernelExecutor class has been modified to use a CompiledKernel object instead of directly accessing its members. This change may affect the behavior of the method and its compatibility with existing code.

              << ", occupancy=" << oss.str() << std::endl;
    }

    if (!compiled_kernel_->kernel()->summary().has_cooperative_grid_reduction) {
      FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchKernel");
      NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel(
          compiled_kernel_->cudaExecutable()->function,
          launch_params_.gdimx(),
          launch_params_.gdimy(),
          launch_params_.gdimz(),
          launch_params_.bdimx(),
          launch_params_.bdimy(),
          launch_params_.bdimz(),
          launch_params_.smem(),
          stream,
          executor_entry->arg_ptrs.data(),
          nullptr));
    } else {
      FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchCooperativeKernel");
      NVFUSER_CUDA_SAFE_CALL(cuLaunchCooperativeKernel(
          compiled_kernel_->cudaExecutable()->function,
          launch_params_.gdimx(),
          launch_params_.gdimy(),
          launch_params_.gdimz(),
          launch_params_.bdimx(),
          launch_params_.bdimy(),
          launch_params_.bdimz(),
          launch_params_.smem(),
          stream,
          executor_entry->arg_ptrs.data()));
    }
  }

  releaseZeroedMemory();

  if (isOptionEnabled(EnableOption::KernelProfile)) {
    debug() << compiled_kernel_->kernel()->profile().toString(profile_buffer);
  }

  if (isProfilerEnabled()) {
    auto& sprof = FusionProfiler::segment(group_id_);
    sprof.stopKernel();
    sprof.outputBytesAccessed(computeBytes(outputs));
  }

  return outputs;
}

flatbuffers::Offset<serde::KernelExecutor> KernelExecutor::serialize(
    flatbuffers::FlatBufferBuilder& builder) const {
  // See table definition for KernelExecutor in serde/fusion_cache.fbs
  using fb_executor_entry = flatbuffers::Offset<serde::ExecutorEntry>;

  // Separate unordered_map for executor_entry_lookup into key and value
  // vectors. The key value is the cache_id value in the KernelArgumentHolder.
  std::vector<size_t> executor_entry_lookup_keys_fb;
  std::vector<fb_executor_entry> executor_entry_lookup_values_fb;
  for (const auto& [key, value] : executor_entry_lookup_) {
    executor_entry_lookup_keys_fb.push_back(key);
    executor_entry_lookup_values_fb.push_back(serialize(builder, value));
  }

  // When compilation is skipped, avoid serializing cubin because it doesn't
  // exist. The remaining fields are also not necessary in this case.
  if (!compiledKernel()->hasCompiledKernel()) {
    return serde::CreateKernelExecutorDirect(builder);
  }

  return serde::CreateKernelExecutorDirect(
      builder,
      device_smem_limit_,
      compiledKernel()->blockSizeHighWaterMark(),
      compiledKernel()->maxrregcountHighWaterMark(),
      warp_size_,
      toUnderlying(compiledKernel()->schedulerType()),
      fusion_id_,
      concrete_id_,
      runtime_id_,
      group_id_,
      compiledKernel()->kernelCode().c_str(),
      &executor_entry_lookup_keys_fb,
      &executor_entry_lookup_values_fb,
      toUnderlying(compiledKernel()->kernel()->indexType()),
      serialize(builder, compiledKernel()->cudaExecutable().get()));
}

flatbuffers::Offset<serde::CudaKernel> KernelExecutor::serialize(
    flatbuffers::FlatBufferBuilder& builder,
    const executor_utils::CudaExecutable* compiled_kernel) const {
  NVF_ERROR(
      compiledKernel()->cudaExecutable() != nullptr &&
          (!compiled_kernel->cubin.empty() || !compiled_kernel->ptx.empty()),
      "Expected compiled cuda kernel before serializing KernelExecutor.");

  auto fb_kernel_name = builder.CreateString(compiled_kernel->kernel_name);
  auto fb_compile_args = builder.CreateString(compiled_kernel->compile_args);

  flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_cubin = 0;
  flatbuffers::Offset<flatbuffers::String> fb_cubin_filename = 0;
  if (!compiled_kernel->cubin.empty()) {
    uint8_t* cubin_ptr = nullptr;
    fb_cubin = builder.CreateUninitializedVector(
        compiled_kernel->cubin.size(), &cubin_ptr);
    std::copy(
        compiled_kernel->cubin.begin(),
        compiled_kernel->cubin.end(),
        cubin_ptr);
    fb_cubin_filename = builder.CreateString(compiled_kernel->cubin_filename);
  }

  flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_ptx = 0;
  flatbuffers::Offset<flatbuffers::String> fb_ptx_filename = 0;
  if (!compiled_kernel->ptx.empty()) {
    uint8_t* ptx_ptr = nullptr;
    fb_ptx = builder.CreateUninitializedVector(
        compiled_kernel->ptx.size(), &ptx_ptr);
    std::copy(
        compiled_kernel->ptx.begin(), compiled_kernel->ptx.end(), ptx_ptr);
    fb_ptx_filename = builder.CreateString(compiled_kernel->ptx_filename);
  }

  serde::CudaKernelBuilder ckb(builder);
  ckb.add_cubin(fb_cubin);
  ckb.add_cubin_filename(fb_cubin_filename);
  ckb.add_ptx(fb_ptx);
  ckb.add_ptx_filename(fb_ptx_filename);
  ckb.add_kernel_name(fb_kernel_name);
  ckb.add_compile_args(fb_compile_args);
  ckb.add_block_size(compiled_kernel->block_size);
  return ckb.Finish();
}

flatbuffers::Offset<serde::ExecutorEntry> KernelExecutor::serialize(
    flatbuffers::FlatBufferBuilder& builder,
    const ExecutorEntry& data) const {
  // See table definition for ExecutorEntry in serde/fusion_cache.fbs

  // Serialize GlobalBufferInfo for outputs.
  // We map the output TensorView pointer to its corresponding position in
  // fusion outputs assuming that the output ordering is consistent.
  using fb_global_buffer_info = flatbuffers::Offset<serde::GlobalBufferInfo>;
  std::vector<fb_global_buffer_info> outputs_fb;
  outputs_fb.reserve(data.outputs.size());
  for (const auto& buffer : data.outputs) {
    auto tv_iter = std::find(
        compiledKernel()->kernel()->outputs().cbegin(),
        compiledKernel()->kernel()->outputs().cend(),
        buffer.tv);
    auto tv_position = (tv_iter == compiledKernel()->kernel()->outputs().cend())
        ? -1
        : std::distance(
              compiledKernel()->kernel()->outputs().cbegin(), tv_iter);
    outputs_fb.push_back(
        serialize(builder, buffer, tv_position, true /* is_fusion_output */));
  }

  // Serialize GlobalBufferInfo for intermediates.
  // We map the intermediate TensorView pointer to its corresponding position in
  // KernelSummary global allocations. We assume that the ordering is consistent
  // between GpuLower objects with the same scheduled fusion.
  std::vector<fb_global_buffer_info> intermediates_fb;
  intermediates_fb.reserve(data.intermediates.size());
  for (const auto& buffer : data.intermediates) {
    auto match_tv_predicate = [buffer_tv = buffer.tv](const kir::Allocate* a) {
      return a->buffer() == buffer_tv;
    };
    auto tv_iter = std::find_if(
        compiledKernel()->kernel()->summary().global_allocations.cbegin(),
        compiledKernel()->kernel()->summary().global_allocations.cend(),
        match_tv_predicate);
    auto tv_position =
        (tv_iter ==
         compiledKernel()->kernel()->summary().global_allocations.cend())
        ? -1
        : std::distance(
              compiledKernel()->kernel()->summary().global_allocations.cbegin(),
              tv_iter);
    intermediates_fb.push_back(
        serialize(builder, buffer, tv_position, false /* is_fusion_output */));
  }

  return serde::CreateExecutorEntryDirect(
      builder,
      data.init,
      data.launch_params.serialize(builder),
      &outputs_fb,
      &intermediates_fb);
}

flatbuffers::Offset<serde::GlobalBufferInfo> KernelExecutor::serialize(
    flatbuffers::FlatBufferBuilder& builder,
    const GlobalBufferInfo& data,
    int64_t tv_position,
    bool is_fusion_output) const {
  // See table definition for GlobalBufferInfo in serde/fusion_cache.fbs
  return serde::CreateGlobalBufferInfoDirect(
      builder,
      tv_position,
      &data.sizes,
      &data.strides,
      nvfuser::toUnderlying(data.type),
      data.zero_init,
      data.resets_to_zero,
      data.is_profile_buffer,
      is_fusion_output);
}
New Class

A new class CompiledKernel has been added to the codebase. This class seems to encapsulate the compilation and execution of a kernel.

// clang-format off
/*
 * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 */
// clang-format on

#include <runtime/compiled_kernel.h>

#include <codegen.h>
#include <cuda_utils.h>
#include <debug.h>
#include <device_lower/analysis/bank_conflict.h>
#include <disjoint_set.h>
#include <driver_api.h>
#include <fusion_profiler.h>
#include <global_allocator.h>
#include <instrumentation.h>
#include <ir/all_nodes.h>
#include <ir/utils.h>
#include <iter_visitor.h>
#include <kernel_db/kernel_db.h>
#include <kernel_ir.h>
#include <multidevice/communication.h>
#include <multidevice/communicator.h>
#include <multidevice/utils.h>
#include <options.h>
#include <polymorphic_value.h>
#include <runtime/allocations.h>
#include <runtime/executor_kernel_arg.h>
#include <runtime/executor_utils.h>
#include <serde/utils.h>
#include <tensor_metadata.h>
#include <utils.h>

#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/llvm_jit_strings.h>
#include <ATen/native/cuda/jit_utils.h>
#include <c10/core/DeviceGuard.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/resource_guard.h>

#include <array>
#include <cmath>
#include <cstring>
#include <fstream>
#include <vector>

#include <cuda_runtime.h>

#include <nvfuser_resources/array.h>
#include <nvfuser_resources/basic_type_traits.h>
#include <nvfuser_resources/bf16_support.h>
#include <nvfuser_resources/bit.h>
#include <nvfuser_resources/block_reduction.h>
#include <nvfuser_resources/block_sync_atomic.h>
#include <nvfuser_resources/block_sync_default.h>
#include <nvfuser_resources/block_welford_outer.h>
#include <nvfuser_resources/broadcast.h>
#include <nvfuser_resources/complex_number.h>
#include <nvfuser_resources/fp16_support.h>
#include <nvfuser_resources/fp8_support.h>
#include <nvfuser_resources/fused_reduction.h>
#include <nvfuser_resources/fused_welford_helper.h>
#include <nvfuser_resources/fused_welford_impl.h>
#include <nvfuser_resources/fused_welford_impl_outer.h>
#include <nvfuser_resources/grid_broadcast.h>
#include <nvfuser_resources/grid_reduction.h>
#include <nvfuser_resources/grid_sync.h>
#include <nvfuser_resources/helpers.h>
#include <nvfuser_resources/index_utils.h>
#include <nvfuser_resources/mbarrier.h>
#include <nvfuser_resources/memory.h>
#include <nvfuser_resources/random_numbers.h>
#include <nvfuser_resources/tensor.h>
#include <nvfuser_resources/tuple.h>
#include <nvfuser_resources/type_traits.h>
#include <nvfuser_resources/warp.h>
#include <nvfuser_resources/welford.h>

namespace nvfuser {

namespace {

// Include all the functions we might need in generated code
std::string kernelPreamble() {
  std::stringstream ss;
  ss << nvfuser_resources::basic_type_traits_cu;
  ss << nvfuser_resources::bit_cu;
  ss << nvfuser_resources::complex_number_cu;

  ss << nvfuser_resources::fp16_support_cu;
  ss << nvfuser_resources::bf16_support_cu;
  ss << nvfuser_resources::fp8_support_cu;

  // Base classes and helpers
  ss << nvfuser_resources::type_traits_cu;
  ss << nvfuser_resources::array_cu;
  ss << nvfuser_resources::tensor_cu;
  ss << nvfuser_resources::random_numbers_cu;
  ss << nvfuser_resources::helpers_cu;
  ss << nvfuser_resources::index_utils_cu;
  ss << nvfuser_resources::tuple_cu;

  // Synchronization classes
  if (getNvFuserEnv("USE_BLOCK_SYNC_ATOMIC")) {
    ss << nvfuser_resources::block_sync_atomic_cu;
  } else {
    ss << nvfuser_resources::block_sync_default_cu;
  }
  ss << nvfuser_resources::grid_sync_cu;
  ss << nvfuser_resources::mbarrier_cu;

  // Communication classes
  ss << nvfuser_resources::block_reduction_cu;
  ss << nvfuser_resources::grid_reduction_cu;
  ss << nvfuser_resources::grid_broadcast_cu;
  ss << nvfuser_resources::broadcast_cu;
  ss << nvfuser_resources::welford_cu;
  ss << nvfuser_resources::warp_cu;
  ss << nvfuser_resources::memory_cu;
  ss << nvfuser_resources::fused_welford_helper_cu;
  ss << nvfuser_resources::fused_reduction_cu;
  ss << nvfuser_resources::fused_welford_impl_cu;
  ss << nvfuser_resources::block_welford_outer_cu;
  ss << nvfuser_resources::fused_welford_impl_outer_cu;

  return ss.str();
}

//! Utility class to invoke nvrtcCompileProgram. Mainly for setting up
//! the c-str options.
//! TODO: Revisit if we should remove or restructure this utility function
class NvrtcCompileDriver {
 public:
  void setOption(const std::string& opt) {
    options_.push_back(opt);
  }

  const std::vector<std::string>& options() const {
    return options_;
  }

  std::string invoke(nvrtcProgram program, const std::string& src) const {
    FUSER_PERF_SCOPE("executor_utils::Nvrtc::CompileProgram");
    auto opts = getOptions();
    auto result = nvrtcCompileProgram(
        program, static_cast<int>(opts.size()), opts.data());
    size_t logsize = 0;
    NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(program, &logsize));
    // The log size, as returned by 'nvrtcGetProgramLogSize', appears larger
    // than its actual size by 2. This discrepancy was noticed in NVRTC
    // version 12.1. The log returned from 'nvrtcGetProgramLog' terminates with
    // a NULL character, ensuring it's safe to use 'std::vector<char>' for
    // storage before converting it to 'std::string'.
    std::vector<char> log_backing_buf(logsize);
    char* log_buf = log_backing_buf.data();
    NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(program, log_buf));
    if (result != NVRTC_SUCCESS) {
      // Print CUDA starting at first global function
      size_t kernel_start = src.find("__global__");
      NVF_THROW(
          "\n",
          src.substr(kernel_start),
          "\nCUDA NVRTC compile error: ",
          log_buf);
    }
    if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
      debug() << log_buf << std::endl;
    }
    return std::string(log_buf);
  }

 private:
  // Get options that can be passed to nvrtcCompileProgram
  std::vector<const char*> getOptions() const {
    std::vector<const char*> opts(options_.size());
    for (const auto i : c10::irange(options_.size())) {
      opts.at(i) = options_.at(i).c_str();
    }
    return opts;
  }

 private:
  std::vector<std::string> options_;
};

// Query the target GPU version number NVRTC compiles CUDA kernels for
void queryTargetGPUVersion(
    const cudaDeviceProp* const prop,
    int64_t& major,
    int64_t& minor,
    bool& compile_to_sass) {
  using CudaVersion = std::pair<int, int>;
  CudaVersion nvrtc_version;
  NVFUSER_NVRTC_SAFE_CALL(
      nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));

  NVF_CHECK(
      nvrtc_version.first >= 6,
      "NVRTC versions less than 6 are not supported. Is: ",
      nvrtc_version.first);

  // Version supported by device
  // Usually any lower version works too but is less efficient
  const CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
  // Maximum version supported by the driver, cap dev_version to this
  CudaVersion max_dev_version;
  if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
    max_dev_version = CudaVersion(5, 0);
  } else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
    max_dev_version = CudaVersion(6, 0);
  } else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
    max_dev_version = CudaVersion(7, 2);
  } else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
    max_dev_version = CudaVersion(7, 5);
  } else if (nvrtc_version == CudaVersion(11, 0)) { // 11.0 supports 3-8.0
    max_dev_version = CudaVersion(8, 0);
  } else if (nvrtc_version.first == 11 && nvrtc_version.second < 8) {
    max_dev_version = CudaVersion(8, 6);
  } else {
    // If the driver version is unknown (i.e. newer than this code)
    // assume the driver supports this device
    max_dev_version = dev_version;
  }
  if (dev_version > max_dev_version) {
    major = max_dev_version.first;
    minor = max_dev_version.second;
    // if we are clamping major/minor, sass is not compatible
    compile_to_sass = false;
  } else {
    major = dev_version.first;
    minor = dev_version.second;
    compile_to_sass = true;
  }
}

#if defined(__linux__)
std::string disassembleBinary(
    const std::vector<char>& cubin,
    const std::string& nvdisasm_args) {
  const char* err = "Failed to disassemble cubin";

  // Reference:
  // https://stackoverflow.com/a/3469651
  // https://linuxhint.com/dup2_system_call_c/

  constexpr int READ = 0, WRITE = 1;
  std::array<int, 2> cubin_pipe{-1, -1};
  std::array<int, 2> disasm_pipe = {-1, -1};
  std::array<int, 2> err_pipe = {-1, -1};

  NVF_ERROR(
      pipe(cubin_pipe.data()) == 0 && pipe(disasm_pipe.data()) == 0 &&
          pipe(err_pipe.data()) == 0,
      err);

  pid_t pid = fork();
  NVF_ERROR(pid != -1, err);

  if (pid) { // I am the parent
    // Parent only write cubin and read disasm, close unused pipe end
    NVF_ERROR(close(cubin_pipe[READ]) == 0, err);
    NVF_ERROR(close(disasm_pipe[WRITE]) == 0, err);
    NVF_ERROR(close(err_pipe[WRITE]) == 0, err);

    // Wrap pipe fileno as C file stream
    FILE* cubin_fp = fdopen(cubin_pipe[WRITE], "wb");
    FILE* disasm_fp = fdopen(disasm_pipe[READ], "r");
    FILE* err_fp = fdopen(err_pipe[READ], "r");
    NVF_ERROR(cubin_fp != nullptr, err);
    NVF_ERROR(disasm_fp != nullptr, err);
    NVF_ERROR(err_fp != nullptr, err);

    // Write cubin to nvdisasm
    size_t written = fwrite(cubin.data(), 1, cubin.size(), cubin_fp);
    NVF_ERROR(written == cubin.size(), err);
    fclose(cubin_fp);

    int ch = -1;

    // read disassembly result
    std::string result;
    result.reserve(cubin.size());
    while ((ch = fgetc(disasm_fp)) != EOF) {
      result.push_back((char)ch);
    }
    fclose(disasm_fp);

    // read error message
    std::string error;
    while ((ch = fgetc(err_fp)) != EOF) {
      error.push_back((char)ch);
    }
    fclose(err_fp);
    NVF_CHECK(error.empty(), error);

    return result;
  } else { // I am the child
    // For easier understanding, we can consider the fileno as a smart pointer
    // pointing to an underlying IO object in the kernel. Both the pointer and
    // the underlying objects are owned by the kernel, and multiple pointers
    // can point to the same object. `close` destroy the pointer, which does
    // not necessarily destroy the object.

    // Modify the stdin, stdout and stderr pointer to point to the pipe object
    NVF_ERROR(close(STDIN_FILENO) == 0, err);
    NVF_ERROR(close(STDOUT_FILENO) == 0, err);
    NVF_ERROR(close(STDERR_FILENO) == 0, err);
    NVF_ERROR(dup2(cubin_pipe[READ], STDIN_FILENO) != -1, err);
    NVF_ERROR(dup2(disasm_pipe[WRITE], STDOUT_FILENO) != -1, err);
    NVF_ERROR(dup2(err_pipe[WRITE], STDERR_FILENO) != -1, err);

    // Now we have stdin, stdout and stderr pointing to the pipe object, we no
    // longer need the original pointers.
    NVF_ERROR(close(cubin_pipe[READ]) == 0, err);
    NVF_ERROR(close(cubin_pipe[WRITE]) == 0, err);
    NVF_ERROR(close(disasm_pipe[READ]) == 0, err);
    NVF_ERROR(close(disasm_pipe[WRITE]) == 0, err);
    NVF_ERROR(close(err_pipe[READ]) == 0, err);
    NVF_ERROR(close(err_pipe[WRITE]) == 0, err);

    // If execl succeed, then the current process will be replaced by nvdisasm,
    // and all the remaining code in the current process will not be executed.
    // So, execl only returns when it fails.
    //
    // TODO: I was planning to use `nvdisasm /dev/stdin` which could avoid
    // creating temporary file, but unfortunately, that fails with:
    //   nvdisasm fatal   : Memory allocation failure
    // so I have to dump the stdin to a temp file and let nvdisasm read it. I am
    // hoping that nvdisasm will support reading from stdin one day.
    std::stringstream ss;
    ss << "export PATH=$PATH:/usr/local/cuda/bin;"
       << "TMPFILE=$(mktemp);"
       << "cat>$TMPFILE;"
       << "nvdisasm $TMPFILE " << nvdisasm_args << "; rm $TMPFILE";
    auto command = ss.str();
    execl("/bin/bash", "bash", "-c", command.c_str(), NULL);

    // only reachable when execl fails
    NVF_THROW(err);
  }
}
#else // #if defined(__linux__)
std::string disassembleBinary(const std::vector<char>& binary) {
  NVF_CHECK(false, "disassembling cubin is only supported on Linux");
}
#endif // #if defined(__linux__)

//! Utility class to invoke cuModuleLoadDataEx. Similar to
//! NvrtcCompileDriver, the main task is to set up the option lists
//! of type void**
class CuModuleLoadDataDriver {
 public:
  //! Valid option type is either int or char*
  using OptionType = std::variant<int, char*>;

  template <typename OptionValType>
  void setOption(CUjit_option key, OptionValType val) {
    options_.push_back(key);
    option_vals_.push_back(val);
  }

  //! Enable logging of cuModuleLoadData
  void enableLogging() {
    logging_enabled_ = true;
    log_.resize(kLogSize);
  }

  const std::string& log() const {
    NVF_ERROR(logging_enabled_, "Logging not enabled");
    return log_;
  }

  //! Invoke cuModuleLoadDataEx with ptx or cubin. Dump logging output
  //! if enabled
  std::string invoke(CUmodule& module, const void* image) {
    FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX");

    auto [opts, opt_vals] = getOptions();

    NVFUSER_CUDA_SAFE_CALL(cuModuleLoadDataEx(
        &module, image, opts.size(), opts.data(), opt_vals.data()));

    if (logging_enabled_) {
      debug() << log_ << std::endl;
    }

    return log_;
  }

 private:
  // Get options that can be passed to cuModuleLoadDataEx
  std::pair<std::vector<CUjit_option>, std::vector<void*>> getOptions() {
    auto opts = options_;
    auto opt_vals = option_vals_;

    // Append options for saving log message to log_
    if (logging_enabled_) {
      opts.push_back(CU_JIT_LOG_VERBOSE);
      opt_vals.emplace_back(1);

      opts.push_back(CU_JIT_INFO_LOG_BUFFER);
      opt_vals.emplace_back(log_.data());

      opts.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES);
      opt_vals.emplace_back(kLogSize);
    }

    // Convert the options to void**. This is ugly, but that's how
    // cuModuleLoadDataEx works. See initCUDA in the
    // matrixMulDynlinkJIT sample
    // https://github.com/NVIDIA/cuda-samples/blob/master/Samples/0_Introduction/matrixMulDynlinkJIT/matrixMulDynlinkJIT.cpp#L169-L204.
    std::vector<void*> opt_val_voidp(opt_vals.size());
    for (const auto i : c10::irange(opt_vals.size())) {
      auto opt_val = opt_vals.at(i);
      if (std::holds_alternative<int>(opt_val)) {
        // NOLINTNEXTLINE(performance-no-int-to-ptr)
        opt_val_voidp.at(i) = (void*)(int64_t)std::get<int>(opt_val);
      } else if (std::holds_alternative<char*>(opt_val)) {
        opt_val_voidp.at(i) = std::get<char*>(opt_val);
      } else {
        NVF_THROW("Invalid option");
      }
    }

    return std::make_pair(opts, opt_val_voidp);
  }

 private:
  static constexpr int kLogSize = 8196;
  //! cuModuleLoadDataEx options
  std::vector<CUjit_option> options_;
  //! Option parameters
  std::vector<OptionType> option_vals_;
  //! Save log to log_ if true
  bool logging_enabled_ = false;
  std::string log_;
};

// Get the max register count passed as -maxrregcount ptxas
// option. The count is determined based on block sizes, an optional
// heuristic and an environment variable.
std::optional<int64_t> getMaxRegCount(
    std::optional<int64_t> opt_block_size,
    const int64_t max_register_heuristic) {
  // The maximum possible count allowed by ptxas is 255
  constexpr int64_t max_register_limit = 255;

  // Temporary set the max register count to be larger than the
  // limit.
  int64_t max_register = max_register_limit + 1;

  // If the block size is known, set the maximum that at least allows
  // one block to be resident on an SM
  if (opt_block_size.has_value() && opt_block_size.value() > 0) {
    constexpr int64_t block_per_sm = 1;
    max_register = std::min(
        max_register_limit,
        getRegPerThreadGivenThreadsPerSM(
            opt_block_size.value() * block_per_sm));
  }

  // If a heuristic value is given, i.e., max_register_heuristic is
  // less than the limit, use that value if it's smaller than the
  // block-size based count
  if (max_register_heuristic < max_register_limit) {
    max_register = std::min(max_register, max_register_heuristic);
  }

  // Overwrite the count by the environment variable
  if (auto env_count = getNvFuserEnv("MAX_REG_COUNT")) {
    auto env_max_reg_count = std::atoi(env_count);
    NVF_CHECK(
        env_max_reg_count > 0 && env_max_reg_count <= max_register_limit,
        "Invalid max register count specified by NVFUSER_MAX_REG_COUNT: ",
        env_max_reg_count);
    max_register = env_max_reg_count;
  }

  // only need to set this option when we want to limit the register usage,
  // otherwise compiler with cuda-12.7 may use more registers than needed,
  // which may cause lower occupancy and performance regression.
  if (max_register < max_register_limit) {
    return max_register;
  } else {
    return std::optional<int64_t>();
  }
}

// Fill options for nvrtcCompileProgram and cuModuleLoadDataEx
void fillCompileOptions(
    NvrtcCompileDriver& nvrtc_compile_driver,
    CuModuleLoadDataDriver& module_load_driver,
    bool compile_to_sass,
    int64_t major,
    int64_t minor,
    const CompileParams& compile_params,
    std::optional<int64_t> opt_block_size) {
  nvrtc_compile_driver.setOption("--std=c++17");
  if (isOptionEnabled(EnableOption::KernelDebug)) {
    nvrtc_compile_driver.setOption("-G");
  }

  // Suppress warnings for functions that are defined but unused, since we have
  // many unused functions in the preamble.
  nvrtc_compile_driver.setOption("--diag-suppress=177");

  // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
  // which gives better backwards compatibility to work on older driver,
  // (since older driver doesn't necessarily recognize PTX emitted by new
  // toolkit);
  // Meanwhile, for forward compatibility (future device with
  // `unsupported_arch==True`), since SASS are not necessarily compatible,
  // we fallback to PTX instead.
  std::string compute = std::string("--gpu-architecture=") +
      (compile_to_sass ? "sm_" : "compute_") + std::to_string(major) +
      std::to_string(minor);
  if (major == 9) {
    // Hopper MMAs require 90a instead of 90
    compute += "a";
  }
  nvrtc_compile_driver.setOption(compute);

  nvrtc_compile_driver.setOption("-default-device");

  if (isOptionDisabled(DisableOption::Fma)) {
    nvrtc_compile_driver.setOption("--fmad=false");
  } else {
    nvrtc_compile_driver.setOption("--fmad=true");
  }

  // Add line info to generated kernels
  if (isOptionEnabled(EnableOption::KernelLineInfo)) {
    nvrtc_compile_driver.setOption("-lineinfo");
  }

#ifdef NDEBUG
  // Avoid excessive register usage from assertion
  nvrtc_compile_driver.setOption("-DNDEBUG");
#endif

  if (isOptionEnabled(EnableOption::KernelProfile)) {
    nvrtc_compile_driver.setOption("-DNVFUSER_PROFILE_KERNEL");
  }
  if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog) ||
      isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose) ||
      isOptionEnabled(EnableOption::WarnRegisterSpill) ||
      compile_params.enable_ptxas_verbose) {
    // show register usage in compilation log
    if (compile_to_sass) {
      nvrtc_compile_driver.setOption("--ptxas-options");
      nvrtc_compile_driver.setOption("--verbose");
    } else {
      module_load_driver.enableLogging();
    }
  }

  const char* ptxas_opt_level = getNvFuserEnv("JIT_OPT_LEVEL");

  if (ptxas_opt_level) {
    int val = atoi(ptxas_opt_level);
    if (val <= 4 && val >= 0) {
      if (val < 4) {
        TORCH_WARN(
            "ptxas optimization level manually set as ",
            val,
            ", which could negatively affect performance. Try removing env variable NVFUSER_JIT_OPT_LEVEL for optimal performance.");
      }
      if (compile_to_sass) {
        nvrtc_compile_driver.setOption("--ptxas-options");
        nvrtc_compile_driver.setOption("-O" + std::to_string(val));
      } else {
        module_load_driver.setOption(CU_JIT_OPTIMIZATION_LEVEL, val);
      }
    } else {
      TORCH_WARN_ONCE(
          "acceptable range for NVFUSER_JIT_OPT_LEVEL is between 0 and 4, but received ",
          val,
          ", ignoring the option");
    }
  }

  const auto max_register =
      getMaxRegCount(opt_block_size, compile_params.maxrregcount);

  // If the max register count is set
  if (max_register.has_value()) {
    if (compile_to_sass) {
      nvrtc_compile_driver.setOption(
          "--maxrregcount=" + std::to_string(*max_register));
    } else {
      module_load_driver.setOption(CU_JIT_MAX_REGISTERS, (int)*max_register);
    }
  }
}

// Dump ptxas output if register spill is detected
int warnRegisterSpill(const std::string& compile_log) {
  auto getRegisterSpillInfo = [](const std::string& log, const char* subStr) {
    auto it_end =
        std::search(log.begin(), log.end(), subStr, subStr + strlen(subStr)) -
        1;
    auto it_beg = it_end - 1;
    while (!std::isspace(*(it_beg - 1))) {
      it_beg--;
    }
    std::string str(it_beg, it_end);
    return std::stoi(str);
  };

  const char* str_stack = "bytes stack frame";
  const char* str_store = "bytes spill stores";
  const char* str_load = "bytes spill loads";
  int stack_count = getRegisterSpillInfo(compile_log, str_stack);
  int store_count = getRegisterSpillInfo(compile_log, str_store);
  int load_count = getRegisterSpillInfo(compile_log, str_load);
  int allowed_spill = 0;
  if (isOptionEnabled(EnableOption::WarnRegisterSpill)) {
    auto optionArgs = getEnableOptionArguments(EnableOption::WarnRegisterSpill);
    if (!optionArgs.empty()) {
      try {
        allowed_spill = std::stoi(optionArgs[0]);
      } catch (const std::exception& e) {
        debug() << "skip invalid argument for WarnRegisterSpill, arg = "
                << optionArgs[0] << std::endl;
      }
    }
  }
  if (stack_count > allowed_spill || store_count > allowed_spill ||
      load_count > allowed_spill) {
    debug() << "WARNING: Register spill detected\n" << compile_log << std::endl;
  }
  return store_count + load_count;
}

void createNvrtcProgram(
    nvrtcProgram& program,
    const std::string& id,
    const std::string& full_src_code) {
  std::stringstream ss;
  ss << "__tmp_kernel_" << id << ".cu";
  std::string name = ss.str();
  FUSER_PERF_SCOPE("executor_utils::NvrtcCreateProgram");
  NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
      &program, full_src_code.c_str(), name.c_str(), 0, nullptr, nullptr));
}

std::vector<char> compileNvrtcProgramToCubin(const nvrtcProgram& program) {
#if CUDA_VERSION < 11010
  NVF_THROW("SASS not supported in CUDA versions older than 11.1");
#endif

  size_t size = 0;
  NVFUSER_NVRTC_SAFE_CALL(nvrtcGetCUBINSize(program, &size));
  std::vector<char> code(size);
  NVFUSER_NVRTC_SAFE_CALL(nvrtcGetCUBIN(program, code.data()));
  return code;
}

// Returns the name of the dumped file.
std::string dumpCompiledCodeToFile(
    const std::vector<char>& code,
    const std::string& id,
    const std::string& suffix) {
  std::stringstream file_name;
  file_name << "__tmp_kernel_" << id << suffix;
  debug() << "PRINTING: " << file_name.str() << std::endl;
  std::ofstream out(file_name.str());
  NVF_ERROR(out.is_open());
  out.write(code.data(), (std::streamsize)code.size());
  out.close();
  return file_name.str();
}

std::vector<char> compileNvrtcProgramToPtx(const nvrtcProgram& program) {
  size_t size = 0;
  NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(program, &size));
  std::vector<char> code(size);
  NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(program, code.data()));
  return code;
}

// Compile the given source code with the NVRTC compiler driver.
std::unique_ptr<executor_utils::CudaExecutable> compileSource(
    const std::string& full_src_code,
    const std::string& func_name,
    const std::string& id,
    const bool compile_to_sass,
    NvrtcCompileDriver& nvrtc_compile) {
  std::stringstream log;

  nvrtcProgram program; // NOLINT(cppcoreguidelines-init-variables)
  torch::jit::ResourceGuard holdProgram([&] {
    FUSER_PERF_SCOPE("executor_utils::NvrtcDestroyProgram");
    NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&program));
  });

  createNvrtcProgram(program, id, full_src_code);

  NVFUSER_NVRTC_SAFE_CALL(nvrtcAddNameExpression(program, func_name.c_str()));
  log << nvrtc_compile.invoke(program, full_src_code) << std::endl;

  auto compiled_kernel = std::make_unique<executor_utils::CudaExecutable>();
  const char* lowered_kernel_name = nullptr;
  NVFUSER_NVRTC_SAFE_CALL(
      nvrtcGetLoweredName(program, func_name.c_str(), &lowered_kernel_name));
  compiled_kernel->kernel_name = lowered_kernel_name;
  compiled_kernel->compile_log = log.str();

  if (compile_to_sass) {
    compiled_kernel->cubin = compileNvrtcProgramToCubin(program);
    if (isDebugDumpEnabled(DebugDumpOption::Cubin)) {
      compiled_kernel->cubin_filename =
          dumpCompiledCodeToFile(compiled_kernel->cubin, id, ".cubin");
    }
  }

  if (!compile_to_sass || isDebugDumpEnabled(DebugDumpOption::Ptx)) {
    compiled_kernel->ptx = compileNvrtcProgramToPtx(program);
    if (isDebugDumpEnabled(DebugDumpOption::Ptx)) {
      compiled_kernel->ptx_filename =
          dumpCompiledCodeToFile(compiled_kernel->ptx, id, ".ptx");
    }
  }

  return compiled_kernel;
}

// Compile the source if no existing compiled binary is found in KernelDB
std::unique_ptr<executor_utils::CudaExecutable> getCudaExecutable(
    std::optional<std::reference_wrapper<const std::string>> kernel_code,
    const std::string& full_src_code,
    const std::string& func_name,
    const std::string& id,
    const CompileParams& compile_params = CompileParams(),
    std::optional<int64_t> opt_block_size = std::nullopt) {
  FUSER_PERF_SCOPE("executor_utils::NVRTC");

  at::cuda::jit::initializeCudaContext();

  // The above initialization works in some cases. However, it seems to
  // occasionally fail to initialize a primary context. Here we check for that
  // and if we detect that no context exists, we create one manually.
  int device = 0;
  cudaGetDevice(&device);
  if (!at::detail::getCUDAHooks().hasPrimaryContext((c10::DeviceIndex)device)) {
    // CUDA>=12 creates a context when cudaSetDevice is called. However, before
    // cu12, that context is not necessarily created. In that case, we create
    // one here implicitly. See https://github.com/NVIDIA/Fuser/issues/429
    cudaFree(nullptr);
  }

  const auto prop = at::cuda::getCurrentDeviceProperties();

  int64_t major = 0, minor = 0;
  bool compile_to_sass = false;
  queryTargetGPUVersion(prop, major, minor, compile_to_sass);

#if CUDA_VERSION < 11010
  // compile to sass is not allowed prior to CUDA 11.1
  compile_to_sass = false;
#endif

  if (isOptionDisabled(DisableOption::CompileToSass)) {
    compile_to_sass = false;
  }

  NvrtcCompileDriver nvrtc_compile_driver;
  CuModuleLoadDataDriver module_load_driver;

  fillCompileOptions(
      nvrtc_compile_driver,
      module_load_driver,
      compile_to_sass,
      major,
      minor,
      compile_params,
      opt_block_size);

  std::stringstream log;

  if (compile_to_sass) {
    log << "\nCompile options: ";
    for (const auto& opt : nvrtc_compile_driver.options()) {
      log << opt << " ";
    }
    if (opt_block_size.has_value()) {
      log << " ; block size=" << opt_block_size.value() << "\n";
    }
  }

  auto compiled_kernel = std::make_unique<executor_utils::CudaExecutable>();
  const auto compile_args =
      toDelimitedString(nvrtc_compile_driver.options(), " ");

  auto& kernel_db = KernelDb::get();
  const auto use_kernel_db = kernel_db.enabled() && kernel_code.has_value();

  // If the Kernel Query fails, the Kernel is recompiled
  if (!(use_kernel_db &&
        kernel_db.query(
            kernel_code.value(),
            compile_args,
            compiled_kernel->kernel_name,
            (compile_to_sass ? compiled_kernel->cubin
                             : compiled_kernel->ptx)))) {
    compiled_kernel = compileSource(
        full_src_code, func_name, id, compile_to_sass, nvrtc_compile_driver);
    log << compiled_kernel->compile_log << std::endl;
    if (use_kernel_db) {
      auto result = kernel_db.write(
          kernel_code.value(),
          compile_args,
          compiled_kernel->kernel_name,
          (compile_to_sass ? compiled_kernel->cubin : compiled_kernel->ptx));
      if (!result) {
        TORCH_WARN(
            "kernel_db was unable to write kernel: ",
            compiled_kernel->kernel_name);
      }
    }
  }

  log << module_load_driver.invoke(
             compiled_kernel->module,
             (compile_to_sass ? compiled_kernel->cubin.data()
                              : compiled_kernel->ptx.data()))
      << std::endl;
  compiled_kernel->compile_log = log.str();
  compiled_kernel->compile_args = compile_args;

  if (isOptionEnabled(EnableOption::WarnRegisterSpill) ||
      compile_params.enable_ptxas_verbose) {
    compiled_kernel->register_spills =
        warnRegisterSpill(compiled_kernel->compile_log);
  }

  NVFUSER_CUDA_SAFE_CALL(cuModuleGetFunction(
      &(compiled_kernel->function),
      compiled_kernel->module,
      compiled_kernel->kernel_name.c_str()));

  // Store block size used to generate compile arguments
  if (opt_block_size.has_value()) {
    compiled_kernel->block_size = opt_block_size.value();
  }

  return compiled_kernel;
}

std::unique_ptr<executor_utils::CudaExecutable> getCudaExecutable(
    const serde::CudaKernel* buffer,
    const CompileParams& compile_params) {
  FUSER_PERF_SCOPE("executor_utils::serde_NVRTC");

  NVF_ERROR(buffer != nullptr, "serde::CudaKernel is nullptr.");

  // Deserialize flatbuffer into CudaExecutable
  auto compiled_kernel = std::make_unique<executor_utils::CudaExecutable>();
  compiled_kernel->kernel_name = buffer->kernel_name()->str();
  compiled_kernel->compile_args = buffer->compile_args()->str();
  compiled_kernel->block_size = buffer->block_size();

  if (buffer->cubin() != nullptr) {
    compiled_kernel->cubin.reserve(buffer->cubin()->size());
    std::copy(
        buffer->cubin()->begin(),
        buffer->cubin()->end(),
        std::back_inserter(compiled_kernel->cubin));
    compiled_kernel->cubin_filename = buffer->cubin_filename()->str();
  }

  if (buffer->ptx() != nullptr) {
    compiled_kernel->ptx.reserve(buffer->ptx()->size());
    std::copy(
        buffer->ptx()->begin(),
        buffer->ptx()->end(),
        std::back_inserter(compiled_kernel->ptx));
    compiled_kernel->ptx_filename = buffer->ptx_filename()->str();
  }

  at::cuda::jit::initializeCudaContext();

  // The above initialization works in some cases. However, it seems to
  // occasionally fail to initialize a primary context. Here we check for that
  // and if we detect that no context exists, we create one manually.
  int device = 0;
  cudaGetDevice(&device);
  if (!at::detail::getCUDAHooks().hasPrimaryContext((c10::DeviceIndex)device)) {
    // CUDA>=12 creates a context when cudaSetDevice is called. However, before
    // cu12, that context is not necessarily created. In that case, we create
    // one here implicitly. See https://github.com/NVIDIA/Fuser/issues/429
    cudaFree(nullptr);
  }

  const auto prop = at::cuda::getCurrentDeviceProperties();

  // Generate compile args and compare against saved args in compiled_kernel
  NvrtcCompileDriver nvrtc_compile_driver;
  CuModuleLoadDataDriver module_load_driver;

  int64_t major = 0, minor = 0;
  bool compile_to_sass = false;
  queryTargetGPUVersion(prop, major, minor, compile_to_sass);

  std::optional<int64_t> opt_block_size;
  if (compiled_kernel->block_size >= -1) {
    opt_block_size = compiled_kernel->block_size;
  }

  fillCompileOptions(
      nvrtc_compile_driver,
      module_load_driver,
      compile_to_sass,
      major,
      minor,
      compile_params,
      opt_block_size);

  const auto latest_compile_args =
      toDelimitedString(nvrtc_compile_driver.options(), " ");
  NVF_ERROR(
      latest_compile_args == compiled_kernel->compile_args,
      "The compile arguments for the serialized cuda kernel does not ",
      "match the latest generated compile args.\t",
      latest_compile_args,
      "\t",
      compiled_kernel->compile_args);

  NVF_ERROR(
      !compile_to_sass || !compiled_kernel->cubin.empty(),
      "Expected compiled cubin after deserializing CudaExecutable.");

  NVF_ERROR(
      compile_to_sass || !compiled_kernel->ptx.empty(),
      "Expected compiled ptx after deserializing CudaExecutable.");

  std::stringstream log;
  log << module_load_driver.invoke(
             compiled_kernel->module,
             (compile_to_sass ? compiled_kernel->cubin.data()
                              : compiled_kernel->ptx.data()))
      << std::endl;
  compiled_kernel->compile_log = log.str();

  NVFUSER_CUDA_SAFE_CALL(cuModuleGetFunction(
      &(compiled_kernel->function),
      compiled_kernel->module,
      compiled_kernel->kernel_name.c_str()));

  return compiled_kernel;
}

static const char* defineIndexType(PrimDataType index_type) {
  if (index_type == DataType::Int32) {
    return "typedef int nvfuser_index_t;\n";
  } else if (index_type == DataType::Int) {
    return "typedef int64_t nvfuser_index_t;\n";
  } else {
    NVF_THROW("invalid indexing type: ", index_type);
  }
}

static const char* defineTypes() {
  return R"(
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = short int;
using uint16_t = unsigned short int;
using int32_t = int;
using uint32_t = unsigned int;
using int64_t = long long int;
using uint64_t = unsigned long long int;

// Modified from cuda.h
struct TensorMap {
  alignas(64)
  uint64_t opaque[16];
};
)";
}

static const std::string& defineStdComplex() {
  static std::string result = std::string(R"ESCAPE(
#ifdef __NVCC__
#include <complex>
#endif // __NVCC__
)ESCAPE");
  return result;
}

// When executing nvFuser with: NVFUSER_EXTERNAL_SRC=file1.cu,file2.cu
// This function retrieves structured code from the specified files.
// The files should be comma-separated, and their order corresponds to the
// fusion_id order. If the provided number of files is fewer than the fusion
// segments, the function will resort to the available files in sequence
// and issue a warning.
std::string getStructuredCodeFromExternalFiles(const int64_t fusion_id) {
  auto external_code_path = getNvFuserEnv("EXTERNAL_SRC");
  if (!external_code_path) {
    return "";
  }
  std::string all_external_code_paths(external_code_path);
  if (all_external_code_paths.empty() || fusion_id < 1) {
    return "";
  }
  auto getExternalCodeFile =
      [fusion_id](const std::string& input) -> std::string {
    std::stringstream ss(input);
    std::string token;
    int64_t count = 0;
    while (std::getline(ss, token, ',')) {
      if (++count == fusion_id) {
        return token;
      }
    }
    debug()
        << "Didn't find requested external source code. Will use generated code!\n"
        << "Number of source code files should equal the number of fusion segments.\n"
        << "External source code filenames should be delineated with commas, e.g.: file1.cu,file2.cu.\n";
    return "";
  };

  std::string single_code_path = getExternalCodeFile(all_external_code_paths);
  if (single_code_path.empty()) {
    return "";
  }
  std::ifstream cuda_src(single_code_path);
  if (!cuda_src.is_open()) {
    debug() << "Failed to open external source file: " << single_code_path
            << std::endl;
    return "";
  }
  debug() << "--------> Compiling external CUDA code: " << single_code_path
          << std::endl;

  std::stringstream buffer;
  buffer << cuda_src.rdbuf();
  return buffer.str();
}

bool requiresDisabledParamCache(const Fusion* fusion) {
  std::vector<Val*> output_extents;
  for (auto out : fusion->outputs()) {
    const auto logical_domain = out->as<TensorView>()->getLogicalDomain();
    // walking through outputs to see if output shapes are dependent on
    // non-tensor inputs. For which case, we should have disabled output
    // allocation, since the caching id only looks at tensor shapes.
    // See issue https://github.com/csarofeen/pytorch/issues/2002
    for (const auto id : logical_domain) {
      Val* extent = nullptr;
      if (id->isReduction() || id->isStride() || id->isDeviceDim()) {
        continue;
      } else if (id->isBroadcast() && id->hasExpandedExtent()) {
        extent = id->expandedExtent();
      } else {
        extent = id->extent();
      }
      output_extents.emplace_back(extent);
    }
  }

  VectorOfUniqueEntries<Val*> input_dependencies;
  for (auto inp : InputsOf::outputs(output_extents)) {
    if (inp->isFusionInput()) {
      input_dependencies.pushBack(inp);
    }
  }
  if (std::any_of(
          input_dependencies.begin(), input_dependencies.end(), [](Val* inp) {
            return inp->isScalar();
          })) {
    return true;
  } else if (!input_dependencies.empty()) {
    VectorOfUniqueEntries<Expr*> all_exprs(DependencyCheck::getAllExprsBetween(
        input_dependencies.set(), output_extents));

    VectorOfUniqueEntries<Val*> meta_data_outputs;
    for (auto meta_data_op :
         ir_utils::filterByType<GetMetaData>(all_exprs.vector())) {
      meta_data_outputs.pushBack(
          meta_data_op->outputs().begin(), meta_data_op->outputs().end());
    }

    VectorOfUniqueEntries<Expr*> before_meta_data_exprs(
        DependencyCheck::getAllExprsBetween(
            input_dependencies.set(), meta_data_outputs.vector()));

    VectorOfUniqueEntries<Expr*> after_meta_data_exprs(
        DependencyCheck::getAllExprsBetween(
            meta_data_outputs.set(), output_extents));

    auto subtraction = all_exprs;
    subtraction = subtraction.computeSubtract(before_meta_data_exprs);
    subtraction = subtraction.computeSubtract(after_meta_data_exprs);
    if (!subtraction.empty()) {
      return true;
    }
  }
  return false;
}

std::string _getStructuredCode(
    const std::string& kernel_str,
    PrimDataType index_type,
    std::string kernel_name) {
  // generating cuda code;
  std::string code = "";
  code += defineStdComplex();
  code += std::string("namespace {\n") + defineTypes() +
      defineIndexType(index_type) + kernelPreamble() + kernel_str + "}\n";

  if (isDebugDumpEnabled(DebugDumpOption::CudaKernel)) {
    debug() << "\n======= Codegen output for kernel: " << kernel_name
            << " =======\n\n"
            << kernel_str << "\n======================================\n\n";
  } else if (isDebugDumpEnabled(DebugDumpOption::CudaFull)) {
    debug() << "\n======= Codegen output for kernel: " << kernel_name
            << " =======\n\n"
            << code << "\n======================================\n\n";
  }
  if (isDebugDumpEnabled(DebugDumpOption::CudaToFile)) {
    std::stringstream file_name;
    file_name << "__tmp_" << kernel_name << ".cu";
    debug() << "PRINTING: " << file_name.str() << std::endl;
    std::ofstream out(file_name.str());
    out << code << std::endl;
    out.close();
  }

  return code;
}

} // namespace

NVF_API CompiledKernel::CompiledKernel(
    Fusion* fusion,
    CompileParams compile_params,
    c10::Device device,
    SchedulerType scheduler_type,
    int64_t fusion_id,
    int64_t concrete_id,
    int64_t runtime_id,
    int64_t group_id,
    const std::vector<std::function<void(GpuLower*)>>& pre_lowering_hooks,
    const std::vector<std::function<void(kir::Kernel*)>>& post_lowering_hooks)
    : compile_params_(compile_params),
      scheduler_type_(scheduler_type),
      fusion_id_(fusion_id),
      concrete_id_(concrete_id),
      runtime_id_(runtime_id),
      group_id_(group_id),
      lowered_(std::make_unique<GpuLower>(fusion, compile_params)),
      device_(device) {
  FUSER_PERF_SCOPE("CompiledKernel::CompiledKernel");
  // TODO: No hooks can be sent because this is in the constructor
  for (const auto& hook : pre_lowering_hooks) {
    hook(lowered_.get());
  }
  lowered_->run();
  for (const auto& hook : post_lowering_hooks) {
    hook(lowered_->kernel());
  }
}

NVF_API CompiledKernel::CompiledKernel(
    Fusion* fusion,
    CompileParams compile_params,
    c10::Device device,
    SchedulerType scheduler_type,
    int64_t fusion_id,
    int64_t concrete_id,
    int64_t runtime_id,
    int64_t group_id)
    : CompiledKernel(
          fusion,
          compile_params,
          device,
          scheduler_type,
          fusion_id,
          concrete_id,
          runtime_id,
          group_id,
          {},
          {}) {}

void CompiledKernel::compile(int64_t block_size) {
  FUSER_PERF_SCOPE("CompiledKernel::compile");

  NVF_ERROR(
      !fusion()->outputs().empty(),
      "No output found for this kernel, aborting.");

  // Parameter cache doesn't cache on input scalars, so if one is used as a
  // dynamic input size of a tensor the cache doesn't work correctly. This
  // should be enabled in the cache, but since it's not, for now we will disable
  // it under these circumstances.
  disable_parameter_cache_ = requiresDisabledParamCache(fusion());

  if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
    fusion()->print();
  } else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) {
    fusion()->printMath();
  }

  if (isDebugDumpEnabled(DebugDumpOption::FusionIrGraph)) {
    std::stringstream file_name;
    file_name << "__tmp_fusion_ir_graph_" << kernel_id_ << ".dot";
    IrGraphGenerator::print(
        fusion(),
        file_name.str().c_str(),
        IrGraphGenerator::DetailLevel::ComputeOnly);
  }

  c10::DeviceGuard dg(device_);

  NVF_ERROR(device_.is_cuda(), "Provided device to CUDA fuser is the CPU.");
  auto properties = at::cuda::getDeviceProperties(device_.index());
  // TODO: These properties should be set as part of the constructor so that
  // it can be const
  warp_size_ = properties->warpSize;
  kir::Kernel* kernel = lowered_->kernel();

  createKernelId();
  setUsedTVs();

  if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) {
    kernel->print();
  }

  if (isDebugDumpEnabled(DebugDumpOption::BankConflictInfo)) {
    auto bank_conflict_info = getBankConflictInfo(kernel);
    if (bank_conflict_info.empty()) {
      debug() << "===== No bank confliction =====" << std::endl;
    } else {
      debug() << "======= Bank confliction =======" << std::endl;
      for (auto info : bank_conflict_info) {
        debug() << "Expr: " << info.first->toString() << std::endl;
        auto conflict = info.second;
        if (conflict.first > 1) {
          debug() << "input conflict: " << conflict.first << " way, ";
        }
        if (conflict.second > 1) {
          debug() << "output conflict: " << conflict.second << " way";
        }
        debug() << std::endl;
      }
      debug() << "================================" << std::endl;
    }
  }

  kernel_code_ = codegen::generateCudaKernel(kernel, kernelName(), block_size);

  // If NVFUSER_EXTERNAL_SRC is set, utilize the external source code.
  // If the loaded external source code is empty, revert to the default
  // codegen. The external_structured_code is moved to structured_code and
  // explicitly cleared to avoid use-after-move scenarios. Note: we index
  // these with getGlobalFusionCount() instead of fusion_id_ in order to match
  // the numbering of files output with NVFUSER_DUMP=cuda_to_file
  auto structured_code =
      getStructuredCodeFromExternalFiles(getGlobalFusionCount());
  if (structured_code.empty()) {
    structured_code = getStructuredCode();
  }

  const kir::KernelSummary& kernel_summary = kernel->summary();

  std::pair<int64_t, int64_t> target_arch;
  bool compile_to_sass = false;
  queryTargetGPUVersion(
      properties,
      std::ref(target_arch.first),
      std::ref(target_arch.second),
      compile_to_sass);

  NVF_CHECK(
      target_arch >= kernel_summary.min_device_version,
      "Target compute capability is ",
      target_arch.first,
      ".",
      target_arch.second,
      " but this fusion requires at least ",
      kernel_summary.min_device_version.first,
      ".",
      kernel_summary.min_device_version.second,
      ". Reason: ",
      kernel_summary.min_device_version_reason);

  // We currently shouldn't allocate any more shared mem
  //  tensors statically but could keep this path if
  //  needed in later development.
  if (!kernel_summary.static_smem_allocations.empty()) {
    ExpressionEvaluator static_evaluator;
    const auto static_smem_size = computeSharedMemory(
        static_evaluator,
        kernel_summary.static_smem_allocations,
        kernel->indexType());
    NVF_ERROR(
        static_smem_size < max_static_smem_,
        "The static shared memory allocation is larger than available memory.");
  }

  if (kernel_summary.has_dynamic_local_memory_allocations) {
    std::stringstream ss;
    ss << "Allocations must be based on constant integers for local memory. However, found: ";
    for (auto alloc : kernel_summary.dynamic_lmem_allocations) {
      ss << alloc->buffer()->toString() << ", ";
    }
    ss << " have dynamic allocations but are placed in local memory.";
    NVF_THROW(ss.str());
  }

  NVF_ERROR(block_size > 0, "launch param inferred block size < 0");

  // Basically setting high water mark as 1 when we don't provide args for
  // compilation, it will just generate a kernel that gets ditched at the
  // first run - not great. We should have better heuristics.
  block_size_high_water_mark_ =
      std::max<int64_t>(block_size, block_size_high_water_mark_);
  maxrregcount_high_water_mark_ = compile_params_.maxrregcount;
  compiled_kernel_ = getCudaExecutable(
      kernel_code_,
      structured_code,
      kernelName(),
      kernel_id_,
      compile_params_,
      block_size);

  NVF_ERROR(validKernelId(), "Invalid kernel id for CompiledKernel.");

  if (isDebugDumpEnabled(DebugDumpOption::Sass)) {
    debug() << disassembledKernelSASS() << std::endl;
  }
}

std::string CompiledKernel::getStructuredCode() const {
  return _getStructuredCode(
      kernelString(), kernel()->indexType(), kernelName());
}

std::string CompiledKernel::disassembledKernelSASS() const {
  return disassembleBinary(compiled_kernel_->cubin, "-fun 1 -c");
}

void CompiledKernel::createKernelId() {
  NVF_ERROR(fusion_id_ > -1, "Invalid fusion_id.");
  NVF_ERROR(concrete_id_ > -1, "Invalid concrete_id.");
  NVF_ERROR(runtime_id_ > -1, "Invalid runtime_id.");
  NVF_ERROR(group_id_ > -1, "Invalid group_id");
  ++global_fusion_count_;
  std::stringstream ss;
  if (isOptionEnabled(EnableOption::StaticFusionCount)) {
    ss << global_fusion_count_.load();
  } else {
    ss << toString(scheduler_type_);
    ss << "_f" << fusion_id_;
    ss << "_c" << concrete_id_;
    ss << "_r" << runtime_id_;
    ss << "_g" << group_id_;
  }
  kernel_id_ = ss.str();
}

void RtcKernel::compile(
    const std::string& code,
    const std::string& name,
    bool structured,
    PrimDataType index_type,
    int64_t device_index) {
  FUSER_PERF_SCOPE("RtcKernel::compile");
  NVF_ERROR(
      index_type == PrimDataType::Int || index_type == PrimDataType::Int32 ||
          "Invalid index type: ",
      index_type);
  device_index_ = device_index;

  std::string scode;
  if (!structured) {
    scode = _getStructuredCode(code, index_type, name);
  } else {
    scode = code;
  }
  CompileParams cp;
  cp.device =
      c10::Device(c10::DeviceType::CUDA, (c10::DeviceIndex)device_index_);
  compiled_kernel_ = getCudaExecutable(std::nullopt, scode, name, "0", cp);
}

float RtcKernel::run(
    const LaunchParams& launch_params,
    const std::vector<at::Tensor>& args,
    PrimDataType index_type) {
  FUSER_PERF_SCOPE("RtcKernel::run");

  auto device =
      c10::Device(c10::DeviceType::CUDA, (c10::DeviceIndex)device_index_);
  c10::DeviceGuard dg(device);
  auto stream = at::cuda::getCurrentCUDAStream();

  cudaEvent_t start_event = {};
  cudaEvent_t finish_event = {};

  NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&start_event));
  NVFUSER_CUDA_RT_SAFE_CALL(cudaEventCreate(&finish_event));

  NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(start_event, stream));

  std::vector<std::vector<std::byte>> data;
  std::vector<void*> pointers;

  for (const auto& input : args) {
    auto dtype =
        std::get<PrimDataType>(aten_to_data_type(input.scalar_type()).type);
    DataType metadata_type = globalTensorMetaData(dtype, input.dim());

    std::shared_ptr<Struct> struct_ = std::make_shared<TensorMetaData>();
    TensorMetaData* metadata = (TensorMetaData*)struct_.get();
    metadata->dtype = dtype;
    metadata->data = input.data_ptr();
    metadata->logical_size = input.sizes();
    metadata->logical_stride = input.strides();
    metadata->alloc_size = input.sizes();
    metadata->alloc_stride = input.strides();

    data.emplace_back(polymorphicValueToBytes(
        PolymorphicValue(std::move(struct_)), metadata_type, index_type));
    pointers.emplace_back(data.back().data());
  }

  NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel(
      compiled_kernel_->function,
      launch_params.gdimx(),
      launch_params.gdimy(),
      launch_params.gdimz(),
      launch_params.bdimx(),
      launch_params.bdimy(),
      launch_params.bdimz(),
      launch_params.smem(),
      stream,
      pointers.data(),
      nullptr));

  NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(finish_event, stream));
  NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(start_event));
  NVFUSER_CUDA_RT_SAFE_CALL(cudaEventSynchronize(finish_event));

  float kernel_time_ms = 0;
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaEventElapsedTime(&kernel_time_ms, start_event, finish_event));
  NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(start_event));
  NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(finish_event));

  return kernel_time_ms;
}

void CompiledKernel::deserialize(const serde::KernelExecutor* buffer) {
  // Initialize CompileOptions
  c10::DeviceGuard dg(device_);

  // Initialize internal fields
  maxrregcount_high_water_mark_ = buffer->maxrregcount_high_water_mark();
  warp_size_ = buffer->warp_size();
  kernel...

return lowered_->kernel();
}

Fusion* fusion() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this isn't any more useful than kernel() so I removed this in #3725.


// function to query whether a `CompiledKernel` has a compiled kernel to
// execute
bool hasCompiledKernel() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/NVIDIA/Fuser/pull/3725/files#diff-31a5ef26405804394f573b42c15512e1fb87f930fe7e5bfd95ad4034d867c30fR147 has merged isCompiled and hasCompiledKernel. Having both used to be important for a FusionExecutor which may or may not be a kernel executor -- no longer after your ExecutorAbstract work.

}
// Lowered is needed to compute launch parameters as it uses the CA map. We
// could modify that, but simply generating that part first.
compiled_kernel_ = std::make_unique<CompiledKernel>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand the supposed responsibility between KernelExecutor::compile and CompiledKernel::compile. With this PR, KernelExecutor::compile appears to become a thin wrapper of CompiledKernel::compile that merely does profiling and overrides some compilation/launch parameters. Do you plan to get rid of KernelExecutor::compile so FusionKernelRuntime doesn't need to create a KernelExecutor to compile?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants