Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Oct 31, 2024
1 parent 946403a commit 09af663
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
5 changes: 2 additions & 3 deletions csrc/python_frontend/fusion_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ class FusionState {
//! Get indicies for the extents of TensorView inputs of FusionState
NVF_API const std::vector<int64_t>& extents() const;

//! Add extents of TensorView inputs to FusionState
NVF_API void addExtents();

//! Add a Record
void addRecord(RecordFunctor* record);
//! Builds an nvFuser Fusion IR object
Expand All @@ -108,6 +105,8 @@ class FusionState {
private:
//! Get extents for TensorView inputs in Fusion
std::vector<Val*> getExtents(Fusion* fusion);
//! Add extents of TensorView inputs to FusionState
void addExtents();
//! Change the fusion ptr and reset its state
void resetFusionState(Fusion* fusion, size_t size);

Expand Down
18 changes: 11 additions & 7 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4610,13 +4610,13 @@ def test_fusion_information(self):
def fusion_func(fd: FusionDefinition) -> None:
t0 = fd.from_pytorch(inputs[0])
t1 = fd.from_pytorch(inputs[1])
c0 = fd.define_scalar(3.0)
c2 = fd.define_scalar(3.0)

t2 = fd.ops.add(t0, t1)
t3 = fd.ops.mul(t2, c0)
t4 = fd.ops.sum(t3, [-1], False, DataType.Float)
t3 = fd.ops.add(t0, t1)
t4 = fd.ops.mul(t3, c2)
t5 = fd.ops.sum(t4, [-1], False, DataType.Float)

fd.add_output(t4)
fd.add_output(t5)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out = torch.sum((inputs[0] + inputs[1]) * 3.0, dim=-1)
Expand All @@ -4628,11 +4628,15 @@ def fusion_func(fd: FusionDefinition) -> None:
nvf_out1 = fd.execute(inputs)
self.assertEqual(eager_out, nvf_out1[0])

# The input tensors are t0 and t1.
self.assertEqual(fd.inputs(), [0, 1])
# The output tensors is t5.
self.assertEqual(fd.outputs(), [5])
# The extents correspond with the dimensions for each input tensor.
# There are two input tensors with three dimensions each, so the
# extents range from [-1, -6].
self.assertEqual(fd.extents(), [idx for idx in range(-1, -7, -1)])


def test_issue_3292(self):
inputs = [
torch.testing.make_tensor(
Expand Down Expand Up @@ -4705,4 +4709,4 @@ def fusion_func(fd: FusionDefinition) -> None:
fd.add_output(T223)

# is_clonable=False is because translation fails with missing ceilDiv
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=False)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=False)

0 comments on commit 09af663

Please sign in to comment.