Skip to content

Commit

Permalink
Varitional Inference : implement variational inference for BNNs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaouki-AI committed Jan 7, 2024
1 parent ab5db1f commit 435e048
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 102 deletions.
Binary file not shown.
Binary file added models/VAR_mauna_loa_model.pth
Binary file not shown.
Binary file modified models/international_airline_passengers_model.pth
Binary file not shown.
Binary file modified models/mauna_loa_model.pth
Binary file not shown.
161 changes: 59 additions & 102 deletions src/models/train_bnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from src.models.BnnModel import BayesianModel
from src.models.BnnModel import BayesianModel, BcnnLayer1D, BnnLayer
from src.models.refactors import Refactor, Refactor_var

from src.data.data_loader import (
load_mauna_loa_atmospheric_co2,
load_international_airline_passengers,
Expand All @@ -28,64 +30,42 @@
X1_train_tensor = torch.from_numpy(X1_train).float()
y1_train_tensor = torch.from_numpy(y1_train).float()
X1_test_tensor = torch.from_numpy(X1_test).float()

# Define the Bayesian neural network model
input_size = X1_train.shape[1]
hidden_size = 20
output_size = 1
model = BayesianModel(input_size, hidden_size, output_size)

model = BayesianModel(input_size, hidden_size, output_size, kernel=3)
# ------------------------------------------
# Train the model with Loss Function MAE
# ------------------------------------------
# Define loss function and optimizer
loss_function = nn.MSELoss()
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 1000
train_losses = []

for epoch in range(num_epochs):
# Forward pass
outputs = model(X1_train_tensor)
loss = loss_function(outputs, y1_train_tensor)

# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
fact1 = Refactor( [X1_train_tensor, y1_train_tensor],
[X1_test_tensor, y1_test],
model = model, criterion = loss_function,
optimizer = optimizer,
epochs = 1000)

train_losses.append(loss.item())
fact1.fit()
fact1.plot_loss()
fact1.eval()
fact1.save_model(path = "./models/", name = "international_airline_passengers_model")

# --------- Plot training losses ----------
# --------------------------------------------
# Train the model with ELBO Loss using Var Inf.
# ---------------------------------------------

plt.plot(range(1, num_epochs + 1), train_losses, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Over Epochs")
plt.legend()
plt.show()

# --------- Plot Ground Truth vs Predictions ----------

# Evaluate the model on the test set
with torch.no_grad():
model.eval()
predictions_1 = model(X1_test_tensor)

# Convert predictions to NumPy array for plotting
predictions_np_1 = predictions_1.numpy()

# export model
torch.save(model, "./models/mauna_loa_model.pth")

plt.figure(figsize=(10, 6))
plt.plot(X1_test, y1_test, "b.", markersize=10, label="Ground Truth")
plt.plot(X1_test, predictions_1, "r.", markersize=10, label="Predictions")
plt.xlabel("Input Features (X_test)")
plt.ylabel("Ground Truth and Predictions (y_test, Predictions)")
plt.title("True Values vs Predictions")
plt.legend()
plt.show()
var_fact1 = Refactor_var( [X1_train_tensor, y1_train_tensor],
[X1_test_tensor , y1_test],
model = model,
optimizer = optimizer,
epochs = 1000)

var_fact1.fit()
var_fact1.plot_loss()
var_fact1.eval(plot=True)
var_fact1.save_model(path = "./models/", name = "VAR_mauna_loa_model")

# ------------------------------------------
# international-airline-passengers Dataset
Expand All @@ -108,57 +88,34 @@
input_size = X2_train.shape[1]
hidden_size = 20
output_size = 1
model = BayesianModel(input_size, hidden_size, output_size)


# Define loss function and optimizer
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 1000
train_losses = []

for epoch in range(num_epochs):
# Forward pass
outputs = model(X2_train_tensor)
loss = loss_function(outputs, y2_train_tensor)

# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()

train_losses.append(loss.item())

# --------- Plot training losses ----------

plt.plot(range(1, num_epochs + 1), train_losses, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Over Epochs")
plt.legend()
plt.show()

# --------- Plot Ground Truth vs Predictions ----------

# Evaluate the model on the test set
with torch.no_grad():
model.eval()
predictions_2 = model(X2_test_tensor)


# Convert predictions to NumPy array for plotting
predictions_np_2 = predictions_2.numpy()

# export model
torch.save(model, "./models/international_airline_passengers_model.pth")

plt.figure(figsize=(10, 6))
plt.plot(X2_test, y2_test, "b.", markersize=10, label="Ground Truth")
plt.plot(X2_test, predictions_2, "r.", markersize=10, label="Predictions")
plt.xlabel("Input Features (X_test)")
plt.ylabel("Ground Truth and Predictions (y_test, Predictions)")
plt.title("True Values vs Predictions")
plt.legend()
plt.show()
# ------------------------------------------
# Train the model with Loss Function MAE
# ------------------------------------------
loss_function = nn.L1Loss()
fact2 = Refactor([X2_train_tensor, y2_train_tensor],
[X2_test_tensor, y2_test],
model = model, criterion = loss_function,
optimizer = optimizer,
epochs = 1000)
fact2.fit()
fact2.plot_loss()
fact2.eval()
fact2.save_model(path = "./models/", name = "international_airline_passengers_model")


# --------------------------------------------
# Train the model with ELBO Loss using Var Inf.
# ---------------------------------------------

var_fact2 = Refactor_var( [X2_train_tensor, y2_train_tensor],
[X2_test_tensor , y2_test],
model = model,
optimizer = optimizer,
epochs = 1000)

var_fact2.fit()
var_fact2.plot_loss()
var_fact2.eval(plot=True)
var_fact2.save_model(path = "./models/", name = "VAR_international_airline_passengers_model")

0 comments on commit 435e048

Please sign in to comment.