Skip to content

Commit

Permalink
rename to embedding_fwd
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Jan 16, 2025
1 parent a0aa5dc commit e5b0594
Show file tree
Hide file tree
Showing 14 changed files with 33 additions and 33 deletions.
2 changes: 1 addition & 1 deletion csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ bool isTvOp(const Expr* expr) {
LinearOp,
SdpaFwdOp,
SdpaBwdOp,
EmbeddingOp,
EmbeddingFwdOp,
BroadcastOp,
SqueezeOp,
ExpandOp,
Expand Down
2 changes: 1 addition & 1 deletion csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class Val;
f(LinearOp); \
f(SdpaFwdOp); \
f(SdpaBwdOp); \
f(EmbeddingOp); \
f(EmbeddingFwdOp); \
f(Communication); \
f(ForLoop); \
f(P2PCommunication);
Expand Down
6 changes: 3 additions & 3 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2679,11 +2679,11 @@ class SdpaBwdOp : public Expr {
const std::vector<PolymorphicValue>& inputs) const override;
};

class EmbeddingOp : public Expr {
class EmbeddingFwdOp : public Expr {
public:
using Expr::Expr;

EmbeddingOp(
EmbeddingFwdOp(
IrBuilderPasskey,
TensorView* output,
TensorView* input,
Expand All @@ -2697,7 +2697,7 @@ class EmbeddingOp : public Expr {
NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "EmbeddingOp";
return "EmbeddingFwdOp";
}

std::string toString(int indent_size = 0) const override;
Expand Down
10 changes: 5 additions & 5 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5257,7 +5257,7 @@ std::vector<PolymorphicValue> SdpaBwdOp::evaluate(
slice_last_dim(grad_value)};
}

EmbeddingOp::EmbeddingOp(
EmbeddingFwdOp::EmbeddingFwdOp(
IrBuilderPasskey passkey,
TensorView* output,
TensorView* input,
Expand Down Expand Up @@ -5289,9 +5289,9 @@ EmbeddingOp::EmbeddingOp(
}
}

NVFUSER_DEFINE_CLONE_AND_CREATE(EmbeddingOp)
NVFUSER_DEFINE_CLONE_AND_CREATE(EmbeddingFwdOp)

std::string EmbeddingOp::toString(int indent_size) const {
std::string EmbeddingFwdOp::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << out()->toString() << ",\n";
indent(ss, indent_size + 1)
Expand All @@ -5314,11 +5314,11 @@ std::string EmbeddingOp::toString(int indent_size) const {
return ss.str();
}

std::string EmbeddingOp::toInlineString(int indent_size) const {
std::string EmbeddingFwdOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

std::vector<PolymorphicValue> EmbeddingOp::evaluate(
std::vector<PolymorphicValue> EmbeddingFwdOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {

Expand Down
2 changes: 1 addition & 1 deletion csrc/logical_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseLogicalDomainMap::map(
return dom_map;
}

if (EmbeddingOp* op = dynamic_cast<EmbeddingOp*>(consumer_tv_->definition())) {
if (EmbeddingFwdOp* op = dynamic_cast<EmbeddingFwdOp*>(consumer_tv_->definition())) {
// Producers:
// input = [*]
// weight = [V, embedding_dim]
Expand Down
4 changes: 2 additions & 2 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ SdpfaBwdResult sdpfa_bwd(
return {grad_query, grad_key, grad_value};
}

TensorView* embedding(
TensorView* embedding_fwd(
TensorView* input,
TensorView* weight,
Val* padding_idx,
Expand Down Expand Up @@ -656,7 +656,7 @@ TensorView* embedding(
if (sparse == nullptr){
sparse = IrBuilder::create<Val>(false, DataType::Bool);
}
IrBuilder::create<EmbeddingOp>(
IrBuilder::create<EmbeddingFwdOp>(
output,
input,
weight,
Expand Down
2 changes: 1 addition & 1 deletion csrc/ops/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ SdpfaBwdResult sdpfa_bwd(
TensorView* philox_offset,
Val* scale);

TensorView* embedding(
TensorView* embedding_fwd(
TensorView* input,
TensorView* weight,
Val* padding_idx,
Expand Down
14 changes: 7 additions & 7 deletions csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -3052,16 +3052,16 @@ struct SdpaBwdOpRecord : RecordFunctor {
}
};

struct EmbeddingOpRecord : RecordFunctor {
EmbeddingOpRecord(std::vector<State> args, std::vector<State> outputs)
struct EmbeddingFwdOpRecord : RecordFunctor {
EmbeddingFwdOpRecord(std::vector<State> args, std::vector<State> outputs)
: RecordFunctor(
std::move(args),
std::move(outputs),
"ops.embedding",
serde::RecordType::EmbeddingOp) {}
~EmbeddingOpRecord() override = default;
"ops.embedding_fwd",
serde::RecordType::EmbeddingFwdOp) {}
~EmbeddingFwdOpRecord() override = default;
RecordFunctor* clone() final {
return new EmbeddingOpRecord(*this);
return new EmbeddingFwdOpRecord(*this);
}

void operator()(FusionState& fd) final {
Expand All @@ -3083,7 +3083,7 @@ struct EmbeddingOpRecord : RecordFunctor {
? fd.getFusionState(args_.at(6).index)->as<Val>()
: nullptr;

auto output = embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse);
auto output = embedding_fwd(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse);
fd.setFusionState(outputs_.at(0).index, output);
}
};
Expand Down
6 changes: 3 additions & 3 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3570,7 +3570,7 @@ void initNvFuserPythonBindings(PyObject* module) {
py::return_value_policy::reference);

nvf_ops.def(
"embedding",
"embedding_fwd",
[](FusionDefinition::Operators& self,
Tensor input,
Tensor weight,
Expand All @@ -3579,7 +3579,7 @@ void initNvFuserPythonBindings(PyObject* module) {
std::optional<Scalar> norm_type,
std::optional<Scalar> scale_grad_by_freq,
std::optional<Scalar> sparse) -> decltype(auto) {
FUSER_PERF_SCOPE("Operators.embedding");
FUSER_PERF_SCOPE("Operators.embedding_fwd");
NVF_CHECK(
self.validUse(), "Attempting to add to a completed definition!");
FusionDefinition* fd = self.fusion_definition;
Expand All @@ -3602,7 +3602,7 @@ void initNvFuserPythonBindings(PyObject* module) {
? fd->recordingState(sparse.value()())
: State(/*_index=*/0, /*_stype=*/serde::StateType::None);

fd->defineRecord(new EmbeddingOpRecord(
fd->defineRecord(new EmbeddingFwdOpRecord(
{fd->recordingState(input()),
fd->recordingState(weight()),
padding_idx_state,
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/expr_eval_sched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

if (exprs.front()->isOneOf<SdpaFwdOp, SdpaBwdOp, EmbeddingOp>()) {
if (exprs.front()->isOneOf<SdpaFwdOp, SdpaBwdOp, EmbeddingFwdOp>()) {
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/serde/fusion_cache.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ enum RecordType: int {
CastTv,
CastVal,
CatOp,
EmbeddingOp,
EmbeddingFwdOp,
End,
ExpandOp,
FullOp,
Expand Down
6 changes: 3 additions & 3 deletions csrc/serde/fusion_record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,11 @@ void RecordFunctorFactory::registerAllParsers() {
};
registerParser(RecordType::SdpaBwdOp, deserializeSdpaBwdRecord);

auto deserializeEmbeddingRecord = [&](const RecordFunctor* buffer) {
return new python_frontend::EmbeddingOpRecord(
auto deserializeEmbeddingFwdRecord = [&](const RecordFunctor* buffer) {
return new python_frontend::EmbeddingFwdOpRecord(
parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()));
};
registerParser(RecordType::EmbeddingOp, deserializeEmbeddingRecord);
registerParser(RecordType::EmbeddingFwdOp, deserializeEmbeddingFwdRecord);

// END OpRecord Parsers

Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_embedding_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using EmbeddingTest = NVFuserTest;

constexpr int64_t n = 5, s = 2;

TEST_F(EmbeddingTest, EmbeddingNode) {
TEST_F(EmbeddingTest, EmbeddingFwdNode) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
std::vector<int64_t> inp_shape({s});
Expand All @@ -32,7 +32,7 @@ TEST_F(EmbeddingTest, EmbeddingNode) {
fusion->addInput(tv_inp);
fusion->addInput(tv_weight);

auto tv_output = embedding(tv_inp, tv_weight, nullptr, nullptr, nullptr, nullptr, nullptr);
auto tv_output = embedding_fwd(tv_inp, tv_weight, nullptr, nullptr, nullptr, nullptr, nullptr);
fusion->addOutput(tv_output);

auto options = at::TensorOptions().device(at::kCUDA, 0);
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def fusion_func(
for idx in range(len(optional_inputs)):
if has_optional_inputs[idx]:
optional_inputs[idx] = fd.define_scalar(value=None, dtype=optional_inputs_dtypes[idx])
out = fd.ops.embedding(input, weight, *optional_inputs)
out = fd.ops.embedding_fwd(input, weight, *optional_inputs)
fd.add_output(out)

N, S = 10, 3
Expand All @@ -73,5 +73,5 @@ def fusion_func(
norm_type = 2.0 if norm_type is None else norm_type
scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq
sparse = False if sparse is None else sparse
ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
ref_out = F.embedding_fwd(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
torch.testing.assert_close(nvf_out[0], ref_out)

0 comments on commit e5b0594

Please sign in to comment.