Skip to content

Commit

Permalink
Fix merge conflicts.
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen committed Jan 20, 2025
1 parent c5939f0 commit 9163101
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
19 changes: 19 additions & 0 deletions csrc/host_ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ir/utils.h>
#include <kernel_ir.h>
#include <ops/all_ops.h>
#include <runtime/executor.h>

namespace nvfuser {

Expand All @@ -32,6 +33,24 @@ std::ostream& HostIrContainer::print(std::ostream& os) const {
return os;
}

const std::vector<Expr*>& HostIrContainer::topLevelExprs() const {
return top_level_exprs_;
}

void HostIrContainer::pushBackTopLevelExprs(Expr* expr) {
assertInContainer(expr, "Cannot add expr, ");
return top_level_exprs_.push_back(expr);
}

void HostIrContainer::pushBackKernelExecutor(
std::unique_ptr<KernelExecutor> ke) {
return kernel_executors_.push_back(std::move(ke));
}

KernelExecutor* HostIrContainer::getKernelExecutor(int64_t index) const {
return kernel_executors_.at(index).get();
}

} // namespace hir

} // namespace nvfuser
20 changes: 6 additions & 14 deletions csrc/host_ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

#include <fusion.h>
#include <host_ir/host_ir.h>
#include <runtime/executor.h>

namespace nvfuser {

class KernelExecutor;

namespace hir {

/*
Expand All @@ -33,22 +34,13 @@ class HostIrContainer final : public Fusion {
//! Print to an output stream
std::ostream& print(std::ostream& os) const;

const auto& topLevelExprs() const {
return top_level_exprs_;
}
const std::vector<Expr*>& topLevelExprs() const;

void pushBackTopLevelExprs(Expr* expr) {
assertInContainer(expr, "Cannot add expr, ");
return top_level_exprs_.push_back(expr);
}
void pushBackTopLevelExprs(Expr* expr);

void pushBackKernelExecutor(std::unique_ptr<KernelExecutor> ke) {
return kernel_executors_.push_back(std::move(ke));
}
void pushBackKernelExecutor(std::unique_ptr<KernelExecutor> ke);

KernelExecutor* getKernelExecutor(int64_t index) const {
return kernel_executors_.at(index).get();
}
KernelExecutor* getKernelExecutor(int64_t index) const;

Stream* getDefaultStream();

Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/test_rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,9 @@ TEST_F(RopeTest, EndingRepeat) {
runtime->schedulerHeuristics()->heuristicsList().front();
EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize);
Fusion* scheduled_fusion =
runtime->executors().at(0)->as<KernelExecutor>()->kernel();
dynamic_cast<KernelExecutor*>(runtime->executors().at(0).get())
->compiledKernel()
->fusion();

// Check the loop domain of the reference. It should look like:
//
Expand Down

0 comments on commit 9163101

Please sign in to comment.