Skip to content

Commit

Permalink
Improve type handling and resolve MHKiT-Software#339 (MHKiT-Software#348
Browse files Browse the repository at this point in the history
)

* fix issue 339 - bug in wave performance MAEP matrix

* type_handling - special dataset to dataarray case and tests

---------

Co-authored-by: ssolson <[email protected]>
  • Loading branch information
akeeste and ssolson authored Aug 28, 2024
1 parent efcda0a commit 789cc06
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 61 deletions.
24 changes: 10 additions & 14 deletions examples/wave_example.ipynb

Large diffs are not rendered by default.

65 changes: 41 additions & 24 deletions mhkit/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,27 @@ def test_convert_to_dataarray(self):
# test data formats
test_n = d1
test_s = pd.Series(d1, t)
test_df = pd.DataFrame({"d1": d1}, index=t)
test_df2 = pd.DataFrame({"d1": d1, "d1_duplicate": d1}, index=t)
test_df_1d = pd.DataFrame({"d1": d1}, index=t)
test_df_2d = pd.DataFrame({"d1": d1, "d1_duplicate": d1}, index=t)
test_da = xr.DataArray(
data=d1,
dims="time",
coords=dict(time=t),
)
test_ds = xr.Dataset(
test_ds_1d_1v = xr.Dataset(
data_vars={"d1": (["time"], d1)}, coords={"time": t, "index": i}
)
test_ds2 = xr.Dataset(
test_ds_2d_1v = xr.Dataset(
data_vars={
"d1": (["time"], d1),
"d2": (["ind"], d2),
"d1_duplicate": (["time"], d1),
},
coords={"time": t},
)
test_ds_2d_2v = xr.Dataset(
data_vars={
"d1": (["time"], d1),
"d2": (["index"], d2),
},
coords={"time": t, "index": i},
)
Expand All @@ -205,15 +212,33 @@ def test_convert_to_dataarray(self):
self.assertIsInstance(da, xr.DataArray)
self.assertTrue(all(da.data == d1))

# Dataframe
df = utils.convert_to_dataarray(test_df)
self.assertIsInstance(df, xr.DataArray)
self.assertTrue(all(df.data == d1))

# Dataset
ds = utils.convert_to_dataarray(test_ds)
self.assertIsInstance(ds, xr.DataArray)
self.assertTrue(all(ds.data == d1))
# 1D Dataframe
df_1d = utils.convert_to_dataarray(test_df_1d)
self.assertIsInstance(df_1d, xr.DataArray)
self.assertTrue(all(df_1d.data == d1))
self.assertTrue("variable" not in df_1d.dims)

# Multivariate Dataframe
df_2d = utils.convert_to_dataarray(test_df_2d)
self.assertIsInstance(df_2d, xr.DataArray)
self.assertTrue(all(df_2d.sel(variable="d1").data == d1))
self.assertTrue(all(df_2d.sel(variable="d1_duplicate").data == d1))

# 1D Dataset
ds_1d_1v = utils.convert_to_dataarray(test_ds_1d_1v)
self.assertIsInstance(ds_1d_1v, xr.DataArray)
self.assertTrue(all(ds_1d_1v.data == d1))
self.assertTrue("variable" not in ds_1d_1v.dims)

# Multivariate 1D Dataset
ds_2d_1v = utils.convert_to_dataarray(test_ds_2d_1v)
self.assertIsInstance(ds_2d_1v, xr.DataArray)
self.assertTrue(all(ds_2d_1v.sel(variable="d1").data == d1))
self.assertTrue(all(ds_2d_1v.sel(variable="d1_duplicate").data == d1))

# Multivariate 2D Dataset (error)
with self.assertRaises(ValueError):
utils.convert_to_dataarray(test_ds_2d_2v)

# int (error)
with self.assertRaises(TypeError):
Expand All @@ -223,14 +248,6 @@ def test_convert_to_dataarray(self):
with self.assertRaises(TypeError):
utils.convert_to_dataarray(test_n, 5)

# Multivariate Dataframe (error)
with self.assertRaises(ValueError):
utils.convert_to_dataarray(test_df2)

# Multivariate Dataset (error)
with self.assertRaises(ValueError):
utils.convert_to_dataarray(test_ds2)

def test_convert_to_dataset(self):
# test data
a = 5
Expand All @@ -242,7 +259,7 @@ def test_convert_to_dataset(self):
# test data formats
test_n = d1
test_s = pd.Series(d1, t)
test_df2 = pd.DataFrame({"d1": d1, "d2": d2}, index=t)
test_df_2d = pd.DataFrame({"d1": d1, "d2": d2}, index=t)
test_da = xr.DataArray(
data=d1,
dims="time",
Expand All @@ -267,7 +284,7 @@ def test_convert_to_dataset(self):
self.assertTrue(all(da["test_name"].data == d1))

# Dataframe
df = utils.convert_to_dataset(test_df2)
df = utils.convert_to_dataset(test_df_2d)
self.assertIsInstance(df, xr.Dataset)
self.assertTrue(all(df["d1"].data == d1))
self.assertTrue(all(df["d2"].data == d2))
Expand Down
62 changes: 39 additions & 23 deletions mhkit/utils/type_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ def convert_to_dataarray(data, name="data"):
"""
Converts the given data to an xarray.DataArray.
This function is designed to handle inputs that can be either a numpy ndarray, pandas Series,
or an xarray DataArray. For convenience, pandas DataFrame and xarray Dataset can also be input
but may only contain a single variable. The function ensures that the output is consistently
an xarray.DataArray.
This function takes in a numpy ndarray, pandas Series, pandas Dataframe, or xarray Dataset
and outputs an equivalent xarray DataArray. DataArrays can be passed through with no changes.
Xarray datasets can only be input when all variable have the same dimensions.
Multivariate pandas Dataframes become 2D DataArrays, which is especially useful when IO
functions return Dataframes with an extremely large number of variable. Use the function
convert_to_dataset to change a multivariate Dataframe into a multivariate Dataset.
Parameters
----------
Expand Down Expand Up @@ -138,7 +142,7 @@ def convert_to_dataarray(data, name="data"):
data, (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray, xr.Dataset)
):
raise TypeError(
"Input data must be of type np.ndarray, pandas.DataFrame, pandas.Series, "
"Input data must be of type np.ndarray, pandas.Series, pandas.DataFrame, "
f"xarray.DataArray, or xarray.Dataset. Got {type(data)}"
)

Expand All @@ -147,40 +151,52 @@ def convert_to_dataarray(data, name="data"):

# Checks pd.DataFrame input and converts to pd.Series if possible
if isinstance(data, pd.DataFrame):
if data.shape[1] > 1:
raise ValueError(
"If the input data is a pd.DataFrame or xr.Dataset, it must contain one variable. Got {data.shape[1]}"
)
else:
# use iloc instead of squeeze. For DataFrames/Series with only a
# single value, squeeze returns a scalar, which is unexpected.
# iloc will return a Series as expected
if data.shape[1] == 1:
# Convert the 1D, univariate case to a Series, which will be caught by the Series conversion below.
# This eliminates an unnecessary variable dimension and names the DataArray with the DataFrame variable name.
#
# Use iloc instead of squeeze. For DataFrames/Series with only a
# single value, squeeze returns a scalar which is unexpected.
# iloc returns a Series with one value as expected.
data = data.iloc[:, 0]
else:
index = data.index.values
columns = data.columns.values
data = xr.DataArray(
data=data.T,
dims=("variable", "index"),
coords={"variable": columns, "index": index},
)

# Checks xr.Dataset input and converts to xr.DataArray if possible
if isinstance(data, xr.Dataset):
keys = list(data.keys())
if len(keys) > 1:
raise ValueError(
"If the input data is a pd.DataFrame or xr.Dataset, it must contain one variable. Got {len(data.keys())}"
)
else:
if len(keys) == 1:
# if only one variable, remove the "variable" dimension and rename the DataArray to simplify
data = data.to_array()
data = data.sel(
variable=keys[0]
) # removes the variable dimension, further simplifying the dataarray
data = data.sel(variable=keys[0])
data.name = keys[0]
data.drop_vars("variable")
else:
# Allow multiple variables if they have the same dimensions
if all([data[keys[0]].dims == data[key].dims for key in keys]):
data = data.to_array()
else:
raise ValueError(
"Multivariate Datasets can only be input if all variables have the same dimensions."
)

# Converts pd.Series to xr.DataArray
if isinstance(data, pd.Series):
data = data.to_xarray()

# Converts np.ndarray to xr.DataArray. Assigns a simple 0-based dimension named index
# Converts np.ndarray to xr.DataArray. Assigns a simple 0-based dimension named index to match how pandas converts to xarray
if isinstance(data, np.ndarray):
data = xr.DataArray(
data=data, dims="index", coords={"index": np.arange(len(data))}
)

# If there's no data name, add one to prevent issues calling or converting the dataArray later one
# If there's no data name, add one to prevent issues calling or converting to a Dataset later on
if data.name == None:
data.name = name

Expand Down

0 comments on commit 789cc06

Please sign in to comment.