Skip to content

Commit

Permalink
create SegmentationState
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 2, 2024
1 parent d4b720a commit 920dec6
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 66 deletions.
66 changes: 39 additions & 27 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,32 @@ std::vector<std::pair<double, double>> FusionDefinition::getValTolerances(
return get_val_constants(preschedFusion(), inputs);
}

void FusionDefinition::prepareGroupOrder() {
int64_t FusionDefinition::setupSegmentation(
const at::ArrayRef<c10::IValue>& inputs) {
NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!");
NVF_ERROR(
segmentation_state_ == nullptr, "SegmentationState already exists!");
segmentation_state_ = std::make_unique<SegmentationState>();
return segmentation_state_->setupSegmentation(
preschedFusion(), map_value_to_fid_, inputs);
}

std::unordered_map<int64_t, int64_t> FusionDefinition::buildSegment(
FusionDefinition& other,
int64_t segment_id) {
NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!");
NVF_CHECK(
segmentation_state_ != nullptr,
"Run setupSegmenation first before trying to build segments!");
return segmentation_state_->buildSegment(other, segment_id);
}

void FusionDefinition::finalizeSegmentation() {
// Destroy SegmentedState
segmentation_state_.reset();
}

void SegmentationState::prepareGroupOrder() {
NVF_ERROR(segmented_fusion_ != nullptr);

// Setup group run order
Expand Down Expand Up @@ -736,24 +761,25 @@ void FusionDefinition::prepareGroupOrder() {
}
}

int64_t FusionDefinition::setupSegmentation(
int64_t SegmentationState::setupSegmentation(
Fusion* fusion,
const std::unordered_map<const Val*, int64_t>& map_value_to_original_fid,
const at::ArrayRef<c10::IValue>& inputs) {
NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!");
int8_t device = getCommonDeviceCUDA(inputs);
NVF_CHECK(
inputs.empty() || device > -1, "Inputs are not all on the same device!");

// Check segmentation state
// Check state
NVF_ERROR(fusion != nullptr);
NVF_ERROR(segment_fusion_ == nullptr);
NVF_ERROR(segmented_fusion_ == nullptr);
NVF_ERROR(group_run_order_.empty());
NVF_ERROR(map_cloned_value_to_fid_.empty());
NVF_ERROR(cloned_extents_.empty());

int8_t device = getCommonDeviceCUDA(inputs);
NVF_CHECK(
inputs.empty() || device > -1, "Inputs are not all on the same device!");

// Clone CPP Fusion
segment_fusion_ = std::make_unique<Fusion>();
IrCloner original_to_cloned_map =
Fusion::copy(preschedFusion(), segment_fusion_.get());
IrCloner original_to_cloned_map = Fusion::copy(fusion, segment_fusion_.get());

// Get arguments
KernelArgumentHolder args =
Expand All @@ -764,14 +790,10 @@ int64_t FusionDefinition::setupSegmentation(
std::unordered_map<Val*, Val*> symbolic_to_concrete_map =
DynamicTransform::concretizeFusion(segment_fusion_.get(), args);

// NOTE: The following tests require using the MarkAliasesPreparePass before
// segmentation, but not running AllocationDomainPass when running each
// segment. See test_issue1953 and test_unpadded_catop_issue2275_repro1.

// Track mapping from cloned CPP fusion and FusionDefinition indices.
std::transform(
map_value_to_fid_.begin(),
map_value_to_fid_.end(),
map_value_to_original_fid.begin(),
map_value_to_original_fid.end(),
std::inserter(map_cloned_value_to_fid_, map_cloned_value_to_fid_.end()),
[&](const auto& item) {
const Val* original_value = item.first;
Expand Down Expand Up @@ -804,10 +826,9 @@ int64_t FusionDefinition::setupSegmentation(
return (int64_t)segmented_fusion_->groups().size();
}

std::unordered_map<int64_t, int64_t> FusionDefinition::buildSegment(
std::unordered_map<int64_t, int64_t> SegmentationState::buildSegment(
FusionDefinition& other,
int64_t segment_id) {
NVF_CHECK(id().has_value(), "FusionDefinition definition does not exist!");
NVF_ERROR(
!other.completed(),
"Expected an incomplete definition before translation.");
Expand Down Expand Up @@ -966,13 +987,4 @@ std::unordered_map<int64_t, int64_t> FusionDefinition::buildSegment(
return segment_fid_to_original_fid_map;
}

void FusionDefinition::finalizeSegmentation() {
// Destroy SegmentedFusion
segmented_fusion_.reset(nullptr);
segment_fusion_.reset(nullptr);
group_run_order_.clear();
map_cloned_value_to_fid_.clear();
cloned_extents_.clear();
}

} // namespace nvfuser::python_frontend
58 changes: 42 additions & 16 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,45 @@ struct Vector {
FusionDefinition* fusion_definition;
};

class SegmentationState {
public:
//! Run segmentation algorithm on FusionDefinition. Returns the number of
//! segments.
int64_t setupSegmentation(
Fusion* fusion,
const std::unordered_map<const Val*, int64_t>& map_value_to_original_fid,
const at::ArrayRef<c10::IValue>& inputs);

//! Given SegmentedFusion and vector of FusionDefinition objects for the
//! fusion segments, create the fusion segments and clone their state to the
//! FusionDefinitions.
NVF_API std::unordered_map<int64_t, int64_t> buildSegment(
FusionDefinition& other,
int64_t segment_id);

//! Perform a topological sort on SegmentedFusion to segment order.
void prepareGroupOrder();

private:
//! Clone of original fusion for segmentation
std::unique_ptr<Fusion> segment_fusion_ = nullptr;

//! This FusionDefinition may require multiple kernels if it cannot be handled
//! by a single heuristic scheduler. SegmentedFusion takes a fusion and runs
//! the segmentation algorithm.
std::unique_ptr<SegmentedFusion> segmented_fusion_ = nullptr;

//! Pre-determined order to run the segmented groups
std::vector<SegmentedGroup*> group_run_order_;

//! Create copy of fusion for segmentation algorithm. IrCloner is a map
//! between values in original and cloned fusions.
std::unordered_map<const Val*, int64_t> map_cloned_value_to_fid_;

//! Extents for TensorView input arguments for cloned Fusion
std::vector<Val*> cloned_extents_;
};

//! FusionDefinition defines the C++ side of a Python Context manager to
//! encapsulate the definition of fusion operations.
//!
Expand Down Expand Up @@ -281,8 +320,6 @@ class NVF_API FusionDefinition : public FusionState {
// Check that the NvFuser TensorView and the Python Tensor dimensions match.
// Apply after buildFusionIr
void verifyTensorDimensions();
//! Perform a topological sort on SegmentedFusion to segment order.
void prepareGroupOrder();

//! Holds the defined maximum length of a FusionDefinition in order to
//! prevent a run away error. The user should feel free to increase this
Expand All @@ -304,20 +341,9 @@ class NVF_API FusionDefinition : public FusionState {
UserSchedule* user_sched_;
//! Number of recording_states_ before applying user schedule
int64_t num_recording_states_presched_ = 0;

//! Clone of original fusion for segmentation
std::unique_ptr<Fusion> segment_fusion_ = nullptr;
//! This FusionDefinition may require multiple kernels if it cannot be handled
//! by a single heuristic scheduler. SegmentedFusion takes a fusion and runs
//! the segmentation algorithm.
std::unique_ptr<SegmentedFusion> segmented_fusion_ = nullptr;
//! Pre-determined order to run the segmented groups
std::vector<SegmentedGroup*> group_run_order_;
//! Create copy of fusion for segmentation algorithm. IrCloner is a map
//! between values in original and cloned fusions.
std::unordered_map<const Val*, int64_t> map_cloned_value_to_fid_;
//! Extents for TensorView input arguments for cloned Fusion
std::vector<Val*> cloned_extents_;
//! Data member that creates SegmentedFusion from cloned, prescheduled Fusion
//! then translates the segments to python FusionDefinitions.
std::unique_ptr<SegmentationState> segmentation_state_;

public:
//! The Operators are not directly defined in this header. They are defined
Expand Down
42 changes: 21 additions & 21 deletions csrc/python_frontend/fusion_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ std::ostream& operator<<(std::ostream& os, const State& state) {
return os;
}

std::vector<Val*> getExtents(Fusion* fusion) {
NVF_CHECK(fusion != nullptr, "Fusion is undefined.");

std::vector<Val*> extents;
for (Val* v : fusion->inputs()) {
// short-circuit: skip if not TensorView
if (!v->isA<TensorView>()) {
continue;
}
TensorView* tv = v->as<TensorView>();
std::vector<IterDomain*> logical_dom =
TensorDomain::noReductions(tv->getLogicalDomain());
std::transform(
logical_dom.begin(),
logical_dom.end(),
std::back_inserter(extents),
[](IterDomain* id) { return id->getMaybeExpandedExtent(); });
}
return extents;
}

FusionState::FusionState()
: end_record_(new EndRecord()),
recording_(),
Expand Down Expand Up @@ -249,27 +270,6 @@ const std::vector<int64_t>& FusionState::extents() const {
return extents_fid_;
}

std::vector<Val*> FusionState::getExtents(Fusion* fusion) {
NVF_CHECK(fusion != nullptr, "Fusion is undefined.");

std::vector<Val*> extents;
for (Val* v : fusion->inputs()) {
// short-circuit: skip if not TensorView
if (!v->isA<TensorView>()) {
continue;
}
TensorView* tv = v->as<TensorView>();
std::vector<IterDomain*> logical_dom =
TensorDomain::noReductions(tv->getLogicalDomain());
std::transform(
logical_dom.begin(),
logical_dom.end(),
std::back_inserter(extents),
[](IterDomain* id) { return id->getMaybeExpandedExtent(); });
}
return extents;
}

void FusionState::addExtents() {
NVF_CHECK(fusion_ != nullptr, "Fusion is undefined.");

Expand Down
5 changes: 3 additions & 2 deletions csrc/python_frontend/fusion_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ struct State {

NVF_API std::ostream& operator<<(std::ostream& os, const State& state);

//! Get extents for TensorView inputs in Fusion
std::vector<Val*> getExtents(Fusion* fusion);

//! FusionState contains the information used to build a new cpp Fusion object.
//! Unlike FusionDefinition, it does not modify the FusionCache Trie structure.
class FusionState {
Expand Down Expand Up @@ -93,8 +96,6 @@ class FusionState {
NVF_API const std::vector<int64_t>& outputs() const;
//! Get indicies for the extents of TensorView inputs of FusionState
NVF_API const std::vector<int64_t>& extents() const;
//! Get extents for TensorView inputs in Fusion
std::vector<Val*> getExtents(Fusion* fusion);

//! Add a Record
void addRecord(RecordFunctor* record);
Expand Down

0 comments on commit 920dec6

Please sign in to comment.