Skip to content

Commit

Permalink
Add more hyperparams for EvolutionOptimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
UsingtcNower committed Apr 16, 2024
1 parent 7105e8c commit e703048
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 6 deletions.
24 changes: 24 additions & 0 deletions nncf/config/schemata/algo/filter_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from nncf.config.schemata.defaults import PRUNING_INTERLAYER_RANKING_TYPE
from nncf.config.schemata.defaults import PRUNING_LEGR_GENERATIONS
from nncf.config.schemata.defaults import PRUNING_LEGR_MAX_PRUNING
from nncf.config.schemata.defaults import PRUNING_LEGR_MUTATE_PERCENT
from nncf.config.schemata.defaults import PRUNING_LEGR_NUM_SAMPLES
from nncf.config.schemata.defaults import PRUNING_LEGR_POPULATION_SIZE
from nncf.config.schemata.defaults import PRUNING_LEGR_RANDOM_SEED
from nncf.config.schemata.defaults import PRUNING_LEGR_SIGMA_SCALE
from nncf.config.schemata.defaults import PRUNING_LEGR_TRAIN_STEPS
from nncf.config.schemata.defaults import PRUNING_NUM_INIT_STEPS
from nncf.config.schemata.defaults import PRUNING_SCHEDULE
Expand Down Expand Up @@ -162,6 +166,26 @@
description="Random seed for LeGR coefficients generation.",
default=PRUNING_LEGR_RANDOM_SEED,
),
"population_size": with_attributes(
NUMBER,
description="Size of population for the evolution algorithm.",
default=PRUNING_LEGR_POPULATION_SIZE,
),
"num_samples": with_attributes(
NUMBER,
description="Number of samples for the evolution algorithm.",
default=PRUNING_LEGR_NUM_SAMPLES,
),
"mutate_percent": with_attributes(
NUMBER,
description="Percent of mutate for the evolution algorithm.",
default=PRUNING_LEGR_MUTATE_PERCENT,
),
"scale_sigma": with_attributes(
NUMBER,
description="Scale sigma for the evolution algorithm.",
default=PRUNING_LEGR_SIGMA_SCALE,
),
},
"additionalProperties": False,
},
Expand Down
4 changes: 4 additions & 0 deletions nncf/config/schemata/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
PRUNING_LEGR_TRAIN_STEPS = 200
PRUNING_LEGR_MAX_PRUNING = 0.8
PRUNING_LEGR_RANDOM_SEED = 42
PRUNING_LEGR_POPULATION_SIZE = 64
PRUNING_LEGR_NUM_SAMPLES = 16
PRUNING_LEGR_MUTATE_PERCENT = 0.1
PRUNING_LEGR_SIGMA_SCALE = 1

SPARSITY_INIT = 0.0
MAGNITUDE_SPARSITY_WEIGHT_IMPORTANCE = "normed_abs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from torch import optim

from nncf.config.config import NNCFConfig
from nncf.config.schemata.defaults import PRUNING_LEGR_GENERATIONS
from nncf.config.schemata.defaults import PRUNING_LEGR_MUTATE_PERCENT
from nncf.config.schemata.defaults import PRUNING_LEGR_NUM_SAMPLES
from nncf.config.schemata.defaults import PRUNING_LEGR_POPULATION_SIZE
from nncf.config.schemata.defaults import PRUNING_LEGR_SIGMA_SCALE
from nncf.torch.utils import get_filters_num


Expand Down Expand Up @@ -48,11 +53,11 @@ def __init__(self, initial_filter_norms: Dict, hparams: Dict, random_seed: int):
"""
self.random_seed = random_seed
# Optimizer hyper-params
self.population_size = hparams.get("population_size", 64)
self.num_generations = hparams.get("num_generations", 400)
self.num_samples = hparams.get("num_samples", 16)
self.mutate_percent = hparams.get("mutate_percent", 0.1)
self.scale_sigma = hparams.get("sigma_scale", 1)
self.population_size = hparams.get("population_size", PRUNING_LEGR_POPULATION_SIZE)
self.num_generations = hparams.get("num_generations", PRUNING_LEGR_GENERATIONS)
self.num_samples = hparams.get("num_samples", PRUNING_LEGR_NUM_SAMPLES)
self.mutate_percent = hparams.get("mutate_percent", PRUNING_LEGR_MUTATE_PERCENT)
self.scale_sigma = hparams.get("sigma_scale", PRUNING_LEGR_SIGMA_SCALE)
self.max_reward = -np.inf
self.mean_rewards = []

Expand Down
20 changes: 19 additions & 1 deletion nncf/torch/pruning/filter_pruning/global_ranking/legr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from nncf.common.logging import nncf_logger
from nncf.config.schemata.defaults import PRUNING_LEGR_GENERATIONS
from nncf.config.schemata.defaults import PRUNING_LEGR_MAX_PRUNING
from nncf.config.schemata.defaults import PRUNING_LEGR_MUTATE_PERCENT
from nncf.config.schemata.defaults import PRUNING_LEGR_NUM_SAMPLES
from nncf.config.schemata.defaults import PRUNING_LEGR_POPULATION_SIZE
from nncf.config.schemata.defaults import PRUNING_LEGR_RANDOM_SEED
from nncf.config.schemata.defaults import PRUNING_LEGR_SIGMA_SCALE
from nncf.config.schemata.defaults import PRUNING_LEGR_TRAIN_STEPS
from nncf.torch.pruning.filter_pruning.global_ranking.evolutionary_optimization import EvolutionOptimizer
from nncf.torch.pruning.filter_pruning.global_ranking.evolutionary_optimization import LeGREvolutionEnv
Expand All @@ -38,6 +42,10 @@ def __init__(
generations: int = PRUNING_LEGR_GENERATIONS,
max_pruning: float = PRUNING_LEGR_MAX_PRUNING,
random_seed: int = PRUNING_LEGR_RANDOM_SEED,
population_size: int = PRUNING_LEGR_POPULATION_SIZE,
num_samples: int = PRUNING_LEGR_NUM_SAMPLES,
mutate_percent: float = PRUNING_LEGR_MUTATE_PERCENT,
scale_sigma: float = PRUNING_LEGR_SIGMA_SCALE,
):
"""
Initializing all necessary structures for optimization- LeGREvolutionEnv environment and EvolutionOptimizer
Expand All @@ -53,10 +61,20 @@ def __init__(
self.num_generations = generations
self.max_pruning = max_pruning
self.train_steps = train_steps
self.population_size = population_size
self.num_samples = num_samples
self.mutate_percent = mutate_percent
self.scale_sigma = scale_sigma

self.pruner = LeGRPruner(pruning_ctrl, target_model)
init_filter_norms = self.pruner.init_filter_norms
agent_hparams = {"num_generations": self.num_generations}
agent_hparams = {
"num_generations": self.num_generations,
"population_size": self.population_size,
"num_samples": self.num_samples,
"mutate_percent": self.mutate_percent,
"sigma_scale": self.scale_sigma,
}
self.agent = EvolutionOptimizer(init_filter_norms, agent_hparams, random_seed)
self.env = LeGREvolutionEnv(
self.pruner,
Expand Down

0 comments on commit e703048

Please sign in to comment.