diff --git a/nncf/config/schemata/algo/filter_pruning.py b/nncf/config/schemata/algo/filter_pruning.py index 31c184346ac..25efa70c930 100644 --- a/nncf/config/schemata/algo/filter_pruning.py +++ b/nncf/config/schemata/algo/filter_pruning.py @@ -27,7 +27,11 @@ from nncf.config.schemata.defaults import PRUNING_INTERLAYER_RANKING_TYPE from nncf.config.schemata.defaults import PRUNING_LEGR_GENERATIONS from nncf.config.schemata.defaults import PRUNING_LEGR_MAX_PRUNING +from nncf.config.schemata.defaults import PRUNING_LEGR_MUTATE_PERCENT +from nncf.config.schemata.defaults import PRUNING_LEGR_NUM_SAMPLES +from nncf.config.schemata.defaults import PRUNING_LEGR_POPULATION_SIZE from nncf.config.schemata.defaults import PRUNING_LEGR_RANDOM_SEED +from nncf.config.schemata.defaults import PRUNING_LEGR_SIGMA_SCALE from nncf.config.schemata.defaults import PRUNING_LEGR_TRAIN_STEPS from nncf.config.schemata.defaults import PRUNING_NUM_INIT_STEPS from nncf.config.schemata.defaults import PRUNING_SCHEDULE @@ -162,6 +166,26 @@ description="Random seed for LeGR coefficients generation.", default=PRUNING_LEGR_RANDOM_SEED, ), + "population_size": with_attributes( + NUMBER, + description="Size of population for the evolution algorithm.", + default=PRUNING_LEGR_POPULATION_SIZE, + ), + "num_samples": with_attributes( + NUMBER, + description="Number of samples for the evolution algorithm.", + default=PRUNING_LEGR_NUM_SAMPLES, + ), + "mutate_percent": with_attributes( + NUMBER, + description="Percent of mutate for the evolution algorithm.", + default=PRUNING_LEGR_MUTATE_PERCENT, + ), + "scale_sigma": with_attributes( + NUMBER, + description="Scale sigma for the evolution algorithm.", + default=PRUNING_LEGR_SIGMA_SCALE, + ), }, "additionalProperties": False, }, diff --git a/nncf/config/schemata/defaults.py b/nncf/config/schemata/defaults.py index d1a6471a768..3bf599fce95 100644 --- a/nncf/config/schemata/defaults.py +++ b/nncf/config/schemata/defaults.py @@ -59,6 +59,10 @@ PRUNING_LEGR_TRAIN_STEPS = 200 PRUNING_LEGR_MAX_PRUNING = 0.8 PRUNING_LEGR_RANDOM_SEED = 42 +PRUNING_LEGR_POPULATION_SIZE = 64 +PRUNING_LEGR_NUM_SAMPLES = 16 +PRUNING_LEGR_MUTATE_PERCENT = 0.1 +PRUNING_LEGR_SIGMA_SCALE = 1 SPARSITY_INIT = 0.0 MAGNITUDE_SPARSITY_WEIGHT_IMPORTANCE = "normed_abs" diff --git a/nncf/torch/pruning/filter_pruning/global_ranking/evolutionary_optimization.py b/nncf/torch/pruning/filter_pruning/global_ranking/evolutionary_optimization.py index def4266cbd9..9d9d69a8de4 100644 --- a/nncf/torch/pruning/filter_pruning/global_ranking/evolutionary_optimization.py +++ b/nncf/torch/pruning/filter_pruning/global_ranking/evolutionary_optimization.py @@ -20,6 +20,11 @@ from torch import optim from nncf.config.config import NNCFConfig +from nncf.config.schemata.defaults import PRUNING_LEGR_GENERATIONS +from nncf.config.schemata.defaults import PRUNING_LEGR_MUTATE_PERCENT +from nncf.config.schemata.defaults import PRUNING_LEGR_NUM_SAMPLES +from nncf.config.schemata.defaults import PRUNING_LEGR_POPULATION_SIZE +from nncf.config.schemata.defaults import PRUNING_LEGR_SIGMA_SCALE from nncf.torch.utils import get_filters_num @@ -48,11 +53,11 @@ def __init__(self, initial_filter_norms: Dict, hparams: Dict, random_seed: int): """ self.random_seed = random_seed # Optimizer hyper-params - self.population_size = hparams.get("population_size", 64) - self.num_generations = hparams.get("num_generations", 400) - self.num_samples = hparams.get("num_samples", 16) - self.mutate_percent = hparams.get("mutate_percent", 0.1) - self.scale_sigma = hparams.get("sigma_scale", 1) + self.population_size = hparams.get("population_size", PRUNING_LEGR_POPULATION_SIZE) + self.num_generations = hparams.get("num_generations", PRUNING_LEGR_GENERATIONS) + self.num_samples = hparams.get("num_samples", PRUNING_LEGR_NUM_SAMPLES) + self.mutate_percent = hparams.get("mutate_percent", PRUNING_LEGR_MUTATE_PERCENT) + self.scale_sigma = hparams.get("sigma_scale", PRUNING_LEGR_SIGMA_SCALE) self.max_reward = -np.inf self.mean_rewards = [] diff --git a/nncf/torch/pruning/filter_pruning/global_ranking/legr.py b/nncf/torch/pruning/filter_pruning/global_ranking/legr.py index 2949ded0469..d307151eec2 100644 --- a/nncf/torch/pruning/filter_pruning/global_ranking/legr.py +++ b/nncf/torch/pruning/filter_pruning/global_ranking/legr.py @@ -15,7 +15,11 @@ from nncf.common.logging import nncf_logger from nncf.config.schemata.defaults import PRUNING_LEGR_GENERATIONS from nncf.config.schemata.defaults import PRUNING_LEGR_MAX_PRUNING +from nncf.config.schemata.defaults import PRUNING_LEGR_MUTATE_PERCENT +from nncf.config.schemata.defaults import PRUNING_LEGR_NUM_SAMPLES +from nncf.config.schemata.defaults import PRUNING_LEGR_POPULATION_SIZE from nncf.config.schemata.defaults import PRUNING_LEGR_RANDOM_SEED +from nncf.config.schemata.defaults import PRUNING_LEGR_SIGMA_SCALE from nncf.config.schemata.defaults import PRUNING_LEGR_TRAIN_STEPS from nncf.torch.pruning.filter_pruning.global_ranking.evolutionary_optimization import EvolutionOptimizer from nncf.torch.pruning.filter_pruning.global_ranking.evolutionary_optimization import LeGREvolutionEnv @@ -38,6 +42,10 @@ def __init__( generations: int = PRUNING_LEGR_GENERATIONS, max_pruning: float = PRUNING_LEGR_MAX_PRUNING, random_seed: int = PRUNING_LEGR_RANDOM_SEED, + population_size: int = PRUNING_LEGR_POPULATION_SIZE, + num_samples: int = PRUNING_LEGR_NUM_SAMPLES, + mutate_percent: float = PRUNING_LEGR_MUTATE_PERCENT, + scale_sigma: float = PRUNING_LEGR_SIGMA_SCALE, ): """ Initializing all necessary structures for optimization- LeGREvolutionEnv environment and EvolutionOptimizer @@ -53,10 +61,20 @@ def __init__( self.num_generations = generations self.max_pruning = max_pruning self.train_steps = train_steps + self.population_size = population_size + self.num_samples = num_samples + self.mutate_percent = mutate_percent + self.scale_sigma = scale_sigma self.pruner = LeGRPruner(pruning_ctrl, target_model) init_filter_norms = self.pruner.init_filter_norms - agent_hparams = {"num_generations": self.num_generations} + agent_hparams = { + "num_generations": self.num_generations, + "population_size": self.population_size, + "num_samples": self.num_samples, + "mutate_percent": self.mutate_percent, + "sigma_scale": self.scale_sigma, + } self.agent = EvolutionOptimizer(init_filter_norms, agent_hparams, random_seed) self.env = LeGREvolutionEnv( self.pruner,