Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EmbeddingFwdOp node with same functionality as F.embedding #3649

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Dec 26, 2024

This PR adds an EmbeddingFwdOp with same functionality as F.embedding.

  1. I am not using take_along_axis. F.embedding allows optional parameters like max_norm, padding_idx which would require further processing if implemented using take_along_axis. So I defaulted to creating a new node to guarantee performance parity.
  2. Thunder uses prims.EMBEDDING if the optional parameters padding_idx/max_norm are specified, else it uses prims.TAKE. This prevents nvfuser from consuming embedding operator in the other cases. Hence, in Thunder, nvfuser will also directly execute ltorch.embedding. This will require a separate backward API to consume ltorch.embedding_backward and cannot reuse grad rules for prims.EMBEDDING. Hence, the EmbeddingFwdOp naming instead of EmbeddingOp.
  3. I first plan to plumb the fwd only embedding support in Thunder while I draft the backward node which should be very similar. Thunder reviews may bring up another way of implementing this support.

@Priya2698 Priya2698 changed the title EmbeddingOp node with same functionality as F.embedding EmbeddingFwdOp node with same functionality as F.embedding Jan 16, 2025
Copy link

github-actions bot commented Jan 16, 2025

PR Reviewer Guide 🔍

(Review updated until commit af25ce0)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Potential Logic Change

The new EmbeddingFwdOp node has been added with the same functionality as F.embedding. This change may introduce potential logic changes, especially with regards to function signatures. Reviewers should verify that the new node behaves as expected and does not introduce any regressions.

EmbeddingFwdOp::EmbeddingFwdOp(
    IrBuilderPasskey passkey,
    TensorView* output,
    TensorView* input,
    TensorView* weight,
    Val* padding_idx,
    Val* max_norm,
    Val* norm_type,
    Val* scale_grad_by_freq,
    Val* sparse)
    : Expr(passkey) {
  addOutput(output);

  addInput(input);
  addInput(weight);
  addInput(norm_type);
  addInput(scale_grad_by_freq);
  addInput(sparse);
  if (padding_idx != nullptr) {
    addInput(padding_idx);
    addDataAttribute(true);
  } else {
    addDataAttribute(false);
  }
  if (max_norm != nullptr) {
    addInput(max_norm);
    addDataAttribute(true);
  } else {
    addDataAttribute(false);
  }
}

NVFUSER_DEFINE_CLONE_AND_CREATE(EmbeddingFwdOp)

std::string EmbeddingFwdOp::toString(int indent_size) const {
  std::stringstream ss;
  indent(ss, indent_size) << out()->toString() << ",\n";
  indent(ss, indent_size + 1) << " = embedding(" << in()->toString() << ",\n";
  indent(ss, indent_size + 1) << "          " << weight()->toString() << ",\n";
  if (padding_idx() != nullptr) {
    indent(ss, indent_size + 1)
        << "          padding_idx = " << padding_idx()->toString() << ",\n";
  }
  if (max_norm() != nullptr) {
    indent(ss, indent_size + 1)
        << "          max_norm = " << max_norm()->toString() << ",\n";
  }
  indent(ss, indent_size + 1)
      << "          norm_type = " << norm_type()->toString() << ",\n";
  indent(ss, indent_size + 1)
      << "          scale_grad_by_freq = "
      << scale_grad_by_freq()->toInlineString() << ",\n";
  indent(ss, indent_size + 1)
      << "          sparse = " << sparse()->toInlineString() << ")\n";
  return ss.str();
}

std::string EmbeddingFwdOp::toInlineString(int indent_size) const {
  NVF_CHECK(false, "Tensor op can not be printed inline");
}

std::vector<PolymorphicValue> EmbeddingFwdOp::evaluate(
    const ExpressionEvaluator& ee,
    const std::vector<PolymorphicValue>& inputs) const {
  auto input = inputs.at(0).as<at::Tensor>();
  auto weight = inputs.at(1).as<at::Tensor>();
  auto norm_type = inputs.at(2).as<double>();
  auto scale_grad_by_freq = inputs.at(3).as<bool>();
  auto sparse = inputs.at(4).as<bool>();
  std::optional<int64_t> padding_idx = std::nullopt;
  if (has_padding_idx()) {
    padding_idx = inputs.at(5).as<int64_t>();
  }
  std::optional<double> max_norm = std::nullopt;
  if (has_max_norm()) {
    auto idx = 5 + has_padding_idx();
    max_norm = inputs.at(idx).as<double>();
  }

  namespace F = torch::nn::functional;
  return {F::embedding(
      input,
      weight,
      F::EmbeddingFuncOptions()
          .padding_idx(padding_idx)
          .max_norm(max_norm)
          .norm_type(norm_type)
          .scale_grad_by_freq(scale_grad_by_freq)
          .sparse(sparse))};
}
} // namespace nvfuser
Potential Logic Change

The embedding_fwd function has been added, which creates a new EmbeddingFwdOp node. Reviewers should verify that this function behaves as expected and does not introduce any regressions.

TensorView* embedding_fwd(
    TensorView* input,
    TensorView* weight,
    Val* padding_idx,
    Val* max_norm,
    Val* norm_type,
    Val* scale_grad_by_freq,
    Val* sparse) {
  auto input_domain = TensorDomain::noReductions(input->getLogicalDomain());
  auto weight_domain = TensorDomain::noReductions(weight->getLogicalDomain());
  NVF_CHECK(
      !input_domain.empty(),
      "Expected input to be atleast 1D, got: ",
      input_domain.size());
  NVF_CHECK(
      weight_domain.size() == 2,
      "Expected weight to be 2D, got: ",
      weight_domain.size());

  NVF_CHECK(
      !padding_idx || padding_idx->isScalar(),
      "Expected padding_idx to be a scalar int.");
  NVF_CHECK(
      !max_norm || max_norm->isScalar(),
      "Expected max_norm to be a scalar double.");
  NVF_CHECK(
      !norm_type || norm_type->isScalar(),
      "Expected scale to be a scalar double.");
  NVF_CHECK(
      !scale_grad_by_freq || scale_grad_by_freq->isScalar(),
      "Expected scale to be a scalar bool.");
  NVF_CHECK(
      !sparse || sparse->isScalar(), "Expected scale to be a scalar bool.");

  auto ndims_out = input_domain.size() + 1;
  std::vector<IterDomain*> out_domain(ndims_out, nullptr);

  for (auto idx : c10::irange(ndims_out - 1)) {
    out_domain[idx] = ops::newOutputIterDomain({input_domain[idx]});
  }
  out_domain[ndims_out - 1] = ops::newOutputIterDomain({weight_domain.back()});
  TensorDomain* out_td = IrBuilder::create<TensorDomain>(
      out_domain, TensorDomain::getContiguityFilledWith(out_domain, true));
  TensorView* output = IrBuilder::create<TensorView>(out_td, weight->dtype());

  if (norm_type == nullptr) {
    norm_type = IrBuilder::create<Val>(2.0, DataType::Double);
  }
  if (scale_grad_by_freq == nullptr) {
    scale_grad_by_freq = IrBuilder::create<Val>(false, DataType::Bool);
  }
  if (sparse == nullptr) {
    sparse = IrBuilder::create<Val>(false, DataType::Bool);
  }
  IrBuilder::create<EmbeddingFwdOp>(
      output,
      input,
      weight,
      padding_idx,
      max_norm,
      norm_type,
      scale_grad_by_freq,
      sparse);

  return output;
}
Potential Logic Change

The embedding_fwd function has been bound to the Python frontend. Reviewers should verify that this binding behaves as expected and does not introduce any regressions.

nvf_ops.def(
    "embedding_fwd",
    [](FusionDefinition::Operators& self,
       Tensor input,
       Tensor weight,
       std::optional<Scalar> padding_idx,
       std::optional<Scalar> max_norm,
       std::optional<Scalar> norm_type,
       std::optional<Scalar> scale_grad_by_freq,
       std::optional<Scalar> sparse) -> decltype(auto) {
      FUSER_PERF_SCOPE("Operators.embedding_fwd");
      NVF_CHECK(
          self.validUse(), "Attempting to add to a completed definition!");
      FusionDefinition* fd = self.fusion_definition;
      size_t ndims = input.dims + 1;
      Tensor output = fd->defineTensor(/*dims=*/ndims);

      auto padding_idx_state = padding_idx.has_value()
          ? fd->recordingState(padding_idx.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);
      auto max_norm_state = max_norm.has_value()
          ? fd->recordingState(max_norm.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);
      auto norm_type_state = norm_type.has_value()
          ? fd->recordingState(norm_type.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);
      auto scale_grad_by_freq_state = scale_grad_by_freq.has_value()
          ? fd->recordingState(scale_grad_by_freq.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);
      auto sparse_state = sparse.has_value()
          ? fd->recordingState(sparse.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);

      fd->defineRecord(new EmbeddingFwdOpRecord(
          {fd->recordingState(input()),
           fd->recordingState(weight()),
           padding_idx_state,
           max_norm_state,
           norm_type_state,
           scale_grad_by_freq_state,
           sparse_state},
          {fd->recordingState(output())}));
      return output;
    },
    py::arg("input"),
    py::arg("weight"),
    py::arg("padding_idx").none(true) = py::none(),
    py::arg("max_norm").none(true) = py::none(),
    py::arg("norm_type").none(true) = py::none(),
    py::arg("scale_grad_by_freq").none(true) = py::none(),
    py::arg("sparse").none(true) = py::none(),
    py::return_value_policy::reference);

@Priya2698
Copy link
Collaborator Author

!test


constexpr int64_t n = 5, s = 2;

TEST_F(EmbeddingTest, EmbeddingFwdNode) {
Copy link
Collaborator

@protonu protonu Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's possible to add a check to verify the output of the toString method as well.

}
std::optional<double> max_norm = std::nullopt;
if (has_max_norm()) {
auto idx = 5 + has_padding_idx();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: having this free 5 bothers me a little bit, but not sure what would be better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be ideal, however, we are fetching the previous variables based on fixed indices as well. The position of the variables is constant so it should be safe.

if (norm_type == nullptr) {
norm_type = IrBuilder::create<Val>(2.0, DataType::Double);
}
if (scale_grad_by_freq == nullptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we use IrContainer::falseVal() here?
input->fusion()->falseVall() or get the current fusion and call the function on it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar pattern below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants