Skip to content

Commit

Permalink
review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Kamuk Christiansen committed Jan 14, 2025
1 parent 18cb472 commit afeee02
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 35 deletions.
50 changes: 16 additions & 34 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def prepare_clamping_params(
)

# Constant parameters for clamping
self.sigmoid_sharpness = 1
self.softplus_sharpness = 1
self.sigmoid_center = 0
self.softplus_center = 0
sigmoid_sharpness = 1
softplus_sharpness = 1
sigmoid_center = 0
softplus_center = 0

normalize_clamping_lim = (
lambda x, feature_idx: (x - self.state_mean[feature_idx])
Expand Down Expand Up @@ -167,58 +167,40 @@ def prepare_clamping_params(
self.clamp_lower_upper = lambda x: (
self.sigmoid_lower_lims
+ (self.sigmoid_upper_lims - self.sigmoid_lower_lims)
* torch.sigmoid(self.sigmoid_sharpness * (x - self.sigmoid_center))
* torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center))
)
self.clamp_lower = lambda x: (
self.softplus_lower_lims
+ torch.nn.functional.softplus(
x - self.softplus_center, beta=self.softplus_sharpness
x - softplus_center, beta=softplus_sharpness
)
)
self.clamp_upper = lambda x: (
self.softplus_upper_lims
- torch.nn.functional.softplus(
self.softplus_center - x, beta=self.softplus_sharpness
softplus_center - x, beta=softplus_sharpness
)
)

# Define inverse clamping functions
def inverse_softplus(x, beta=1, threshold=20):
# If x*beta is above threshold, returns linear function
# for numerical stability
non_linear_part = (
torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta
)
x = torch.where(x * beta <= threshold, non_linear_part, x)

return x

def inverse_sigmoid(x):
# Sigmoid output takes values in [0,1], this makes sure input is just within this interval
# Note that this torch.clamp will make gradients 0, but this is not a problem
# as values of x that are this close to 0 or 1 have gradient 0 anyhow.
x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6)
return torch.log(x_clamped / (1 - x_clamped))

self.inverse_clamp_lower_upper = lambda x: (
self.sigmoid_center
+ inverse_sigmoid(
sigmoid_center
+ utils.inverse_sigmoid(
(x - self.sigmoid_lower_lims)
/ (self.sigmoid_upper_lims - self.sigmoid_lower_lims)
)
/ self.sigmoid_sharpness
/ sigmoid_sharpness
)
self.inverse_clamp_lower = lambda x: (
inverse_softplus(
x - self.softplus_lower_lims, beta=self.softplus_sharpness
utils.inverse_softplus(
x - self.softplus_lower_lims, beta=softplus_sharpness
)
+ self.softplus_center
+ softplus_center
)
self.inverse_clamp_upper = lambda x: (
-inverse_softplus(
self.softplus_upper_lims - x, beta=self.softplus_sharpness
-utils.inverse_softplus(
self.softplus_upper_lims - x, beta=softplus_sharpness
)
+ self.softplus_center
+ softplus_center
)

def get_clamped_new_state(self, state_delta, prev_state):
Expand Down
33 changes: 33 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,36 @@ def init_wandb_metrics(wandb_logger, val_steps):
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")


def inverse_softplus(x, beta=1, threshold=20):
"""
Inverse of torch.nn.functional.softplus
For x*beta above threshold, returns linear function for numerical
stability.
Input is clamped to x > ln(1+1e-6)/beta which is approximately positive
values of x.
Note that this torch.clamp_min will make gradients 0, but this is not a
problem as values of x that are this close to 0 have gradients of 0 anyhow.
"""
non_linear_part = (
torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta
)
x = torch.where(x * beta <= threshold, non_linear_part, x)

return x


def inverse_sigmoid(x):
"""
Inverse of torch.sigmoid
Sigmoid output takes values in [0,1], this makes sure input is just within
this interval.
Note that this torch.clamp will make gradients 0, but this is not a problem
as values of x that are this close to 0 or 1 have gradients of 0 anyhow.
"""
x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6)
return torch.log(x_clamped / (1 - x_clamped))
2 changes: 1 addition & 1 deletion tests/datastore_examples/mdp/danra_100m_winds/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ training:
t2m: 0.0
r2m: 0
upper:
r2m: 100.0
r2m: 1.0
u100m: 100.0

0 comments on commit afeee02

Please sign in to comment.