diff --git a/csrc/python_frontend/fusion_state.h b/csrc/python_frontend/fusion_state.h index b7adaf82a00..7a83886514a 100644 --- a/csrc/python_frontend/fusion_state.h +++ b/csrc/python_frontend/fusion_state.h @@ -94,9 +94,6 @@ class FusionState { //! Get indicies for the extents of TensorView inputs of FusionState NVF_API const std::vector& extents() const; - //! Add extents of TensorView inputs to FusionState - NVF_API void addExtents(); - //! Add a Record void addRecord(RecordFunctor* record); //! Builds an nvFuser Fusion IR object @@ -108,6 +105,8 @@ class FusionState { private: //! Get extents for TensorView inputs in Fusion std::vector getExtents(Fusion* fusion); + //! Add extents of TensorView inputs to FusionState + void addExtents(); //! Change the fusion ptr and reset its state void resetFusionState(Fusion* fusion, size_t size); diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index f0198bdc7b8..e0597757a9c 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4610,13 +4610,13 @@ def test_fusion_information(self): def fusion_func(fd: FusionDefinition) -> None: t0 = fd.from_pytorch(inputs[0]) t1 = fd.from_pytorch(inputs[1]) - c0 = fd.define_scalar(3.0) + c2 = fd.define_scalar(3.0) - t2 = fd.ops.add(t0, t1) - t3 = fd.ops.mul(t2, c0) - t4 = fd.ops.sum(t3, [-1], False, DataType.Float) + t3 = fd.ops.add(t0, t1) + t4 = fd.ops.mul(t3, c2) + t5 = fd.ops.sum(t4, [-1], False, DataType.Float) - fd.add_output(t4) + fd.add_output(t5) nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) eager_out = torch.sum((inputs[0] + inputs[1]) * 3.0, dim=-1) @@ -4628,11 +4628,15 @@ def fusion_func(fd: FusionDefinition) -> None: nvf_out1 = fd.execute(inputs) self.assertEqual(eager_out, nvf_out1[0]) + # The input tensors are t0 and t1. self.assertEqual(fd.inputs(), [0, 1]) + # The output tensors is t5. self.assertEqual(fd.outputs(), [5]) + # The extents correspond with the dimensions for each input tensor. + # There are two input tensors with three dimensions each, so the + # extents range from [-1, -6]. self.assertEqual(fd.extents(), [idx for idx in range(-1, -7, -1)]) - def test_issue_3292(self): inputs = [ torch.testing.make_tensor( @@ -4705,4 +4709,4 @@ def fusion_func(fd: FusionDefinition) -> None: fd.add_output(T223) # is_clonable=False is because translation fails with missing ceilDiv - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=False) \ No newline at end of file + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=False)