Skip to content

Commit

Permalink
added support for Deep-R to EventProp compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
neworderofjamie committed Sep 5, 2024
1 parent 6a3db4b commit 6cc4838
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
5 changes: 2 additions & 3 deletions ml_genn/ml_genn/compilers/eprop_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,8 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
base_train_callbacks.append(CustomUpdateOnBatchBegin("Reset"))
base_validate_callbacks.append(CustomUpdateOnBatchBegin("Reset"))

# If Deep-R is required on any connections,
# trigger Deep-R callbacks at end of batch
if len(self.deep_r_exc_conns) > 0 or len(self.deep_r_inh_conns) > 0:
# If Deep-R is required, trigger Deep-R callbacks at end of batch
if deep_r_required:
base_train_callbacks.append(CustomUpdateOnTrainBegin("DeepRInit"))
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR1"))
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR2"))
Expand Down
48 changes: 45 additions & 3 deletions ml_genn/ml_genn/compilers/event_prop_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .compiler import Compiler
from .compiled_training_network import CompiledTrainingNetwork
from .deep_r import RewiringRecord
from .. import Connection, Population, Network
from ..callbacks import (BatchProgressBar, Callback, CustomUpdateOnBatchBegin,
CustomUpdateOnBatchEnd, CustomUpdateOnEpochEnd,
Expand All @@ -30,6 +31,7 @@
from copy import deepcopy
from pygenn import create_egp_ref, create_var_ref, create_wu_var_ref
from .compiler import create_reset_custom_update
from .deep_r import add_deep_r
from ..utils.module import get_object, get_object_mapping
from ..utils.network import get_underlying_conn, get_underlying_pop
from ..utils.value import is_value_array, is_value_constant
Expand Down Expand Up @@ -486,6 +488,10 @@ def __init__(self, example_timesteps: int, losses, optimiser="adam",
communicator: Communicator = None,
delay_optimiser=None,
delay_learn_conns: Sequence = [],
deep_r_exc_conns: Sequence = [],
deep_r_inh_conns: Sequence = [],
deep_r_l1_strength: float = 0.01,
deep_r_record_rewirings = {},
**genn_kwargs):
supported_matrix_types = [SynapseMatrixType.TOEPLITZ,
SynapseMatrixType.PROCEDURAL_KERNELG,
Expand All @@ -511,6 +517,12 @@ def __init__(self, example_timesteps: int, losses, optimiser="adam",
Optimiser, "Optimiser", default_optimisers)
self.delay_learn_conns = set(get_underlying_conn(c)
for c in delay_learn_conns)
self.deep_r_exc_conns = set(get_underlying_conn(c)
for c in deep_r_exc_conns)
self.deep_r_inh_conns = set(get_underlying_conn(c)
for c in deep_r_inh_conns)
self.deep_r_l1_strength = deep_r_l1_strength
self.deep_r_record_rewirings = deep_r_record_rewirings

def pre_compile(self, network: Network,
genn_model, **kwargs) -> CompileState:
Expand Down Expand Up @@ -1081,16 +1093,33 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
# Loop through connections that require optimisers
weight_optimiser_cus = []
delay_optimiser_cus = []
deep_r_record_rewirings_ccus = []
for i, (c, w, d) in enumerate(compile_state.optimiser_connections):
genn_pop = connection_populations[c]

# If weight optimisation is required
gradient_vars = []
if w:
# If connection is in list of those to use Deep-R on
gradient_var_ref = create_wu_var_ref(genn_pop, "Gradient")
weight_var_ref = create_wu_var_ref(genn_pop, "g")
if c in self.deep_r_inh_conns or c in self.deep_r_exc_conns:
# Add infrastructure
excitatory = (c in self.deep_r_exc_conns)
deep_r_2_ccu = add_deep_r(genn_pop, genn_model, self,
self.deep_r_l1_strength,
delta_g_var_ref, weight_var_ref,
excitatory)

# If we should record rewirings from
# this connection, add to list with key
if c in self.deep_r_record_rewirings:
deep_r_record_rewirings_ccus.append(
(deep_r_2_ccu, self.deep_r_record_rewirings[c]))

# Create weight optimiser custom update
cu_weight = self._create_optimiser_custom_update(
f"Weight{i}", create_wu_var_ref(genn_pop, "g"),
create_wu_var_ref(genn_pop, "Gradient"),
f"Weight{i}", weight_var_ref, gradient_var_ref,
self._optimiser, genn_model)

# Add custom update to list of optimisers
Expand Down Expand Up @@ -1150,6 +1179,13 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
# Build list of base callbacks
base_train_callbacks = []
base_validate_callbacks = []
deep_r_required = (len(self.deep_r_exc_conns) > 0
or len(self.deep_r_inh_conns) > 0)

# If Deep-R and L1 regularisation are required, add callback
if deep_r_required and self.deep_r_l1_strength > 0.0:
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepRL1"))

if len(weight_optimiser_cus) > 0 or len(delay_optimiser_cus) > 0:
if self.full_batch_size > 1:
base_train_callbacks.append(
Expand Down Expand Up @@ -1186,7 +1222,13 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
if compile_state.is_reset_custom_update_required:
base_train_callbacks.append(CustomUpdateOnBatchBegin("Reset"))
base_validate_callbacks.append(CustomUpdateOnBatchBegin("Reset"))


# If Deep-R is required, trigger Deep-R callbacks at end of batch
if deep_r_required:
base_train_callbacks.append(CustomUpdateOnTrainBegin("DeepRInit"))
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR1"))
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR2"))

# Build list of optimisers and their custom updates
optimisers = []
if len(weight_optimiser_cus) > 0:
Expand Down

0 comments on commit 6cc4838

Please sign in to comment.