Skip to content

Commit

Permalink
apply code suggestion, fix black
Browse files Browse the repository at this point in the history
Signed-off-by: luarss <[email protected]>
  • Loading branch information
luarss committed Oct 16, 2024
1 parent dde2995 commit 7c60e29
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions tools/AutoTuner/src/autotuner/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,11 @@ def parse_arguments():
help="Perturbation interval for PopulationBasedTraining.",
)
tune_parser.add_argument(
"--seed", type=int, metavar="<int>", default=42, help="Random seed. (0 means no seed.)"
"--seed",
type=int,
metavar="<int>",
default=42,
help="Random seed. (0 means no seed.)",
)

# Workload
Expand Down Expand Up @@ -873,10 +877,16 @@ def set_algorithm(experiment_name, config):
"""
# Pre-set seed if user sets seed to 0
if args.seed == 0:
print("Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)")
print(
"Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)"
)
if input().lower() != "y":
sys.exit(0)
args.seed = None
else:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

if args.algorithm == "hyperopt":
algorithm = HyperOptSearch(
Expand All @@ -896,10 +906,7 @@ def set_algorithm(experiment_name, config):
)
algorithm = AxSearch(ax_client=ax_client, points_to_evaluate=best_params)
elif args.algorithm == "optuna":
algorithm = OptunaSearch(
points_to_evaluate=best_params,
seed=args.seed
)
algorithm = OptunaSearch(points_to_evaluate=best_params, seed=args.seed)
elif args.algorithm == "pbt":
print("Warning: PBT does not support seed values. args.seed will be ignored.")
algorithm = PopulationBasedTraining(
Expand All @@ -911,18 +918,13 @@ def set_algorithm(experiment_name, config):
elif args.algorithm == "random":
algorithm = BasicVariantGenerator(
max_concurrent=args.jobs,
random_state=args.seed,)
random_state=args.seed,
)

# A wrapper algorithm for limiting the number of concurrent trials.
if args.algorithm not in ["random", "pbt"]:
algorithm = ConcurrencyLimiter(algorithm, max_concurrent=args.jobs)

# Self seed
if args.seed is not None:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

return algorithm


Expand Down

0 comments on commit 7c60e29

Please sign in to comment.