diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f182892ff1..77f5a7ef1d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -547,6 +547,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_circular_buffering.cpp ${NVFUSER_ROOT}/tests/cpp/test_abstract_tensor.cpp ${NVFUSER_ROOT}/tests/cpp/test_dynamic_transform.cpp + ${NVFUSER_ROOT}/tests/cpp/test_embedding_node.cpp ${NVFUSER_ROOT}/tests/cpp/test_evaluator.cpp ${NVFUSER_ROOT}/tests/cpp/test_exceptions.cpp ${NVFUSER_ROOT}/tests/cpp/test_expr_simplifier.cpp diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 66baab289db..8a423c2e3ba 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -159,6 +159,7 @@ bool isTvOp(const Expr* expr) { LinearOp, SdpaFwdOp, SdpaBwdOp, + EmbeddingFwdOp, BroadcastOp, SqueezeOp, ExpandOp, diff --git a/csrc/dispatch.h b/csrc/dispatch.h index c2a2816fb7e..ee47464a6fb 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -111,6 +111,7 @@ class Val; f(LinearOp); \ f(SdpaFwdOp); \ f(SdpaBwdOp); \ + f(EmbeddingFwdOp); \ f(Communication); \ f(ForLoop); \ f(P2PCommunication); diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 6aebcb3c457..fb89716e9bf 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2715,4 +2715,79 @@ class SdpaBwdOp : public Expr { const std::vector& inputs) const override; }; +class EmbeddingFwdOp : public Expr { + public: + using Expr::Expr; + + EmbeddingFwdOp( + IrBuilderPasskey, + TensorView* output, + TensorView* input, + TensorView* weight, + Val* padding_idx, + Val* max_norm, + Val* norm_type, + Val* scale_grad_by_freq, + Val* sparse); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "EmbeddingFwdOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + TensorView* out() const { + return output(0)->as(); + } + + TensorView* in() const { + return input(0)->as(); + } + + TensorView* weight() const { + return input(1)->as(); + } + + Val* norm_type() const { + return input(2); + } + + Val* scale_grad_by_freq() const { + return input(3); + } + + Val* sparse() const { + return input(4); + } + + Val* padding_idx() const { + if (has_padding_idx()) { + return input(5); + } + return nullptr; + } + + Val* max_norm() const { + if (has_max_norm()) { + return input(5 + has_padding_idx()); + } + return nullptr; + } + + bool has_padding_idx() const { + return attribute(0); + } + + bool has_max_norm() const { + return attribute(1); + } + + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; +}; + } // namespace nvfuser diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3087fe4e34e..ddceff5a474 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -5336,4 +5337,94 @@ std::vector SdpaBwdOp::evaluate( slice_last_dim(grad_value)}; } +EmbeddingFwdOp::EmbeddingFwdOp( + IrBuilderPasskey passkey, + TensorView* output, + TensorView* input, + TensorView* weight, + Val* padding_idx, + Val* max_norm, + Val* norm_type, + Val* scale_grad_by_freq, + Val* sparse) + : Expr(passkey) { + addOutput(output); + + addInput(input); + addInput(weight); + addInput(norm_type); + addInput(scale_grad_by_freq); + addInput(sparse); + if (padding_idx != nullptr) { + addInput(padding_idx); + addDataAttribute(true); + } else { + addDataAttribute(false); + } + if (max_norm != nullptr) { + addInput(max_norm); + addDataAttribute(true); + } else { + addDataAttribute(false); + } +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(EmbeddingFwdOp) + +std::string EmbeddingFwdOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << ",\n"; + indent(ss, indent_size + 1) << " = embedding(" << in()->toString() << ",\n"; + indent(ss, indent_size + 1) << " " << weight()->toString() << ",\n"; + if (padding_idx() != nullptr) { + indent(ss, indent_size + 1) + << " padding_idx = " << padding_idx()->toString() << ",\n"; + } + if (max_norm() != nullptr) { + indent(ss, indent_size + 1) + << " max_norm = " << max_norm()->toString() << ",\n"; + } + indent(ss, indent_size + 1) + << " norm_type = " << norm_type()->toString() << ",\n"; + indent(ss, indent_size + 1) + << " scale_grad_by_freq = " + << scale_grad_by_freq()->toInlineString() << ",\n"; + indent(ss, indent_size + 1) + << " sparse = " << sparse()->toInlineString() << ")\n"; + return ss.str(); +} + +std::string EmbeddingFwdOp::toInlineString(int indent_size) const { + NVF_CHECK(false, "Tensor op can not be printed inline"); +} + +std::vector EmbeddingFwdOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + auto input = inputs.at(0).as(); + auto weight = inputs.at(1).as(); + auto norm_type = inputs.at(2).as(); + auto scale_grad_by_freq = inputs.at(3).as(); + auto sparse = inputs.at(4).as(); + std::optional padding_idx = std::nullopt; + if (has_padding_idx()) { + padding_idx = inputs.at(5).as(); + } + std::optional max_norm = std::nullopt; + if (has_max_norm()) { + auto idx = 5 + has_padding_idx(); + max_norm = inputs.at(idx).as(); + } + + namespace F = torch::nn::functional; + return {F::embedding( + input, + weight, + F::EmbeddingFuncOptions() + .padding_idx(padding_idx) + .max_norm(max_norm) + .norm_type(norm_type) + .scale_grad_by_freq(scale_grad_by_freq) + .sparse(sparse))}; +} } // namespace nvfuser diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index f49d0a9466f..a94a3a3f613 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -325,6 +325,27 @@ std::unordered_map PairwiseLogicalDomainMap::map( return dom_map; } + if (EmbeddingFwdOp* op = + dynamic_cast(consumer_tv_->definition())) { + // Producers: + // input = [*] + // weight = [V, embedding_dim] + // Consumers: + // output = [*, embedding_dim] + auto ndims_out = consumer_root.size(); + if (producer_tv_->sameAs(op->in())) { + for (auto idx : c10::irange(ndims_out - 1)) { + updatePairwiseLogicalDomainMap( + producer_logical.at(idx), consumer_root.at(idx)); + } + } + if (producer_tv_->sameAs(op->weight())) { + updatePairwiseLogicalDomainMap( + producer_logical.back(), consumer_root.back()); + } + return dom_map; + } + size_t itc = 0, itp = 0; while (itc < consumer_root.size() && itp < producer_logical.size()) { IterDomain* producer_id = producer_logical.at(itp); diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index ed8986ff817..fe5caa02d16 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -662,4 +662,71 @@ SdpfaBwdResult sdpfa_bwd( return {grad_query, grad_key, grad_value}; } +TensorView* embedding_fwd( + TensorView* input, + TensorView* weight, + Val* padding_idx, + Val* max_norm, + Val* norm_type, + Val* scale_grad_by_freq, + Val* sparse) { + auto input_domain = TensorDomain::noReductions(input->getLogicalDomain()); + auto weight_domain = TensorDomain::noReductions(weight->getLogicalDomain()); + NVF_CHECK( + !input_domain.empty(), + "Expected input to be atleast 1D, got: ", + input_domain.size()); + NVF_CHECK( + weight_domain.size() == 2, + "Expected weight to be 2D, got: ", + weight_domain.size()); + + NVF_CHECK( + !padding_idx || padding_idx->isScalar(), + "Expected padding_idx to be a scalar int."); + NVF_CHECK( + !max_norm || max_norm->isScalar(), + "Expected max_norm to be a scalar double."); + NVF_CHECK( + !norm_type || norm_type->isScalar(), + "Expected scale to be a scalar double."); + NVF_CHECK( + !scale_grad_by_freq || scale_grad_by_freq->isScalar(), + "Expected scale to be a scalar bool."); + NVF_CHECK( + !sparse || sparse->isScalar(), "Expected scale to be a scalar bool."); + + auto ndims_out = input_domain.size() + 1; + std::vector out_domain(ndims_out, nullptr); + + for (auto idx : c10::irange(ndims_out - 1)) { + out_domain[idx] = ops::newOutputIterDomain({input_domain[idx]}); + } + out_domain[ndims_out - 1] = ops::newOutputIterDomain({weight_domain.back()}); + TensorDomain* out_td = IrBuilder::create( + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); + TensorView* output = IrBuilder::create(out_td, weight->dtype()); + + if (norm_type == nullptr) { + norm_type = IrBuilder::create(2.0, DataType::Double); + } + if (scale_grad_by_freq == nullptr) { + scale_grad_by_freq = IrBuilder::create(false, DataType::Bool); + } + if (sparse == nullptr) { + sparse = IrBuilder::create(false, DataType::Bool); + } + IrBuilder::create( + output, + input, + weight, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse); + + return output; +} + } // namespace nvfuser diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index b67015b994d..cb505190e7c 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -118,4 +118,13 @@ SdpfaBwdResult sdpfa_bwd( TensorView* philox_offset, Val* scale); +TensorView* embedding_fwd( + TensorView* input, + TensorView* weight, + Val* padding_idx, + Val* max_norm, + Val* norm_type, + Val* scale_grad_by_freq, + Val* sparse); + } // namespace nvfuser diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 0bc33244da1..43bb4d098b9 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -3052,6 +3052,49 @@ struct SdpaBwdOpRecord : RecordFunctor { } }; +struct EmbeddingFwdOpRecord : RecordFunctor { + EmbeddingFwdOpRecord(std::vector args, std::vector outputs) + : RecordFunctor( + std::move(args), + std::move(outputs), + "ops.embedding_fwd", + serde::RecordType::EmbeddingFwdOp) {} + ~EmbeddingFwdOpRecord() override = default; + RecordFunctor* clone() final { + return new EmbeddingFwdOpRecord(*this); + } + + void operator()(FusionState& fd) final { + auto input = fd.getFusionState(args_.at(0).index)->as(); + auto weight = fd.getFusionState(args_.at(1).index)->as(); + auto padding_idx = (args_.at(2).stype == serde::StateType::Scalar) + ? fd.getFusionState(args_.at(2).index)->as() + : nullptr; + auto max_norm = (args_.at(3).stype == serde::StateType::Scalar) + ? fd.getFusionState(args_.at(3).index)->as() + : nullptr; + auto norm_type = (args_.at(4).stype == serde::StateType::Scalar) + ? fd.getFusionState(args_.at(4).index)->as() + : nullptr; + auto scale_grad_by_freq = (args_.at(5).stype == serde::StateType::Scalar) + ? fd.getFusionState(args_.at(5).index)->as() + : nullptr; + auto sparse = (args_.at(6).stype == serde::StateType::Scalar) + ? fd.getFusionState(args_.at(6).index)->as() + : nullptr; + + auto output = embedding_fwd( + input, + weight, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse); + fd.setFusionState(outputs_.at(0).index, output); + } +}; + } // namespace nvfuser::python_frontend //! Creating the template specialized hash and equal_to functions for a diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index d20105e990a..b696914f9f9 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -3621,6 +3621,59 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("scale").none(true) = py::none(), py::return_value_policy::reference); + nvf_ops.def( + "embedding_fwd", + [](FusionDefinition::Operators& self, + Tensor input, + Tensor weight, + std::optional padding_idx, + std::optional max_norm, + std::optional norm_type, + std::optional scale_grad_by_freq, + std::optional sparse) -> decltype(auto) { + FUSER_PERF_SCOPE("Operators.embedding_fwd"); + NVF_CHECK( + self.validUse(), "Attempting to add to a completed definition!"); + FusionDefinition* fd = self.fusion_definition; + size_t ndims = input.dims + 1; + Tensor output = fd->defineTensor(/*dims=*/ndims); + + auto padding_idx_state = padding_idx.has_value() + ? fd->recordingState(padding_idx.value()()) + : State(/*_index=*/0, /*_stype=*/serde::StateType::None); + auto max_norm_state = max_norm.has_value() + ? fd->recordingState(max_norm.value()()) + : State(/*_index=*/0, /*_stype=*/serde::StateType::None); + auto norm_type_state = norm_type.has_value() + ? fd->recordingState(norm_type.value()()) + : State(/*_index=*/0, /*_stype=*/serde::StateType::None); + auto scale_grad_by_freq_state = scale_grad_by_freq.has_value() + ? fd->recordingState(scale_grad_by_freq.value()()) + : State(/*_index=*/0, /*_stype=*/serde::StateType::None); + auto sparse_state = sparse.has_value() + ? fd->recordingState(sparse.value()()) + : State(/*_index=*/0, /*_stype=*/serde::StateType::None); + + fd->defineRecord(new EmbeddingFwdOpRecord( + {fd->recordingState(input()), + fd->recordingState(weight()), + padding_idx_state, + max_norm_state, + norm_type_state, + scale_grad_by_freq_state, + sparse_state}, + {fd->recordingState(output())})); + return output; + }, + py::arg("input"), + py::arg("weight"), + py::arg("padding_idx").none(true) = py::none(), + py::arg("max_norm").none(true) = py::none(), + py::arg("norm_type").none(true) = py::none(), + py::arg("scale_grad_by_freq").none(true) = py::none(), + py::arg("sparse").none(true) = py::none(), + py::return_value_policy::reference); + bindSchedule(fusion_def); bindCommunicator(nvfuser); diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index a40e118b1fd..affecd14fbd 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -29,7 +29,7 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } - if (exprs.front()->isOneOf()) { + if (exprs.front()->isOneOf()) { return true; } diff --git a/csrc/serde/fusion_cache.fbs b/csrc/serde/fusion_cache.fbs index b21e4ea4f82..562f35d9e59 100644 --- a/csrc/serde/fusion_cache.fbs +++ b/csrc/serde/fusion_cache.fbs @@ -42,6 +42,7 @@ enum RecordType: int { CastTv, CastVal, CatOp, + EmbeddingFwdOp, End, ExpandOp, FullOp, diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index 253270f0b9d..ff003f20734 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -321,6 +321,13 @@ void RecordFunctorFactory::registerAllParsers() { parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); }; registerParser(RecordType::SdpaBwdOp, deserializeSdpaBwdRecord); + + auto deserializeEmbeddingFwdRecord = [&](const RecordFunctor* buffer) { + return new python_frontend::EmbeddingFwdOpRecord( + parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); + }; + registerParser(RecordType::EmbeddingFwdOp, deserializeEmbeddingFwdRecord); + // END OpRecord Parsers // START Reduction Parsers diff --git a/tests/cpp/test_embedding_node.cpp b/tests/cpp/test_embedding_node.cpp new file mode 100644 index 00000000000..b4b6c3293b7 --- /dev/null +++ b/tests/cpp/test_embedding_node.cpp @@ -0,0 +1,50 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include +#include + +namespace nvfuser { + +using EmbeddingTest = NVFuserTest; + +constexpr int64_t n = 5, s = 2; + +TEST_F(EmbeddingTest, EmbeddingFwdNode) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + std::vector inp_shape({s}); + std::vector weight_shape({n, s}); + + auto tv_inp = makeConcreteTensor(inp_shape, DataType::Int); + auto tv_weight = makeConcreteTensor(weight_shape, DataType::Half); + + fusion->addInput(tv_inp); + fusion->addInput(tv_weight); + + auto tv_output = embedding_fwd( + tv_inp, tv_weight, nullptr, nullptr, nullptr, nullptr, nullptr); + fusion->addOutput(tv_output); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::randint(n, inp_shape, options.dtype(at::kLong)); + at::Tensor weight = at::randn(weight_shape, options.dtype(at::kHalf)); + + namespace F = torch::nn::functional; + auto aten_out = F::embedding(input, weight); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto nvf_out = executor_cache.runFusionWithInputs({input, weight}); + EXPECT_TRUE(at::allclose(nvf_out[0], aten_out)); +} +} // namespace nvfuser diff --git a/tests/python/test_embedding.py b/tests/python/test_embedding.py new file mode 100644 index 00000000000..c517ddc65f2 --- /dev/null +++ b/tests/python/test_embedding.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import torch +from nvfuser import FusionDefinition, DataType +import pytest +import torch.nn.functional as F + + +@pytest.mark.parametrize("padding_idx", [None, -2]) +@pytest.mark.parametrize("max_norm", [None, 1e-5]) +@pytest.mark.parametrize("norm_type", [None, 1.0]) +@pytest.mark.parametrize("scale_grad_by_freq", [None, True]) +@pytest.mark.parametrize("sparse", [None, True]) +def test_embedding( + padding_idx: None | int, + max_norm: None | float, + norm_type: None | float, + scale_grad_by_freq: None | bool, + sparse: None | bool, +): + def fusion_func( + fd: FusionDefinition, + has_optional_inputs: list[bool], + optional_inputs_dtypes: list[DataType], + ): + input = fd.define_tensor( + shape=[-1], + contiguity=[True], + dtype=DataType.Int, + is_cpu=False, + ) + weight = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + ) + # padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse + optional_inputs = [None] * 5 + for idx in range(len(optional_inputs)): + if has_optional_inputs[idx]: + optional_inputs[idx] = fd.define_scalar( + value=None, dtype=optional_inputs_dtypes[idx] + ) + out = fd.ops.embedding_fwd(input, weight, *optional_inputs) + fd.add_output(out) + + N, S = 10, 3 + input = torch.randint( + N, (S,), dtype=torch.int64, device="cuda", requires_grad=False + ) + weight = torch.randn(N, S, dtype=torch.bfloat16, device="cuda", requires_grad=True) + optional_inputs_dtypes = [ + DataType.Int, + DataType.Float, + DataType.Float, + DataType.Bool, + DataType.Bool, + ] + + # This is not in pytest_ops.py since the torch API does not accept None values for some arguments. + # Different inputs for nvfuser and torch API cannot be handled within OpInfo + optional_inputs = [padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse] + has_optional_inputs = [None] * 5 + inputs = [input, weight] + for idx, param in enumerate(optional_inputs): + if param is not None: + has_optional_inputs[idx] = True + inputs.append(param) + + with FusionDefinition() as fd: + fusion_func( + fd, + has_optional_inputs=has_optional_inputs, + optional_inputs_dtypes=optional_inputs_dtypes, + ) + nvf_out = fd.execute(inputs) + + norm_type = 2.0 if norm_type is None else norm_type + scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq + sparse = False if sparse is None else sparse + ref_out = F.embedding( + input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse + ) + torch.testing.assert_close(nvf_out[0], ref_out) diff --git a/version.txt b/version.txt index ac16615536e..8cf7c24c6b0 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.24 +0.2.25