Skip to content

Commit

Permalink
write output contiguously; 🐛
Browse files Browse the repository at this point in the history
  • Loading branch information
dionhaefner committed Nov 3, 2020
1 parent 2db42eb commit a2aa08a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 47 deletions.
36 changes: 23 additions & 13 deletions fowd/cdip.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def process_cdip_station(station_folder, out_folder, nproc=None):
result_files = [None for _ in range(num_inputs)]

do_work = functools.partial(get_cdip_wave_records, out_folder=out_folder, qc_outfile=qc_outfile)
num_waves_total = 0

def handle_result(i, result, pbar):
pbar.update(1)
Expand All @@ -195,11 +196,14 @@ def handle_result(i, result, pbar):
logger.warning('Processing skipped for file %s', filename)
return

nonlocal num_waves_total
num_waves = 0
for record_chunk in read_pickle_outfile_chunks(result_file):
if record_chunk:
num_waves += len(record_chunk['wave_id_local'])

num_waves_total += num_waves

result_files[i] = result_file

# get QC information
Expand Down Expand Up @@ -260,22 +264,27 @@ def handle_result(i, result, pbar):
# write output
def generate_results():
current_wave_id = 0
for result_file in tqdm.tqdm(result_files, desc='Writing output'):
if result_file is None:
continue
pbar = tqdm.tqdm(total=num_waves_total, desc='Writing output')

for record_chunk in read_pickle_outfile_chunks(result_file):
if not record_chunk:
with pbar:
for result_file in result_files:
if result_file is None:
continue

# fix local id to be unique for the whole station
chunk_size = len(record_chunk['wave_id_local'])
record_chunk['wave_id_local'] = np.arange(
current_wave_id, current_wave_id + chunk_size
)
current_wave_id += chunk_size
for record_chunk in read_pickle_outfile_chunks(result_file):
if not record_chunk:
continue

# fix local id to be unique for the whole station
chunk_size = len(record_chunk['wave_id_local'])
record_chunk['wave_id_local'] = np.arange(
current_wave_id, current_wave_id + chunk_size
)
current_wave_id += chunk_size

yield record_chunk

yield record_chunk
pbar.update(chunk_size)

result_generator = generate_results()
out_file = os.path.join(out_folder, f'fowd_cdip_{station_id}.nc')
Expand All @@ -284,5 +293,6 @@ def generate_results():

write_records(
result_generator, out_file, station_name,
include_direction=True, extra_metadata=EXTRA_METADATA
include_direction=True, extra_metadata=EXTRA_METADATA,
num_records=num_waves_total
)
7 changes: 5 additions & 2 deletions fowd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def postprocess_cdip(cdip_files, out_folder):
logger.info(f'Processing {infile}')

station_name = str(ds.meta_station_name.values[0])
num_records = len(ds['wave_id_local'])

assert station_name.startswith('CDIP_')
if CDIP_DEPLOYMENT_BLACKLIST.get(station_name[5:]) == '*':
Expand All @@ -190,14 +191,16 @@ def postprocess_cdip(cdip_files, out_folder):

record_generator = tqdm.tqdm(
filter_cdip(ds, num_filtered, chunk_size=chunk_size),
total=math.ceil(len(ds['wave_id_local']) / chunk_size),
total=math.ceil(num_records / chunk_size),
leave=False
)

write_records(
record_generator,
outfile, station_name,
extra_metadata=out_metadata, include_direction=True
num_records=num_records,
extra_metadata=out_metadata,
include_direction=True
)

logger.info(f'Filtered {num_filtered["blacklist"]} blacklisted seas')
Expand Down
2 changes: 1 addition & 1 deletion fowd/generic_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,4 @@ def process_file(input_file, out_folder, station_id=None):
result_generator = filter(None, read_pickle_outfile_chunks(result_file))
out_file = os.path.join(out_folder, f'fowd_{station_id}.nc')
logger.info('Writing output to %s', out_file)
write_records(result_generator, out_file, station_id)
write_records(result_generator, out_file, station_id, num_records=num_waves)
45 changes: 14 additions & 31 deletions fowd/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@
FILL_VALUE_NUMBER = -9999
FILL_VALUE_STR = 'MISSING'

# chunk sizes to use for each dimension
CHUNKSIZES = {
'meta_station_name': 1,
'wave_id_local': 1000,
'meta_frequency_band': len(FREQUENCY_INTERVALS),
}

DATASET_VARIABLES = dict(
# metadata
meta_source_file_name=dict(
Expand Down Expand Up @@ -494,16 +487,21 @@ def get_dataset_metadata(station_name, start_time, end_time, extra_metadata=None


def write_records(wave_record_iterator, filename, station_name, extra_metadata=None,
include_direction=False):
include_direction=False, num_records=None):
"""Write given wave records in FOWD's netCDF4 output format.
First argument is an iterable of chunks of wave records.
"""

if num_records is None:
wave_id_dim = None
else:
wave_id_dim = np.arange(num_records)

dimension_data = (
# (name, dtype, data)
('meta_station_name', str, np.array([np.string_(station_name)])),
('wave_id_local', 'int64', None),
('wave_id_local', 'int64', wave_id_dim),
('meta_frequency_band', 'uint8', np.arange(len(FREQUENCY_INTERVALS))),
)

Expand All @@ -523,13 +521,7 @@ def write_records(wave_record_iterator, filename, station_name, extra_metadata=N
else:
f.createDimension(dim, len(val))

extra_args = dict(
zlib=True,
fletcher32=True,
chunksizes=[CHUNKSIZES[dim]]
)

v = f.createVariable(dim, dtype, (dim,), **extra_args)
v = f.createVariable(dim, dtype, (dim,))

if val is not None:
v[:] = val
Expand All @@ -538,19 +530,14 @@ def write_records(wave_record_iterator, filename, station_name, extra_metadata=N
# add meta_station_name as additional scalar dimension
dims = ('meta_station_name',) + meta['dims']

extra_args = dict(
zlib=True,
fletcher32=True,
chunksizes=[CHUNKSIZES[dim] for dim in dims]
)

# determine dtype
if meta['dtype'] == 'vlen':
dtype = vlen_type
else:
dtype = meta['dtype']

# add correct fill value
extra_args = {}
is_number = np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.integer)
if is_number and dtype is not vlen_type:
extra_args.update(fill_value=FILL_VALUE_NUMBER)
Expand Down Expand Up @@ -595,6 +582,10 @@ def write_records(wave_record_iterator, filename, station_name, extra_metadata=N
if np.issubdtype(data.dtype, np.datetime64):
data = (data - np.datetime64(TIME_ORIGIN)) / np.timedelta64(1, 'ms')

# convert timedelta64 to seconds
if np.issubdtype(data.dtype, np.timedelta64):
data = data / np.timedelta64(1, 's')

v[0, chunk_slice, ...] = data

# set global metadata
Expand All @@ -606,15 +597,7 @@ def write_records(wave_record_iterator, filename, station_name, extra_metadata=N

# add extra variables
for name, meta in EXTRA_VARIABLES.items():
extra_args = dict(
zlib=True,
fletcher32=True,
chunksizes=[CHUNKSIZES[dim] for dim in meta['dims']]
)
v = f.createVariable(
name, meta['data'].dtype, meta['dims'],
**extra_args
)
v = f.createVariable(name, meta['data'].dtype, meta['dims'])
v[:] = meta['data']
for attr, val in meta['attrs'].items():
setattr(v, attr, val)
Expand Down

0 comments on commit a2aa08a

Please sign in to comment.