diff --git a/ml_genn/ml_genn/compilers/eprop_compiler.py b/ml_genn/ml_genn/compilers/eprop_compiler.py index a688e319..ff94d6bd 100644 --- a/ml_genn/ml_genn/compilers/eprop_compiler.py +++ b/ml_genn/ml_genn/compilers/eprop_compiler.py @@ -606,9 +606,8 @@ def create_compiled_network(self, genn_model, neuron_populations: dict, base_train_callbacks.append(CustomUpdateOnBatchBegin("Reset")) base_validate_callbacks.append(CustomUpdateOnBatchBegin("Reset")) - # If Deep-R is required on any connections, - # trigger Deep-R callbacks at end of batch - if len(self.deep_r_exc_conns) > 0 or len(self.deep_r_inh_conns) > 0: + # If Deep-R is required, trigger Deep-R callbacks at end of batch + if deep_r_required: base_train_callbacks.append(CustomUpdateOnTrainBegin("DeepRInit")) base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR1")) base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR2")) diff --git a/ml_genn/ml_genn/compilers/event_prop_compiler.py b/ml_genn/ml_genn/compilers/event_prop_compiler.py index d4c599de..da070bea 100644 --- a/ml_genn/ml_genn/compilers/event_prop_compiler.py +++ b/ml_genn/ml_genn/compilers/event_prop_compiler.py @@ -8,6 +8,7 @@ from .compiler import Compiler from .compiled_training_network import CompiledTrainingNetwork +from .deep_r import RewiringRecord from .. import Connection, Population, Network from ..callbacks import (BatchProgressBar, Callback, CustomUpdateOnBatchBegin, CustomUpdateOnBatchEnd, CustomUpdateOnEpochEnd, @@ -30,6 +31,7 @@ from copy import deepcopy from pygenn import create_egp_ref, create_var_ref, create_wu_var_ref from .compiler import create_reset_custom_update +from .deep_r import add_deep_r from ..utils.module import get_object, get_object_mapping from ..utils.network import get_underlying_conn, get_underlying_pop from ..utils.value import is_value_array, is_value_constant @@ -486,6 +488,10 @@ def __init__(self, example_timesteps: int, losses, optimiser="adam", communicator: Communicator = None, delay_optimiser=None, delay_learn_conns: Sequence = [], + deep_r_exc_conns: Sequence = [], + deep_r_inh_conns: Sequence = [], + deep_r_l1_strength: float = 0.01, + deep_r_record_rewirings = {}, **genn_kwargs): supported_matrix_types = [SynapseMatrixType.TOEPLITZ, SynapseMatrixType.PROCEDURAL_KERNELG, @@ -511,6 +517,12 @@ def __init__(self, example_timesteps: int, losses, optimiser="adam", Optimiser, "Optimiser", default_optimisers) self.delay_learn_conns = set(get_underlying_conn(c) for c in delay_learn_conns) + self.deep_r_exc_conns = set(get_underlying_conn(c) + for c in deep_r_exc_conns) + self.deep_r_inh_conns = set(get_underlying_conn(c) + for c in deep_r_inh_conns) + self.deep_r_l1_strength = deep_r_l1_strength + self.deep_r_record_rewirings = deep_r_record_rewirings def pre_compile(self, network: Network, genn_model, **kwargs) -> CompileState: @@ -1081,16 +1093,33 @@ def create_compiled_network(self, genn_model, neuron_populations: dict, # Loop through connections that require optimisers weight_optimiser_cus = [] delay_optimiser_cus = [] + deep_r_record_rewirings_ccus = [] for i, (c, w, d) in enumerate(compile_state.optimiser_connections): genn_pop = connection_populations[c] # If weight optimisation is required gradient_vars = [] if w: + # If connection is in list of those to use Deep-R on + gradient_var_ref = create_wu_var_ref(genn_pop, "Gradient") + weight_var_ref = create_wu_var_ref(genn_pop, "g") + if c in self.deep_r_inh_conns or c in self.deep_r_exc_conns: + # Add infrastructure + excitatory = (c in self.deep_r_exc_conns) + deep_r_2_ccu = add_deep_r(genn_pop, genn_model, self, + self.deep_r_l1_strength, + delta_g_var_ref, weight_var_ref, + excitatory) + + # If we should record rewirings from + # this connection, add to list with key + if c in self.deep_r_record_rewirings: + deep_r_record_rewirings_ccus.append( + (deep_r_2_ccu, self.deep_r_record_rewirings[c])) + # Create weight optimiser custom update cu_weight = self._create_optimiser_custom_update( - f"Weight{i}", create_wu_var_ref(genn_pop, "g"), - create_wu_var_ref(genn_pop, "Gradient"), + f"Weight{i}", weight_var_ref, gradient_var_ref, self._optimiser, genn_model) # Add custom update to list of optimisers @@ -1150,6 +1179,13 @@ def create_compiled_network(self, genn_model, neuron_populations: dict, # Build list of base callbacks base_train_callbacks = [] base_validate_callbacks = [] + deep_r_required = (len(self.deep_r_exc_conns) > 0 + or len(self.deep_r_inh_conns) > 0) + + # If Deep-R and L1 regularisation are required, add callback + if deep_r_required and self.deep_r_l1_strength > 0.0: + base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepRL1")) + if len(weight_optimiser_cus) > 0 or len(delay_optimiser_cus) > 0: if self.full_batch_size > 1: base_train_callbacks.append( @@ -1186,7 +1222,13 @@ def create_compiled_network(self, genn_model, neuron_populations: dict, if compile_state.is_reset_custom_update_required: base_train_callbacks.append(CustomUpdateOnBatchBegin("Reset")) base_validate_callbacks.append(CustomUpdateOnBatchBegin("Reset")) - + + # If Deep-R is required, trigger Deep-R callbacks at end of batch + if deep_r_required: + base_train_callbacks.append(CustomUpdateOnTrainBegin("DeepRInit")) + base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR1")) + base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR2")) + # Build list of optimisers and their custom updates optimisers = [] if len(weight_optimiser_cus) > 0: