-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Record cudagraphs when weight streaming budget has changed #3309
Conversation
core/runtime/execute_engine.cpp
Outdated
@@ -115,11 +115,15 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |||
} | |||
|
|||
// Whether cudagraphs needs to record the graph on this pass | |||
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine))); | |||
bool need_cudagraphs_record = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we need something a bit more comprehensive than a bunch of booleans for cudagraph rerecord
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with it. I'm thinking of below structure. Could you review [PR3276]?(#3276)
https://github.com/pytorch/TensorRT/pull/3276/files#diff-3d6304b0f21f64bdf7867ceac346866d505619401e9c7b6a696cfe3bb567254eR37-R63
7bb66da
to
2050887
Compare
Will adopt proper structure for runtime states after #3276 is reviewed. |
2050887
to
7a2d58e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionality LGTM
def test_weight_streaming_cudagraphs(self, _, use_python_runtime): | ||
model = SampleModel().eval().cuda() | ||
input = [torch.randn(*INPUT_SIZE, dtype=torch.float32).cuda()] | ||
fx_graph = torch.fx.symbolic_trace(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you use torch.export.export()
to export the model?
name="x", | ||
) | ||
model = SampleModel().eval().cuda() | ||
fx_graph = torch.fx.symbolic_trace(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here. Use torch.export.export()
Description
Add flag indicates to reevaluate the runtime settings when weight streaming budget is changed.
Fixes #3308
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: