-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
executable file
·399 lines (325 loc) · 11.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
"""Collection of functions to preprocess climate data."""
import os
import numpy as np
import scipy as sp
import scipy.interpolate as interpolate
import xarray as xr
from tqdm import tqdm
from joblib import Parallel, delayed
def check_dimensions(ds, sort=True):
"""
Checks whether the dimensions are the correct ones for xarray!
"""
dims = list(ds.dims)
rename_dic = {
'longitude': 'lon',
'nav_lon': 'lon',
'xt_ocean': 'lon',
'xu_ocean': 'lon',
'latitude': 'lat',
'nav_lat': 'lat',
'yt_ocean': 'lat',
'yu_ocean': 'lat',
'time_counter': 'time',
}
for c_old, c_new in rename_dic.items():
if c_old in dims:
print(f'Rename:{c_old} : {c_new} ')
ds = ds.rename({c_old: c_new})
dims = list(ds.dims)
# Check for dimensions
clim_dims = ['time', 'lat', 'lon']
for dim in clim_dims:
if dim not in dims:
raise ValueError(
f"The dimension {dim} not consistent with required dims {clim_dims}!")
# If lon from 0 to 360 shift to -180 to 180
if max(ds.lon) > 180:
print("Shift longitude!")
ds = ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180))
if sort:
print('Sort longitudes and latitudes in ascending order, respectively')
ds = ds.sortby('lon')
ds = ds.sortby('lat')
# if 'time' in ds.dims:
# ds = ds.transpose('time', 'lat', 'lon')
return ds
def save_to_file(da, filepath, var_name=None):
"""Save dataset or dataarray to file."""
if os.path.exists(filepath):
print("File" + filepath + " already exists!")
# convert xr.dataArray to xr.dataSet if needed
if var_name is not None:
ds = da.to_dataset(name=var_name)
else:
ds = da
# Store to .nc file
if os.path.exists(filepath):
print(f"File '{filepath}' already exists!")
filepath = filepath + "_new"
try:
ds.to_netcdf(filepath)
print(f"File is stored to '{filepath}'!")
except OSError:
print(f"Could not write to '{filepath}'!")
return None
def time_average(ds, group='1D'):
"""Downsampling of time dimension by averaging.
Args:
-----
ds: xr.dataFrame
dataset
group: str
time group e.g. '1D' for daily average from hourly data
"""
ds_average = ds.resample(time=group, label='left').mean(skipna=True)
# Shift time to first of month
if group == '1M':
new_time = ds_average.time.data + np.timedelta64(1, 'D')
new_coords = {}
for dim in ds_average.dims:
new_coords[dim] = ds_average[dim].data
new_coords['time'] = new_time
ds_average = ds_average.assign_coords(coords=new_coords)
return ds_average
def set_grid(ds, step_lat=1, step_lon=1,
lat_range=None, lon_range=None):
"""Interpolate grid.
Args:
ds (xr.Dataset): Dataset or dataarray to interpolate.
Dataset is only supported for grid_type='mercato'.
step_lat (float, optional): Latitude grid step. Defaults to 1.
step_lon (float, optional): Longitude grid step. Defaults to 1.
Returns:
da (xr.Dataset): Interpolated dataset or dataarray.
grid (dict): Grid used for interpolation, dict(lat=[...], lon=[...]).
"""
lat_min = ds['lat'].min().data if lat_range is None else lat_range[0]
lat_max = ds['lat'].max().data if lat_range is None else lat_range[1]
lon_min = ds['lon'].min().data if lon_range is None else lon_range[0]
lon_max = ds['lon'].max().data if lon_range is None else lon_range[1]
init_lat = np.arange(
lat_min, (lat_max + step_lat), step_lat
)
init_lon = np.arange(
lon_min, lon_max, step_lon
)
grid = {'lat': init_lat, 'lon': init_lon}
# Interpolate
da = ds.interp(grid, method='nearest')
return da, grid
def interp_points(i, da, points_origin, points_grid):
"""Interpolation of dataarray to a new set of points.
Args:
i (int): Index of time
da (xr.Dataarray): Dataarray
points_origin (np.ndarray): Array of origin locations.
points_grid (np.ndarray): Array of locations to interpolate on.
Returns:
i (int): Index of time
values_grid_flat (np.ndarray): Values on new points.
"""
values_origin = da[i].data.flatten()
values_grid_flat = interpolate.griddata(
points_origin, values_origin, xi=points_grid, method='nearest'
)
return i, values_grid_flat
def interp_points2mercato(da, grid, n_cpus=1):
"""Interpolate Dataarray with non-rectangular grid to mercato grid.
Args:
da (xr.Dataarray): Dataarray with non-rectangular grid.
grid (dict): Grid to interpolate on dict(lat=[...], lon=[...]).
Returns:
da_grid (xr.Dataarray): Dataarray interpolated on mercato grid.
"""
print(f"Interpolate data with non-rectangular grid to mercato grid. n_cpus={n_cpus}.",
flush=True)
# Create array of points from mercato grid
xx, yy = np.meshgrid(grid['lon'], grid['lat'])
points_grid = np.array([xx.flatten(), yy.flatten()]).T
points_origin = np.array(
[da['nav_lon'].data.flatten(), da['nav_lat'].data.flatten()]).T
# Interpolation at each time step in parallel
n_processes = len(da['time_counter'])
results = Parallel(n_jobs=n_cpus)(
delayed(interp_points)(i, da, points_origin, points_grid)
for i in tqdm(range(n_processes))
)
# Read results
ids = []
values_grid_flat = []
for r in results:
i, data = r
ids.append(i)
values_grid_flat.append(data)
ids = np.array(ids)
values_grid_flat = np.array(values_grid_flat)
# Store to new dataarray
values_grid = np.reshape(
values_grid_flat,
newshape=(len(values_grid_flat), len(grid['lat']), len(grid['lon']))
)
times = da['time_counter'].data[ids]
da_grid = xr.DataArray(
data=values_grid,
dims=['time', 'lat', 'lon'],
coords=dict(time=times, lat=grid['lat'], lon=grid['lon']),
name=da.name
)
return da_grid
def cut_map(ds, lon_range=None, lat_range=None, shortest=True):
"""Cut an area in the map. Use always smallest range as default.
It lon ranges accounts for regions (eg. Pacific) that are around the -180/180 region.
Args:
----------
lon_range: list [min, max]
range of longitudes
lat_range: list [min, max]
range of latitudes
shortest: boolean
use shortest range in longitude (eg. -170, 170 range contains all points from
170-180, -180- -170, not all between -170 and 170). Default is True.
Return:
-------
ds_area: xr.dataset
Dataset cut to range
"""
if lon_range is not None:
if (max(lon_range) - min(lon_range) <= 180) or shortest is False:
ds = ds.sel(
lon=slice(np.min(lon_range), np.max(lon_range)),
lat=slice(np.min(lat_range), np.max(lat_range))
)
else:
# To account for areas that lay at the border of -180 to 180
ds = ds.sel(
lon=ds.lon[(ds.lon < min(lon_range)) |
(ds.lon > max(lon_range))],
lat=slice(np.min(lat_range), np.max(lat_range))
)
if lat_range is not None:
ds = ds.sel(
lat=slice(np.min(lat_range), np.max(lat_range))
)
return ds
def select_months(ds, months=[12, 1, 2]):
"""Select only some months in the data.
Args:
ds ([xr.DataSet, xr.DataArray]): Dataset of dataarray
months (list, optional): Index of months to select.
Defaults to [12,1,2]=DJF.
"""
ds_months = ds.sel(time=np.in1d(ds['time.month'], months))
return ds_months
def select_time_snippets(ds, time_snippets):
"""Cut time snippets from dataset and concatenate them.
ra
Parameters:
-----------
time_snippets: np.datetime64 (n,2)
Array of n time snippets with dimension (n,2).
Returns:
--------
xr.Dataset with concatenate times
"""
ds_lst = []
for time_range in time_snippets:
ds_lst.append(ds.sel(time=slice(time_range[0], time_range[1])))
ds_snip = xr.concat(ds_lst, dim='time')
return ds_snip
def average_time_periods(ds, time_snippets):
"""Select time snippets from dataset and average them.
Parameters:
-----------
time_snippets: np.datetime64 (n,2)
Array of n time snippets with dimension (n,2).
Returns:
--------
xr.Dataset with averaged times
"""
ds_lst = []
for time_range in time_snippets:
temp_mean = ds.sel(time=slice(time_range[0], time_range[1])).mean('time')
temp_mean['time'] = time_range[0] + 0.5 * (time_range[1] - time_range[0])
ds_lst.append(temp_mean)
ds_snip = xr.concat(ds_lst, dim='time')
return ds_snip
def get_mean_time_series(da, lon_range, lat_range, time_roll=0):
"""Get mean time series of selected area.
Parameters:
-----------
da: xr.DataArray
Data
lon_range: list
[min, max] of longitudinal range
lat_range: list
[min, max] of latiduninal range
"""
da_area = cut_map(da, lon_range, lat_range)
ts_mean = da_area.mean(dim=('lon', 'lat'), skipna=True)
ts_std = da_area.std(dim=('lon', 'lat'), skipna=True)
if time_roll > 0:
ts_mean = ts_mean.rolling(time=time_roll, center=True).mean()
ts_std = ts_std.rolling(time=time_roll, center=True).mean()
return ts_mean, ts_std
def normalize(da, method='zscore'):
"""Normalize dataarray by a given method.
Args:
da ([type]): [description]
method (str, optional): Normalization method. 'minmax' corresponds to 0-1,
and 'zscore' standardizes the data. Defaults to 'zscore'.
Returns:
[type]: [description]
"""
print(f'Normalize data by {method}!')
flatten = da.stack(z=da.dims)
if method == 'minmax':
min = flatten.min(skipna=True)
max = flatten.max(skipna=True)
norm_data = (
(flatten - min) / (max - min)
)
attr = dict(norm=method, min=min.data, max=max.data)
elif method == 'zscore':
mean = flatten.mean(skipna=True)
std = flatten.std(skipna=True)
norm_data = (
(flatten - mean) / std
)
attr = dict(norm=method, mean=mean.data, std=std.data)
else:
print(f'Your selected normalization method "{method}" does not exist.')
norm_data = norm_data.unstack('z')
for key, val in attr.items():
norm_data.attrs[key] = val
return norm_data
def unnormalize(dmap, attr):
"""Unnormalize data.
Args:
dmap (xr.Dataarray): Datamap.
attr (dict): Dictionary containing normalization information
attr = {'norm': 'minmax' ,'min': , 'max': }
or attr = {'norm': 'zscore' ,'mean': , 'std': }
Returns:
rec_map (xr.Dataarray): Unnormalized map.
"""
if attr['norm'] == 'minmax':
rec_map = dmap * (attr['max'] - attr['min']) + attr['min']
elif attr['norm'] == 'zscore':
rec_map = (dmap * attr['std'] + attr['mean'])
else:
print(
f'Your selected normalization method {attr["norm"]} does not exist.')
return rec_map
def rotate_matrix(M, Theta):
"""Rotate 2d matrix by angle theta.
Args:
M (np.ndarray): (2,2) 2d matrix
Theta (float): Angle in rad.
Returns:
(np.ndarray) (2,2) Rotated matrix.
"""
R = np.array(
[[np.cos(Theta), -np.sin(Theta)], [np.sin(Theta), np.cos(Theta)]]
)
return R @ M @ R.T