Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EmbeddingFwdOp node with same functionality as F.embedding #3649

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_circular_buffering.cpp
${NVFUSER_ROOT}/tests/cpp/test_abstract_tensor.cpp
${NVFUSER_ROOT}/tests/cpp/test_dynamic_transform.cpp
${NVFUSER_ROOT}/tests/cpp/test_embedding_node.cpp
${NVFUSER_ROOT}/tests/cpp/test_evaluator.cpp
${NVFUSER_ROOT}/tests/cpp/test_exceptions.cpp
${NVFUSER_ROOT}/tests/cpp/test_expr_simplifier.cpp
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ bool isTvOp(const Expr* expr) {
LinearOp,
SdpaFwdOp,
SdpaBwdOp,
EmbeddingFwdOp,
BroadcastOp,
SqueezeOp,
ExpandOp,
Expand Down
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Val;
f(LinearOp); \
f(SdpaFwdOp); \
f(SdpaBwdOp); \
f(EmbeddingFwdOp); \
f(Communication); \
f(ForLoop); \
f(P2PCommunication);
Expand Down
75 changes: 75 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2715,4 +2715,79 @@ class SdpaBwdOp : public Expr {
const std::vector<PolymorphicValue>& inputs) const override;
};

class EmbeddingFwdOp : public Expr {
public:
using Expr::Expr;

EmbeddingFwdOp(
IrBuilderPasskey,
TensorView* output,
TensorView* input,
TensorView* weight,
Val* padding_idx,
Val* max_norm,
Val* norm_type,
Val* scale_grad_by_freq,
Val* sparse);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "EmbeddingFwdOp";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

TensorView* out() const {
return output(0)->as<TensorView>();
}

TensorView* in() const {
return input(0)->as<TensorView>();
}

TensorView* weight() const {
return input(1)->as<TensorView>();
}

Val* norm_type() const {
return input(2);
}

Val* scale_grad_by_freq() const {
return input(3);
}

Val* sparse() const {
return input(4);
}

Val* padding_idx() const {
if (has_padding_idx()) {
return input(5);
}
return nullptr;
}

Val* max_norm() const {
if (has_max_norm()) {
return input(5 + has_padding_idx());
}
return nullptr;
}

bool has_padding_idx() const {
return attribute<bool>(0);
}

bool has_max_norm() const {
return attribute<bool>(1);
}

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
};

} // namespace nvfuser
91 changes: 91 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <type.h>

#include <c10/util/irange.h>
#include <torch/nn/options/embedding.h>

#include <complex>
#include <iterator>
Expand Down Expand Up @@ -5336,4 +5337,94 @@ std::vector<PolymorphicValue> SdpaBwdOp::evaluate(
slice_last_dim(grad_value)};
}

EmbeddingFwdOp::EmbeddingFwdOp(
IrBuilderPasskey passkey,
TensorView* output,
TensorView* input,
TensorView* weight,
Val* padding_idx,
Val* max_norm,
Val* norm_type,
Val* scale_grad_by_freq,
Val* sparse)
: Expr(passkey) {
addOutput(output);

addInput(input);
addInput(weight);
addInput(norm_type);
addInput(scale_grad_by_freq);
addInput(sparse);
if (padding_idx != nullptr) {
addInput(padding_idx);
addDataAttribute(true);
} else {
addDataAttribute(false);
}
if (max_norm != nullptr) {
addInput(max_norm);
addDataAttribute(true);
} else {
addDataAttribute(false);
}
}

NVFUSER_DEFINE_CLONE_AND_CREATE(EmbeddingFwdOp)

std::string EmbeddingFwdOp::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << out()->toString() << ",\n";
indent(ss, indent_size + 1) << " = embedding(" << in()->toString() << ",\n";
indent(ss, indent_size + 1) << " " << weight()->toString() << ",\n";
if (padding_idx() != nullptr) {
indent(ss, indent_size + 1)
<< " padding_idx = " << padding_idx()->toString() << ",\n";
}
if (max_norm() != nullptr) {
indent(ss, indent_size + 1)
<< " max_norm = " << max_norm()->toString() << ",\n";
}
indent(ss, indent_size + 1)
<< " norm_type = " << norm_type()->toString() << ",\n";
indent(ss, indent_size + 1)
<< " scale_grad_by_freq = "
<< scale_grad_by_freq()->toInlineString() << ",\n";
indent(ss, indent_size + 1)
<< " sparse = " << sparse()->toInlineString() << ")\n";
return ss.str();
}

std::string EmbeddingFwdOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

std::vector<PolymorphicValue> EmbeddingFwdOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
auto input = inputs.at(0).as<at::Tensor>();
auto weight = inputs.at(1).as<at::Tensor>();
auto norm_type = inputs.at(2).as<double>();
auto scale_grad_by_freq = inputs.at(3).as<bool>();
auto sparse = inputs.at(4).as<bool>();
std::optional<int64_t> padding_idx = std::nullopt;
if (has_padding_idx()) {
padding_idx = inputs.at(5).as<int64_t>();
}
std::optional<double> max_norm = std::nullopt;
if (has_max_norm()) {
auto idx = 5 + has_padding_idx();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: having this free 5 bothers me a little bit, but not sure what would be better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be ideal, however, we are fetching the previous variables based on fixed indices as well. The position of the variables is constant so it should be safe.

max_norm = inputs.at(idx).as<double>();
}

namespace F = torch::nn::functional;
return {F::embedding(
input,
weight,
F::EmbeddingFuncOptions()
.padding_idx(padding_idx)
.max_norm(max_norm)
.norm_type(norm_type)
.scale_grad_by_freq(scale_grad_by_freq)
.sparse(sparse))};
}
} // namespace nvfuser
21 changes: 21 additions & 0 deletions csrc/logical_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,27 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseLogicalDomainMap::map(
return dom_map;
}

if (EmbeddingFwdOp* op =
dynamic_cast<EmbeddingFwdOp*>(consumer_tv_->definition())) {
// Producers:
// input = [*]
// weight = [V, embedding_dim]
// Consumers:
// output = [*, embedding_dim]
auto ndims_out = consumer_root.size();
if (producer_tv_->sameAs(op->in())) {
for (auto idx : c10::irange(ndims_out - 1)) {
updatePairwiseLogicalDomainMap(
producer_logical.at(idx), consumer_root.at(idx));
}
}
if (producer_tv_->sameAs(op->weight())) {
updatePairwiseLogicalDomainMap(
producer_logical.back(), consumer_root.back());
}
return dom_map;
}

size_t itc = 0, itp = 0;
while (itc < consumer_root.size() && itp < producer_logical.size()) {
IterDomain* producer_id = producer_logical.at(itp);
Expand Down
67 changes: 67 additions & 0 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,4 +662,71 @@ SdpfaBwdResult sdpfa_bwd(
return {grad_query, grad_key, grad_value};
}

TensorView* embedding_fwd(
TensorView* input,
TensorView* weight,
Val* padding_idx,
Val* max_norm,
Val* norm_type,
Val* scale_grad_by_freq,
Val* sparse) {
auto input_domain = TensorDomain::noReductions(input->getLogicalDomain());
auto weight_domain = TensorDomain::noReductions(weight->getLogicalDomain());
NVF_CHECK(
!input_domain.empty(),
"Expected input to be atleast 1D, got: ",
input_domain.size());
NVF_CHECK(
weight_domain.size() == 2,
"Expected weight to be 2D, got: ",
weight_domain.size());

NVF_CHECK(
!padding_idx || padding_idx->isScalar(),
"Expected padding_idx to be a scalar int.");
NVF_CHECK(
!max_norm || max_norm->isScalar(),
"Expected max_norm to be a scalar double.");
NVF_CHECK(
!norm_type || norm_type->isScalar(),
"Expected scale to be a scalar double.");
NVF_CHECK(
!scale_grad_by_freq || scale_grad_by_freq->isScalar(),
"Expected scale to be a scalar bool.");
NVF_CHECK(
!sparse || sparse->isScalar(), "Expected scale to be a scalar bool.");

auto ndims_out = input_domain.size() + 1;
std::vector<IterDomain*> out_domain(ndims_out, nullptr);

for (auto idx : c10::irange(ndims_out - 1)) {
out_domain[idx] = ops::newOutputIterDomain({input_domain[idx]});
}
out_domain[ndims_out - 1] = ops::newOutputIterDomain({weight_domain.back()});
TensorDomain* out_td = IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true));
TensorView* output = IrBuilder::create<TensorView>(out_td, weight->dtype());

if (norm_type == nullptr) {
norm_type = IrBuilder::create<Val>(2.0, DataType::Double);
}
if (scale_grad_by_freq == nullptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we use IrContainer::falseVal() here?
input->fusion()->falseVall() or get the current fusion and call the function on it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar pattern below.

scale_grad_by_freq = IrBuilder::create<Val>(false, DataType::Bool);
}
if (sparse == nullptr) {
sparse = IrBuilder::create<Val>(false, DataType::Bool);
}
IrBuilder::create<EmbeddingFwdOp>(
output,
input,
weight,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse);

return output;
}

} // namespace nvfuser
9 changes: 9 additions & 0 deletions csrc/ops/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,13 @@ SdpfaBwdResult sdpfa_bwd(
TensorView* philox_offset,
Val* scale);

TensorView* embedding_fwd(
TensorView* input,
TensorView* weight,
Val* padding_idx,
Val* max_norm,
Val* norm_type,
Val* scale_grad_by_freq,
Val* sparse);

} // namespace nvfuser
43 changes: 43 additions & 0 deletions csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -3052,6 +3052,49 @@ struct SdpaBwdOpRecord : RecordFunctor {
}
};

struct EmbeddingFwdOpRecord : RecordFunctor {
EmbeddingFwdOpRecord(std::vector<State> args, std::vector<State> outputs)
: RecordFunctor(
std::move(args),
std::move(outputs),
"ops.embedding_fwd",
serde::RecordType::EmbeddingFwdOp) {}
~EmbeddingFwdOpRecord() override = default;
RecordFunctor* clone() final {
return new EmbeddingFwdOpRecord(*this);
}

void operator()(FusionState& fd) final {
auto input = fd.getFusionState(args_.at(0).index)->as<TensorView>();
auto weight = fd.getFusionState(args_.at(1).index)->as<TensorView>();
auto padding_idx = (args_.at(2).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(2).index)->as<Val>()
: nullptr;
auto max_norm = (args_.at(3).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(3).index)->as<Val>()
: nullptr;
auto norm_type = (args_.at(4).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(4).index)->as<Val>()
: nullptr;
auto scale_grad_by_freq = (args_.at(5).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(5).index)->as<Val>()
: nullptr;
auto sparse = (args_.at(6).stype == serde::StateType::Scalar)
? fd.getFusionState(args_.at(6).index)->as<Val>()
: nullptr;

auto output = embedding_fwd(
input,
weight,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse);
fd.setFusionState(outputs_.at(0).index, output);
}
};

} // namespace nvfuser::python_frontend

//! Creating the template specialized hash and equal_to functions for a
Expand Down
Loading
Loading