Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pass search #1562

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Implement pass search #1562

wants to merge 1 commit into from

Conversation

shaahji
Copy link
Contributor

@shaahji shaahji commented Jan 21, 2025

Implement pass search

Reimplement search logic to include passes in search space.

Checklist before requesting a review

  • Add unit tests for this change.
  • Make sure all tests can pass.
  • Update documents if necessary.
  • Lint and apply fixes to your code by running lintrunner -a
  • Is this a user-facing change? If yes, give a description of this change to be included in the release notes.
  • Is this PR including examples changes? If yes, please remember to update example documentation in a follow-up PR.

(Optional) Issue link

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

examples/mobilenet/prepare_config.py Fixed Show fixed Hide fixed
olive/engine/engine.py Fixed Show fixed Hide fixed
olive/search/samplers/__init__.py Fixed Show fixed Hide fixed
olive/search/samplers/__init__.py Fixed Show fixed Hide fixed
olive/search/search_point.py Fixed Show fixed Hide fixed
olive/search/search_strategy.py Fixed Show fixed Hide fixed
olive/workflows/run/config.py Fixed Show fixed Hide fixed
test/unit_test/search/test_search_results.py Fixed Show fixed Hide fixed
test/unit_test/search/test_search_results.py Fixed Show fixed Hide fixed
test/unit_test/search/test_search_strategy.py Fixed Show resolved Hide resolved
@shaahji shaahji force-pushed the shaahji/pass_search branch 3 times, most recently from 64d4138 to 534a692 Compare January 21, 2025 22:47
olive/search/search_parameter.py Outdated Show resolved Hide resolved
olive/search/search_parameter.py Outdated Show resolved Hide resolved
olive/search/search_parameter.py Outdated Show resolved Hide resolved
olive/search/search_results.py Outdated Show resolved Hide resolved
olive/search/search_strategy.py Outdated Show resolved Hide resolved
olive/search/search_strategy.py Outdated Show resolved Hide resolved
examples/test/local/test_bert_cuda_gpu.py Show resolved Hide resolved
@xiaoyu-work
Copy link
Contributor

Can you also update documents related to your changes?

@shaahji shaahji force-pushed the shaahji/pass_search branch 2 times, most recently from 3e6fbfd to 457d3b7 Compare January 23, 2025 09:52
@@ -8,7 +8,7 @@
import torch
import torchmetrics
import transformers
from datasets import load_dataset, load_metric
from datasets import load_dataset

Check failure

Code scanning / lintrunner

MYPY/import Error

Cannot find implementation or library stub for module named "datasets" To disable, use # type: ignore[import]
examples/bert/user_script.py Fixed Show fixed Hide fixed
try:
from datasets import load_metric
except ImportError:
from evaluate import load as load_metric

Check failure

Code scanning / lintrunner

MYPY/import Error

Cannot find implementation or library stub for module named "evaluate" To disable, use # type: ignore[import]
import re
from collections import OrderedDict

import pytest

Check failure

Code scanning / lintrunner

MYPY/import Error test

Cannot find implementation or library stub for module named "pytest" To disable, use # type: ignore[import]
@shaahji shaahji force-pushed the shaahji/pass_search branch 2 times, most recently from a87fd30 to 58fa43e Compare January 23, 2025 10:27
Reimplement search logic to include passes in search space.
@shaahji shaahji force-pushed the shaahji/pass_search branch from 58fa43e to d63c35a Compare January 23, 2025 10:31
@shaahji
Copy link
Contributor Author

shaahji commented Jan 23, 2025

Can you also update documents related to your changes?

Done!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this required?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the "BERT optimization with CUDA/TensorRT on GPU" section in the readme

if not args.use_gptq:
template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" not in flow[0]]
used_passes = [
Copy link
Contributor

@jambayk jambayk Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might have to create a new mapping from gptq/no-gptq , precision, etc.
Previously this resulted in multiple pass flows. But now, you are just flattening the pass flows into a single list.

I think the passes used might looks like "conversion_merged", "transformers_optimization_fp16", "conversion_merged", "transformers_optimization_fp16", "blockwise_quant_int4" which is not the intended behavior here. You can test this by using the --config_only option to dump the config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to either make the mapping specific or generate multiple workflows and run the separately.

@@ -186,25 +186,23 @@ def main(raw_args=None):
legacy_optimization_setting(template_json)

# add pass flows
pass_flows = [[]]
used_passes = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in general, I think using sets might be a bit risky since they are unordered.

@@ -191,7 +191,10 @@ def use_passes(template_json, *passes):
else:
del template_json["data_configs"]

template_json["pass_flows"] = [passes]
for pass_name in set(template_json["passes"].keys()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead of popping the unused passes, it might be better to create a new dict with the used passes. Popping assumes the order in template_json["passes"] is the same as in passes

@@ -23,7 +23,10 @@ def update_cuda_config(config_cuda: Dict):
if version.parse(OrtVersion) < version.parse("1.17.0"):
# disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models
config_cuda["passes"]["optimize_cuda"]["optimization_options"] = {"enable_skip_group_norm": False}
config_cuda["pass_flows"] = [["convert", "optimize_cuda"]]
used_passes = {"convert", "optimize_cuda"}
Copy link
Contributor

@jambayk jambayk Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part in stable_diffusion.py also needs to be updated to only use "convert", "optimize"


# Initialize the searcher
self._sampler = self._create_sampler()
# TODO(olivedev): There is no absolute direction to set.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be investigated later? Since I still think we need directions for the signals to make sense and the sampler to choose the next option.

if self.should_stop:
return None

while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to avoid sampling the same point? If so, could you add a comment about it?
does it do the same even with the current implementation of optuna sampler?



@dataclass
class SearchWalkState:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some docstrings to describe what each class does/is used for? thanks!

self._init_model_id: str = None

# State variables
self._path: List[int] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SearchWalkState also has a path attribute

suggestion_index = trial.suggest_categorical(suggestion_name, list(range(suggestion_len)))
suggestion = suggestions[suggestion_index]

if isinstance(suggestion, (SearchParameter, SearchSpace)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this case happen? I thought this the if block at line 87 is the base case. so it should only result in a fixed value?


spi = 0
for child_index, suggestions_len in reversed(indicies_lengths):
spi *= suggestions_len
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this guarantee uniqueness for the search point index? behaves like a generic version of binary (some base_n) encoding?

index, values[name] = SearchSpace.get_suggestion(param, index, values)
return values

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some docstrings and comments to describe the functionalities and logic. It's a bit hard to follow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants