Skip to content

Commit

Permalink
Edited README.md some more.
Browse files Browse the repository at this point in the history
  • Loading branch information
mylonasc committed Dec 14, 2020
1 parent ec2727a commit 16365c4
Showing 1 changed file with 71 additions and 1 deletion.
72 changes: 71 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,77 @@ _____________________________


For convenience, and to make the syntax a bit more expressive,
some operators are overloaded. For instance
some operators are overloaded. For instance, the sum operator for `GraphTuples` is overloaded. It is possible to compute residual connections
easilly as follows:

```python
def eval_full(G, core_steps =CORE_STEPS):
## The actual computation
gi_str.graph_tuple_eval(G) # happens in-place.

for ncore in range(core_steps):
G += gcore.graph_tuple_eval(G.copy())

emb, node_output = node_regressor(G.nodes) # a function operating only on the nodes.
G.nodes = emb
gcore_out = gt_to_global(G) # A function containing node-to-global and edge-to-global aggregators.

return gcore_out, node_output

```

Currently global blocks are not implemented.
The reason for this is that it is easy to implement them in a few lines when needed.
Here is an example of implementing a graph-to-global function with the respective aggregators:

```python
def make_graph_tuple_to_global(insize = GN_STATE, agg_type = 'mean', global_state_out = 3, type_ = "node", local_bnn = LOCALBNN):

# It would have been much cleaner if I supported this in the library...
agg_fcn = make_keras_simple_agg(insize,agg_type) # from ibk_gnns import make_keras_simple_agg
agg_fcn = agg_fcn[1]

# Constructing the node+edge -> global function.
xx = Input(shape = (insize,))

if local_bnn == False:
out = Dense(insize, 'relu')(xx)
out = Dense(insize, 'relu')(out)
else:
out = Dense(insize)(xx)
out = tfp.layers.DenseLocalReparameterization(insize,'relu')(out)
out = Dense(insize,'relu')(out)

out = Dense(global_state_out, activation = GLOBAL_OUTPUT_ACTIVATION, use_bias = False)(out)

global_fcn = Model(inputs = xx, outputs = out)
bnnlosses = global_fcn.losses

def fcn_node_and_edge(gt):
graph_indices_nodes = []
for k_,k in enumerate(gt.n_nodes):
graph_indices_nodes.extend(np.ones(k).astype("int")*k_)

graph_indices_edges = []
for k_,k in enumerate(gt.n_edges):
graph_indices_edges.extend(np.ones(k).astype("int")*k_)

o1 = agg_fcn(gt.nodes,graph_indices_nodes, gt.n_graphs) # node_to_global aggregation
o2 = agg_fcn(gt.edges,graph_indices_edges, gt.n_graphs) # edge_to_global aggregation
return global_fcn(o1+o2) # either concat or add the aggregated information.

def fcn_node(gt):
graph_indices_nodes = []
for k_,k in enumerate(gt.n_nodes):
graph_indices_nodes.extend(np.ones(k).astype("int")*k_)

o1 = agg_fcn(gt.nodes,graph_indices_nodes, gt.n_graphs)
return global_fcn(o1)

fcn_dict = {'node': fcn_node,'node_and_edge' : fcn_node_and_edge}
return fcn_dict[type_], global_fcn, bnnlosses

```



Expand Down

0 comments on commit 16365c4

Please sign in to comment.