-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(models): Add model comm group to predict_step #77
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
@@ -94,7 +96,7 @@ def _build_model(self) -> None: | |||
# Use the forward method of the model directly | |||
self.forward = self.model.forward | |||
|
|||
def predict_step(self, batch: torch.Tensor) -> torch.Tensor: | |||
def predict_step(self, batch: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a **kwargs
here for future proofing? Then we don't need any logic in inference to maintain backwards compatibility when new optionals are added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be my pleasure. And then pass that to self()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not pass it to self, so that that remains explicit. Passing kwargs blindly to sub-functions is something we should try to avoid in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small change to allow predict_step to take an optional model_comm_group, like model.forward does.
This is needed for parallel-inference, which calls predict_step