Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EmbeddingFwdOp node with same functionality as F.embedding #3649

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Dec 26, 2024

This PR adds an EmbeddingFwdOp with same functionality as F.embedding.

  1. I am not using take_along_axis. F.embedding allows optional parameters like max_norm, padding_idx which would require further processing if implemented using take_along_axis. So I defaulted to creating a new node to guarantee performance parity.
  2. Thunder uses prims.EMBEDDING if the optional parameters padding_idx/max_norm are specified, else it uses prims.TAKE. This prevents nvfuser from consuming embedding operator in the other cases. Hence, in Thunder, nvfuser will also directly execute ltorch.embedding. This will require a separate backward API to consume ltorch.embedding_backward and cannot reuse grad rules for prims.EMBEDDING. Hence, the EmbeddingFwdOp naming instead of EmbeddingOp.
  3. I first plan to plumb the fwd only embedding support in Thunder while I draft the backward node which should be very similar. Thunder reviews may bring up another way of implementing this support.

@Priya2698 Priya2698 changed the title EmbeddingOp node with same functionality as F.embedding EmbeddingFwdOp node with same functionality as F.embedding Jan 16, 2025
Copy link

github-actions bot commented Jan 16, 2025

PR Reviewer Guide 🔍

(Review updated until commit f50fe0c)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Potential Logic Change

The EmbeddingFwdOp constructor has been modified to accept additional parameters. Review the logic to ensure it aligns with the intended functionality.

EmbeddingFwdOp::EmbeddingFwdOp(
    IrBuilderPasskey passkey,
    TensorView* output,
    TensorView* input,
    TensorView* weight,
    Val* padding_idx,
    Val* max_norm,
    Val* norm_type,
    Val* scale_grad_by_freq,
    Val* sparse)
    : Expr(passkey) {
  addOutput(output);

  addInput(input);
  addInput(weight);
  addInput(norm_type);
  addInput(scale_grad_by_freq);
  addInput(sparse);
  if (padding_idx != nullptr) {
    addInput(padding_idx);
    addDataAttribute(true);
  } else {
    addDataAttribute(false);
  }
  if (max_norm != nullptr) {
    addInput(max_norm);
    addDataAttribute(true);
  } else {
    addDataAttribute(false);
  }
}
Function Signature Change

The embedding_fwd function signature has been updated to include new parameters. Verify that the changes are consistent with the function's purpose.

TensorView* embedding_fwd(
    TensorView* input,
    TensorView* weight,
    Val* padding_idx,
    Val* max_norm,
    Val* norm_type,
    Val* scale_grad_by_freq,
    Val* sparse) {
  auto input_domain = TensorDomain::noReductions(input->getLogicalDomain());
  auto weight_domain = TensorDomain::noReductions(weight->getLogicalDomain());
  NVF_CHECK(
      !input_domain.empty(),
      "Expected input to be atleast 1D, got: ",
      input_domain.size());
  NVF_CHECK(
      weight_domain.size() == 2,
      "Expected weight to be 2D, got: ",
      weight_domain.size());

  NVF_CHECK(
      !padding_idx || padding_idx->isScalar(),
      "Expected padding_idx to be a scalar int.");
  NVF_CHECK(
      !max_norm || max_norm->isScalar(),
      "Expected max_norm to be a scalar double.");
  NVF_CHECK(
      !norm_type || norm_type->isScalar(),
      "Expected scale to be a scalar double.");
  NVF_CHECK(
      !scale_grad_by_freq || scale_grad_by_freq->isScalar(),
      "Expected scale to be a scalar bool.");
  NVF_CHECK(
      !sparse || sparse->isScalar(), "Expected scale to be a scalar bool.");

  auto ndims_out = input_domain.size() + 1;
  std::vector<IterDomain*> out_domain(ndims_out, nullptr);

  for (auto idx : c10::irange(ndims_out - 1)) {
    out_domain[idx] = ops::newOutputIterDomain({input_domain[idx]});
  }
  out_domain[ndims_out - 1] = ops::newOutputIterDomain({weight_domain.back()});
  TensorDomain* out_td = IrBuilder::create<TensorDomain>(
      out_domain, TensorDomain::getContiguityFilledWith(out_domain, true));
  TensorView* output = IrBuilder::create<TensorView>(out_td, weight->dtype());

  if (norm_type == nullptr) {
    norm_type = IrBuilder::create<Val>(2.0, DataType::Double);
  }

  if (scale_grad_by_freq == nullptr) {
    scale_grad_by_freq = input->fusion()->falseVal();
  }
  if (sparse == nullptr) {
    sparse = input->fusion()->falseVal();
  }
  IrBuilder::create<EmbeddingFwdOp>(
      output,
      input,
      weight,
      padding_idx,
      max_norm,
      norm_type,
      scale_grad_by_freq,
      sparse);

  return output;
Binding Update

The embedding_fwd binding has been updated to reflect the changes in the C++ function. Ensure that the binding is correct and functional.

nvf_ops.def(
    "embedding_fwd",
    [](FusionDefinition::Operators& self,
       Tensor input,
       Tensor weight,
       std::optional<Scalar> padding_idx,
       std::optional<Scalar> max_norm,
       std::optional<Scalar> norm_type,
       std::optional<Scalar> scale_grad_by_freq,
       std::optional<Scalar> sparse) -> decltype(auto) {
      FUSER_PERF_SCOPE("Operators.embedding_fwd");
      NVF_CHECK(
          self.validUse(), "Attempting to add to a completed definition!");
      FusionDefinition* fd = self.fusion_definition;
      size_t ndims = input.dims + 1;
      Tensor output = fd->defineTensor(/*dims=*/ndims);

      auto padding_idx_state = padding_idx.has_value()
          ? fd->recordingState(padding_idx.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);
      auto max_norm_state = max_norm.has_value()
          ? fd->recordingState(max_norm.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);
      auto norm_type_state = norm_type.has_value()
          ? fd->recordingState(norm_type.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);
      auto scale_grad_by_freq_state = scale_grad_by_freq.has_value()
          ? fd->recordingState(scale_grad_by_freq.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);
      auto sparse_state = sparse.has_value()
          ? fd->recordingState(sparse.value()())
          : State(/*_index=*/0, /*_stype=*/serde::StateType::None);

      fd->defineRecord(new EmbeddingFwdOpRecord(
          {fd->recordingState(input()),
           fd->recordingState(weight()),
           padding_idx_state,
           max_norm_state,
           norm_type_state,
           scale_grad_by_freq_state,
           sparse_state},
          {fd->recordingState(output())}));
      return output;
    },
    py::arg("input"),
    py::arg("weight"),
    py::arg("padding_idx").none(true) = py::none(),
    py::arg("max_norm").none(true) = py::none(),
    py::arg("norm_type").none(true) = py::none(),
    py::arg("scale_grad_by_freq").none(true) = py::none(),
    py::arg("sparse").none(true) = py::none(),
    py::return_value_policy::reference);

@Priya2698
Copy link
Collaborator Author

!test


constexpr int64_t n = 5, s = 2;

TEST_F(EmbeddingTest, EmbeddingFwdNode) {
Copy link
Collaborator

@protonu protonu Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's possible to add a check to verify the output of the toString method as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't find verifying against a handwritten string to be very robust since that representation can change based on individual toString methods so I don't add it.

}
std::optional<double> max_norm = std::nullopt;
if (has_max_norm()) {
auto idx = 5 + has_padding_idx();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: having this free 5 bothers me a little bit, but not sure what would be better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be ideal, however, we are fetching the previous variables based on fixed indices as well. The position of the variables is constant so it should be safe.

@Priya2698 Priya2698 requested a review from protonu January 22, 2025 04:17
@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 requested a review from jjsjann123 January 22, 2025 06:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants