Skip to content

Commit

Permalink
Replace Input with InputLayer everywhere for the sake of consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorC committed Jan 7, 2023
1 parent b2a9a5c commit 59a5b18
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 51 deletions.
4 changes: 2 additions & 2 deletions examples/diffusion_1d_pidon.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
),
model_args=ModelArgs(
branch_net=tf.keras.Sequential(
[tf.keras.layers.Input(np.prod(cp.y_vertices_shape).item())]
[tf.keras.layers.InputLayer(np.prod(cp.y_vertices_shape).item())]
+ [tf.keras.layers.Dense(50, activation="tanh") for _ in range(8)]
),
trunk_net=tf.keras.Sequential(
[tf.keras.layers.Input(diff_eq.x_dimension + 1)]
[tf.keras.layers.InputLayer(diff_eq.x_dimension + 1)]
+ [tf.keras.layers.Dense(50, activation="tanh") for _ in range(8)]
),
combiner_net=tf.keras.Sequential(
Expand Down
4 changes: 2 additions & 2 deletions examples/lotka_volterra_pidon.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
),
model_args=ModelArgs(
branch_net=tf.keras.Sequential(
[tf.keras.layers.Input(np.prod(cp.y_vertices_shape).item())]
[tf.keras.layers.InputLayer(np.prod(cp.y_vertices_shape).item())]
+ [tf.keras.layers.Dense(100, activation="tanh") for _ in range(6)]
),
trunk_net=tf.keras.Sequential(
[tf.keras.layers.Input(diff_eq.x_dimension + 1)]
[tf.keras.layers.InputLayer(diff_eq.x_dimension + 1)]
+ [tf.keras.layers.Dense(100, activation="tanh") for _ in range(6)]
),
combiner_net=tf.keras.Sequential(
Expand Down
4 changes: 2 additions & 2 deletions examples/population_growth_pidon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
),
model_args=ModelArgs(
branch_net=tf.keras.Sequential(
[tf.keras.layers.Input(np.prod(cp.y_vertices_shape).item())]
[tf.keras.layers.InputLayer(np.prod(cp.y_vertices_shape).item())]
+ [tf.keras.layers.Dense(100, activation="tanh") for _ in range(6)]
),
trunk_net=tf.keras.Sequential(
[tf.keras.layers.Input(diff_eq.x_dimension + 1)]
[tf.keras.layers.InputLayer(diff_eq.x_dimension + 1)]
+ [tf.keras.layers.Dense(100, activation="tanh") for _ in range(6)]
),
combiner_net=tf.keras.Sequential(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,25 @@ def build_model():
model = DeepONet(
branch_net=tf.keras.Sequential(
[
tf.keras.layers.Input(np.prod(cp.y_shape(True)).item()),
tf.keras.layers.InputLayer(
np.prod(cp.y_shape(True)).item()
),
tf.keras.layers.Dense(100, activation="tanh"),
tf.keras.layers.Dense(50, activation="tanh"),
tf.keras.layers.Dense(diff_eq.y_dimension * 10),
]
),
trunk_net=tf.keras.Sequential(
[
tf.keras.layers.Input(diff_eq.x_dimension),
tf.keras.layers.InputLayer(diff_eq.x_dimension),
tf.keras.layers.Dense(50, activation="tanh"),
tf.keras.layers.Dense(50, activation="tanh"),
tf.keras.layers.Dense(diff_eq.y_dimension * 10),
]
),
combiner_net=tf.keras.Sequential(
[
tf.keras.layers.Input(3 * diff_eq.y_dimension * 10),
tf.keras.layers.InputLayer(3 * diff_eq.y_dimension * 10),
tf.keras.layers.Dense(
diff_eq.y_dimension,
kernel_regularizer=tf.keras.regularizers.L2(l2=1e-5),
Expand Down
24 changes: 12 additions & 12 deletions tests/operators/ml/pidon/test_pi_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ def test_pi_deeponet_with_mismatched_branch_and_trunk_net_output_shapes():
PIDeepONet(
tf.keras.Sequential(
[
tf.keras.layers.Input(1),
tf.keras.layers.InputLayer(1),
tf.keras.layers.Dense(5),
]
),
tf.keras.Sequential(
[
tf.keras.layers.Input(1),
tf.keras.layers.InputLayer(1),
tf.keras.layers.Dense(3),
]
),
tf.keras.Sequential(
[
tf.keras.layers.Input(15),
tf.keras.layers.InputLayer(15),
tf.keras.layers.Dense(1),
]
),
Expand All @@ -43,19 +43,19 @@ def test_pi_deeponet_with_wrong_combiner_net_output_shape():
PIDeepONet(
tf.keras.Sequential(
[
tf.keras.layers.Input(1),
tf.keras.layers.InputLayer(1),
tf.keras.layers.Dense(5),
]
),
tf.keras.Sequential(
[
tf.keras.layers.Input(1),
tf.keras.layers.InputLayer(1),
tf.keras.layers.Dense(5),
]
),
tf.keras.Sequential(
[
tf.keras.layers.Input(15),
tf.keras.layers.InputLayer(15),
tf.keras.layers.Dense(2),
]
),
Expand All @@ -70,19 +70,19 @@ def test_pi_deeponet_with_wrong_loss_weight_length():
PIDeepONet(
tf.keras.Sequential(
[
tf.keras.layers.Input(3),
tf.keras.layers.InputLayer(3),
tf.keras.layers.Dense(3),
]
),
tf.keras.Sequential(
[
tf.keras.layers.Input(1),
tf.keras.layers.InputLayer(1),
tf.keras.layers.Dense(3),
]
),
tf.keras.Sequential(
[
tf.keras.layers.Input(9),
tf.keras.layers.InputLayer(9),
tf.keras.layers.Dense(3),
]
),
Expand All @@ -96,19 +96,19 @@ def test_pi_deeponet_loss_weight_broadcasting():
pidon = PIDeepONet(
tf.keras.Sequential(
[
tf.keras.layers.Input(3),
tf.keras.layers.InputLayer(3),
tf.keras.layers.Dense(3),
]
),
tf.keras.Sequential(
[
tf.keras.layers.Input(1),
tf.keras.layers.InputLayer(1),
tf.keras.layers.Dense(3),
]
),
tf.keras.Sequential(
[
tf.keras.layers.Input(9),
tf.keras.layers.InputLayer(9),
tf.keras.layers.Dense(3),
]
),
Expand Down
Loading

0 comments on commit 59a5b18

Please sign in to comment.