-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement pass search #1562
base: main
Are you sure you want to change the base?
Implement pass search #1562
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
64d4138
to
534a692
Compare
Can you also update documents related to your changes? |
3e6fbfd
to
457d3b7
Compare
@@ -8,7 +8,7 @@ | |||
import torch | |||
import torchmetrics | |||
import transformers | |||
from datasets import load_dataset, load_metric | |||
from datasets import load_dataset |
Check failure
Code scanning / lintrunner
MYPY/import Error
try: | ||
from datasets import load_metric | ||
except ImportError: | ||
from evaluate import load as load_metric |
Check failure
Code scanning / lintrunner
MYPY/import Error
import re | ||
from collections import OrderedDict | ||
|
||
import pytest |
Check failure
Code scanning / lintrunner
MYPY/import Error test
a87fd30
to
58fa43e
Compare
Reimplement search logic to include passes in search space.
58fa43e
to
d63c35a
Compare
Done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update the "BERT optimization with CUDA/TensorRT on GPU" section in the readme
if not args.use_gptq: | ||
template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" not in flow[0]] | ||
used_passes = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might have to create a new mapping from gptq/no-gptq , precision, etc.
Previously this resulted in multiple pass flows. But now, you are just flattening the pass flows into a single list.
I think the passes used might looks like "conversion_merged", "transformers_optimization_fp16", "conversion_merged", "transformers_optimization_fp16", "blockwise_quant_int4"
which is not the intended behavior here. You can test this by using the --config_only
option to dump the config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to either make the mapping specific or generate multiple workflows and run the separately.
@@ -186,25 +186,23 @@ def main(raw_args=None): | |||
legacy_optimization_setting(template_json) | |||
|
|||
# add pass flows | |||
pass_flows = [[]] | |||
used_passes = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just in general, I think using sets might be a bit risky since they are unordered.
@@ -191,7 +191,10 @@ def use_passes(template_json, *passes): | |||
else: | |||
del template_json["data_configs"] | |||
|
|||
template_json["pass_flows"] = [passes] | |||
for pass_name in set(template_json["passes"].keys()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think instead of popping the unused passes, it might be better to create a new dict with the used passes. Popping assumes the order in template_json["passes"]
is the same as in passes
@@ -23,7 +23,10 @@ def update_cuda_config(config_cuda: Dict): | |||
if version.parse(OrtVersion) < version.parse("1.17.0"): | |||
# disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models | |||
config_cuda["passes"]["optimize_cuda"]["optimization_options"] = {"enable_skip_group_norm": False} | |||
config_cuda["pass_flows"] = [["convert", "optimize_cuda"]] | |||
used_passes = {"convert", "optimize_cuda"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this part in stable_diffusion.py also needs to be updated to only use "convert", "optimize"
if provider == "dml": |
|
||
# Initialize the searcher | ||
self._sampler = self._create_sampler() | ||
# TODO(olivedev): There is no absolute direction to set. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this be investigated later? Since I still think we need directions for the signals to make sense and the sampler to choose the next option.
if self.should_stop: | ||
return None | ||
|
||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this to avoid sampling the same point? If so, could you add a comment about it?
does it do the same even with the current implementation of optuna sampler?
|
||
|
||
@dataclass | ||
class SearchWalkState: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add some docstrings to describe what each class does/is used for? thanks!
self._init_model_id: str = None | ||
|
||
# State variables | ||
self._path: List[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SearchWalkState also has a path attribute
suggestion_index = trial.suggest_categorical(suggestion_name, list(range(suggestion_len))) | ||
suggestion = suggestions[suggestion_index] | ||
|
||
if isinstance(suggestion, (SearchParameter, SearchSpace)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this case happen? I thought this the if block at line 87 is the base case. so it should only result in a fixed value?
|
||
spi = 0 | ||
for child_index, suggestions_len in reversed(indicies_lengths): | ||
spi *= suggestions_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this guarantee uniqueness for the search point index? behaves like a generic version of binary (some base_n) encoding?
index, values[name] = SearchSpace.get_suggestion(param, index, values) | ||
return values | ||
|
||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add some docstrings and comments to describe the functionalities and logic. It's a bit hard to follow
Implement pass search
Reimplement search logic to include passes in search space.
Checklist before requesting a review
lintrunner -a
(Optional) Issue link