Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix restart #94

Open
wants to merge 4 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cosipy/cpkernel/init.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union
import numpy as np
from numba import njit

Expand Down Expand Up @@ -129,9 +130,9 @@ def load_snowpack(GRID_RESTART):
layer_LWC = GRID_RESTART.LAYER_LWC.values[0:num_layers]
layer_IF = GRID_RESTART.LAYER_IF.values[0:num_layers]

new_snow_height = np.float64(GRID_RESTART.new_snow_height.values)
new_snow_timestamp = np.float64(GRID_RESTART.new_snow_timestamp.values)
old_snow_timestamp = np.float64(GRID_RESTART.old_snow_timestamp.values)
new_snow_height = np.float64(GRID_RESTART.NEW_SNOW_HEIGHT.to_numpy())
new_snow_timestamp = np.float64(GRID_RESTART.NEW_SNOW_TIMESTAMP.to_numpy())
old_snow_timestamp = np.float64(GRID_RESTART.OLD_SNOW_TIMESTAMP.to_numpy())

GRID = create_grid_jitted(
layer_heights,
Expand All @@ -154,9 +155,9 @@ def create_grid_jitted(
layer_T: np.ndarray,
layer_LWC: np.ndarray,
layer_IF: np.ndarray,
new_snow_height: float,
new_snow_timestamp: float,
old_snow_timestamp: float,
new_snow_height: Union[float, np.float64],
new_snow_timestamp: Union[float, np.float64],
old_snow_timestamp: Union[float, np.float64],
):
"""Create Grid with JIT.

Expand Down
182 changes: 114 additions & 68 deletions cosipy/cpkernel/io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""
Read the input data (model forcing) and write the output to netCDF file.
"""
"""Read the input data (model forcing) and write the output to netCDF file."""

import os
import warnings
from datetime import datetime
from typing import Union

import numpy as np
import xarray as xr
Expand All @@ -14,7 +13,6 @@


class IOClass:

def __init__(self, DATA=None):
"""Initialise the IO Class.

Expand Down Expand Up @@ -97,7 +95,10 @@ def set_full_field_attribute(self, name: str, value: np.ndarray, x: int, y: int)
getattr(self, f"LAYER_{name}")[:, y, x, :] = value

def get_datetime(
self, timestamp: str, use_np: bool = True, fmt: str = "%Y-%m-%dT%H:%M"
self,
timestamp: Union[str, datetime, np.datetime64],
use_np: bool = True,
fmt: str = "%Y-%m-%dT%H:%M",
):
"""Get datetime object from a string.

Expand All @@ -115,8 +116,45 @@ def get_datetime(
return np.datetime64(timestamp)
else:
return datetime.strptime(timestamp, fmt)
else:
return timestamp
if isinstance(timestamp, datetime) and use_np:
return np.datetime64(timestamp)
if isinstance(timestamp, np.datetime64) and not use_np:
return timestamp.astype(datetime)
return timestamp

def load_restart_file(self) -> None:
print(f"{'-'*62}\n\tRESTART FROM PREVIOUS STATE\n{'-'*62}\n")

# Load the restart file
time_start = Config.time_start
time_end = Config.time_end
start_timestamp = self.get_datetime(time_start, use_np=False)
end_timestamp = self.get_datetime(time_end, use_np=False)
timestamp = start_timestamp.strftime("%Y-%m-%dT%H-%M")
restart_path = os.path.join(
Config.data_path, "restart", f"restart_{timestamp}.nc"
)
if not os.path.isfile(restart_path):
raise FileNotFoundError
elif start_timestamp == end_timestamp:
raise IndexError
try:
self.GRID_RESTART = xr.open_dataset(restart_path)
"""Get time of the last calculation and add one time step.

GRID_RESTART.time is an array of np.datetime64 objects.
"""
self.restart_date = self.GRID_RESTART.time.values + np.timedelta64(
Constants.dt, "s"
)
# Read data from the last date to the end of the data file
self.init_data_dataset()
except FileNotFoundError:
raise SystemExit(
f"No restart file available for the given date: {timestamp}"
)
except IndexError:
raise SystemExit(f"Start date {time_start} equals end date {time_end}\n")

def create_data_file(self) -> xr.Dataset:
"""Create the input data and read the restart file if necessary.
Expand All @@ -126,37 +164,7 @@ def create_data_file(self) -> xr.Dataset:
"""

if Config.restart:
print(f"{'-'*62}\n\tRESTART FROM PREVIOUS STATE\n{'-'*62}\n")

# Load the restart file
time_start = Config.time_start
time_end = Config.time_end
start_timestamp = self.get_datetime(time_start)
end_timestamp = self.get_datetime(time_end)
timestamp = start_timestamp.strftime("%Y-%m-%dT%H-%M")
restart_path = os.path.join(
Config.data_path, "restart", f"restart_{timestamp}.nc"
)
try:
if not os.path.isfile(restart_path):
raise FileNotFoundError
elif start_timestamp == end_timestamp:
raise IndexError
else:
self.GRID_RESTART = xr.open_dataset(restart_path)
"""Get time of the last calculation and add one time
step. GRID_RESTART.time is an array of np.datetime64
objects.
"""
self.restart_date = self.GRID_RESTART.time.values + np.timedelta64(
Constants.dt, "s"
)
# Read data from the last date to the end of the data file
self.init_data_dataset()
except FileNotFoundError:
raise SystemExit(f"No restart file available for the given date: {timestamp}")
except IndexError:
raise SystemExit(f"Start date {time_start} equals end date {time_end}\n")
self.load_restart_file()
else:
# If no restart, read data according to the dates defined in config file
self.restart_date = None
Expand Down Expand Up @@ -234,15 +242,15 @@ def check_field(self, field, _max, _min) -> bool:
def check_input_data(self) -> bool:
"""Check the input data is within valid bounds."""
print(f"{'-'*62}\nChecking input data ....\n")

data_bounds = {
"T2": (313.16, 243.16),
"RH2": (100.0, 0.0),
"G": (1600.0, 0.0),
"U2": (50.0, 0.0),
"RRR": (20.0, 0.0),
"N": (1.0, 0.0),
"PRES": (1080.0, 400.0),
"PRESS": (1080.0, 400.0),
"LWin": (400.0, 200.0),
"SNOWFALL": (0.1, 0.0),
"SLOPE": (0.0, 90.0),
Expand All @@ -259,7 +267,7 @@ def init_data_dataset(self):
"""Read and store the input netCDF data.

The input data should contain the following variables:
:PRES: Air pressure [hPa].
:PRESS: Air pressure [hPa].
:N: Cloud cover fraction [-].
:RH2: 2m relative humidity [%].
:RRR: Precipitation per time step [mm].
Expand All @@ -276,7 +284,6 @@ def init_data_dataset(self):
except FileNotFoundError:
raise SystemExit(f"Input file not found at: {input_path}")


self.DATA["time"] = np.sort(self.DATA["time"].values)
minimum_time = str(self.DATA.time.values[0])[0:16]
maximum_time = str(self.DATA.time.values[-1])[0:16]
Expand All @@ -297,11 +304,11 @@ def init_data_dataset(self):
raise IndexError("Selected period not available in input data.\n")
if start_time < start_interval:
warnings.warn(
"\nWARNING! Selected startpoint before first timestep of input data\n",
"\nWARNING! Selected startpoint before first timestep of input data\n",
)
if end_time > end_interval:
warnings.warn(
"\nWARNING! Selected endpoint after last timestep of input data\n",
"\nWARNING! Selected endpoint after last timestep of input data\n",
)

if self.restart_date is None: # Check if restart option is set
Expand Down Expand Up @@ -341,7 +348,7 @@ def get_input_metadata(self) -> tuple:
"T2": ("K", "Air temperature at 2 m"),
"RH2": ("%", "Relative humidity at 2 m"),
"U2": ("m s\u207b\xb9", "Wind velocity at 2 m"),
"PRES": ("hPa", "Atmospheric pressure"),
"PRESS": ("hPa", "Atmospheric pressure"),
"G": ("W m\u207b\xb2", "Incoming shortwave radiation"),
"RRR": ("mm", "Total precipitation"),
"SNOWFALL": ("m", "Snowfall"),
Expand Down Expand Up @@ -374,7 +381,6 @@ def get_full_field_metadata(self) -> dict:
return metadata

def get_restart_metadata(self) -> dict:

field_metadata = self.get_full_field_metadata()
restart_metadata = {
"new_snow_height": ("m .w.e", "New snow height"),
Expand Down Expand Up @@ -471,31 +477,55 @@ def init_result_dataset(self) -> xr.Dataset:
self.RESULT.attrs["Densification_method"] = Constants.densification_method
self.RESULT.attrs["Penetrating_method"] = Constants.penetrating_method
self.RESULT.attrs["Roughness_method"] = Constants.roughness_method
self.RESULT.attrs["Saturation_water_vapour_method"] = Constants.saturation_water_vapour_method
self.RESULT.attrs["Saturation_water_vapour_method"] = (
Constants.saturation_water_vapour_method
)

self.RESULT.attrs["Initial_snowheight"] = Constants.initial_snowheight_constant
self.RESULT.attrs["Initial_snow_layer_heights"] = Constants.initial_snow_layer_heights
self.RESULT.attrs["Initial_snow_layer_heights"] = (
Constants.initial_snow_layer_heights
)
self.RESULT.attrs["Initial_glacier_height"] = Constants.initial_glacier_height
self.RESULT.attrs["Initial_glacier_layer_heights"] = Constants.initial_glacier_layer_heights
self.RESULT.attrs["Initial_top_density_snowpack"] = Constants.initial_top_density_snowpack
self.RESULT.attrs["Initial_bottom_density_snowpack"] = Constants.initial_bottom_density_snowpack
self.RESULT.attrs["Initial_glacier_layer_heights"] = (
Constants.initial_glacier_layer_heights
)
self.RESULT.attrs["Initial_top_density_snowpack"] = (
Constants.initial_top_density_snowpack
)
self.RESULT.attrs["Initial_bottom_density_snowpack"] = (
Constants.initial_bottom_density_snowpack
)
self.RESULT.attrs["Temperature_bottom"] = Constants.temperature_bottom
self.RESULT.attrs["Const_init_temp"] = Constants.const_init_temp

self.RESULT.attrs["Center_snow_transfer_function"] = Constants.center_snow_transfer_function
self.RESULT.attrs["Spread_snow_transfer_function"] = Constants.spread_snow_transfer_function
self.RESULT.attrs["Multiplication_factor_for_RRR_or_SNOWFALL"] = Constants.mult_factor_RRR
self.RESULT.attrs["Minimum_snow_layer_height"] = Constants.minimum_snow_layer_height
self.RESULT.attrs["Center_snow_transfer_function"] = (
Constants.center_snow_transfer_function
)
self.RESULT.attrs["Spread_snow_transfer_function"] = (
Constants.spread_snow_transfer_function
)
self.RESULT.attrs["Multiplication_factor_for_RRR_or_SNOWFALL"] = (
Constants.mult_factor_RRR
)
self.RESULT.attrs["Minimum_snow_layer_height"] = (
Constants.minimum_snow_layer_height
)
self.RESULT.attrs["Minimum_snowfall"] = Constants.minimum_snowfall

self.RESULT.attrs["Remesh_method"] = Constants.remesh_method
self.RESULT.attrs["First_layer_height_log_profile"] = Constants.first_layer_height
self.RESULT.attrs["First_layer_height_log_profile"] = (
Constants.first_layer_height
)
self.RESULT.attrs["Layer_stretching_log_profile"] = Constants.layer_stretching

self.RESULT.attrs["Merge_max"] = Constants.merge_max
self.RESULT.attrs["Layer_stretching_log_profile"] = Constants.layer_stretching
self.RESULT.attrs["Density_threshold_merging"] = Constants.density_threshold_merging
self.RESULT.attrs["Temperature_threshold_merging"] = Constants.temperature_threshold_merging
self.RESULT.attrs["Density_threshold_merging"] = (
Constants.density_threshold_merging
)
self.RESULT.attrs["Temperature_threshold_merging"] = (
Constants.temperature_threshold_merging
)

self.RESULT.attrs["Density_fresh_snow"] = Constants.constant_density
self.RESULT.attrs["Albedo_fresh_snow"] = Constants.albedo_fresh_snow
Expand Down Expand Up @@ -556,7 +586,7 @@ def init_result_dataset(self) -> xr.Dataset:
spatiotemporal["N"][0],
)

print(f"\nOutput dataset ... ok")
print("\nOutput dataset ... ok")

return self.RESULT

Expand Down Expand Up @@ -693,15 +723,15 @@ def write_results_to_file(self):
)

if Config.full_field and self.full:
for full_field_var in self.full:
layer_name = f"LAYER_{full_field_var}"
self.add_variable_along_latlonlayertime(
self.RESULT,
getattr(self, layer_name),
layer_name,
metadata[layer_name][0],
metadata[layer_name][1],
)
for full_field_var in self.full:
layer_name = f"LAYER_{full_field_var}"
self.add_variable_along_latlonlayertime(
self.RESULT,
getattr(self, layer_name),
layer_name,
metadata[layer_name][0],
metadata[layer_name][1],
)

def create_empty_restart(self) -> xr.Dataset:
"""Create an empty dataset for the RESTART attribute.
Expand Down Expand Up @@ -843,55 +873,71 @@ def write_restart_to_file(self):
@property
def RAIN(self):
return self.__RAIN

@property
def SNOWFALL(self):
return self.__SNOWFALL

@property
def LWin(self):
return self.__LWin

@property
def LWout(self):
return self.__LWout

@property
def H(self):
return self.__H

@property
def LE(self):
return self.__LE

@property
def B(self):
return self.__B

@property
def QRR(self):
return self.__QRR

@property
def MB(self):
return self.__MB

@RAIN.setter
def RAIN(self, x):
self.__RAIN = x

@SNOWFALL.setter
def SNOWFALL(self, x):
self.__SNOWFALL = x

@LWin.setter
def LWin(self, x):
self.__LWin = x

@LWout.setter
def LWout(self, x):
self.__LWout = x

@H.setter
def H(self, x):
self.__H = x

@LE.setter
def LE(self, x):
self.__LE = x

@B.setter
def B(self, x):
self.__B = x

@QRR.setter
def QRR(self, x):
self.__QRR = x

@MB.setter
def MB(self, x):
self.__MB = x
Expand Down
Binary file not shown.
Binary file not shown.
Loading
Loading