diff --git a/README.md b/README.md index 518b2ae..ca34e54 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,77 @@ _____________________________ For convenience, and to make the syntax a bit more expressive, -some operators are overloaded. For instance +some operators are overloaded. For instance, the sum operator for `GraphTuples` is overloaded. It is possible to compute residual connections +easilly as follows: + +```python +def eval_full(G, core_steps =CORE_STEPS): + ## The actual computation + gi_str.graph_tuple_eval(G) # happens in-place. + + for ncore in range(core_steps): + G += gcore.graph_tuple_eval(G.copy()) + + emb, node_output = node_regressor(G.nodes) # a function operating only on the nodes. + G.nodes = emb + gcore_out = gt_to_global(G) # A function containing node-to-global and edge-to-global aggregators. + + return gcore_out, node_output + +``` + +Currently global blocks are not implemented. +The reason for this is that it is easy to implement them in a few lines when needed. +Here is an example of implementing a graph-to-global function with the respective aggregators: + +```python +def make_graph_tuple_to_global(insize = GN_STATE, agg_type = 'mean', global_state_out = 3, type_ = "node", local_bnn = LOCALBNN): + + # It would have been much cleaner if I supported this in the library... + agg_fcn = make_keras_simple_agg(insize,agg_type) # from ibk_gnns import make_keras_simple_agg + agg_fcn = agg_fcn[1] + + # Constructing the node+edge -> global function. + xx = Input(shape = (insize,)) + + if local_bnn == False: + out = Dense(insize, 'relu')(xx) + out = Dense(insize, 'relu')(out) + else: + out = Dense(insize)(xx) + out = tfp.layers.DenseLocalReparameterization(insize,'relu')(out) + out = Dense(insize,'relu')(out) + + out = Dense(global_state_out, activation = GLOBAL_OUTPUT_ACTIVATION, use_bias = False)(out) + + global_fcn = Model(inputs = xx, outputs = out) + bnnlosses = global_fcn.losses + + def fcn_node_and_edge(gt): + graph_indices_nodes = [] + for k_,k in enumerate(gt.n_nodes): + graph_indices_nodes.extend(np.ones(k).astype("int")*k_) + + graph_indices_edges = [] + for k_,k in enumerate(gt.n_edges): + graph_indices_edges.extend(np.ones(k).astype("int")*k_) + + o1 = agg_fcn(gt.nodes,graph_indices_nodes, gt.n_graphs) # node_to_global aggregation + o2 = agg_fcn(gt.edges,graph_indices_edges, gt.n_graphs) # edge_to_global aggregation + return global_fcn(o1+o2) # either concat or add the aggregated information. + + def fcn_node(gt): + graph_indices_nodes = [] + for k_,k in enumerate(gt.n_nodes): + graph_indices_nodes.extend(np.ones(k).astype("int")*k_) + + o1 = agg_fcn(gt.nodes,graph_indices_nodes, gt.n_graphs) + return global_fcn(o1) + + fcn_dict = {'node': fcn_node,'node_and_edge' : fcn_node_and_edge} + return fcn_dict[type_], global_fcn, bnnlosses + +```