diff --git a/setup.py b/setup.py index 45d2f1af..876c9abd 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ 'scipy', 'tabulate', 'tensorboardX', + 'wcwidth', ] extras = dict() diff --git a/src/dowel/tabular_input.py b/src/dowel/tabular_input.py index 43ca6625..5f86433e 100644 --- a/src/dowel/tabular_input.py +++ b/src/dowel/tabular_input.py @@ -1,6 +1,7 @@ """A `dowel.logger` input for tabular (key-value) data.""" import contextlib import warnings +import wcwidth import numpy as np import tabulate @@ -24,8 +25,23 @@ def __init__(self): def __str__(self): """Return a string representation of the table for the logger.""" - return tabulate.tabulate( - sorted(self.as_primitive_dict.items(), key=lambda x: x[0])) + # Sort first, then pad. + lines = sorted(self.as_primitive_dict.items(), key=lambda x: x[0]) + key_widths = [wcwidth.wcswidth(key) for key, value in lines] + max_key_width = max(key_widths, default=0) + padded_keys = [] + for line, (key, value) in enumerate(lines): + padded_key = key + if line % 2 == 1: + key_width = wcwidth.wcswidth(key) + if key_width % 2 == 1 and key_width < max_key_width: + padded_key += ' ' + key_width += 1 + pad_width = (max_key_width - key_width) // 2 + padded_key += ' .' * pad_width + padded_keys.append(padded_key) + values = [value for key, value in lines] + return tabulate.tabulate(zip(padded_keys, values)) def record(self, key, val): """Save key/value entries for the table. @@ -55,11 +71,11 @@ def record_misc_stat(self, key, values, placement='back'): :param placement: Whether to put the prefix in front or in the back. """ if placement == 'front': - front = "" + front = '' back = key else: front = key - back = "" + back = '' if values: self.record(front + 'Average' + back, np.average(values)) self.record(front + 'Std' + back, np.std(values)) diff --git a/tests/dowel/test_simple_outputs.py b/tests/dowel/test_simple_outputs.py index 04b6b11d..0797ec0a 100644 --- a/tests/dowel/test_simple_outputs.py +++ b/tests/dowel/test_simple_outputs.py @@ -125,6 +125,36 @@ def test_record_tabular(self, mock_datetime): self.str_out.seek(0) assert self.str_out.read() == tab + def test_record_tabular_line_markers(self, mock_datetime): + fake_timestamp(mock_datetime) + + self.tabular.record('a', 100) + self.tabular.record('bbbbbbb', 55) + self.tabular.record('ccccc', 55) + self.tabular.record('d', 55) + self.tabular.record('ee', 55) + self.tabular.record('ff', 55) + + with redirect_stdout(self.str_out): + self.std_output.record(self.tabular) + + self.std_output.dump() + + tab = ( + '------- ---\n' + 'a 100\n' + 'bbbbbbb 55\n' + 'ccccc 55\n' + 'd . . 55\n' + 'ee 55\n' + 'ff . . 55\n' + '------- ---\n' + ) # yapf: disable + self.str_out.seek(0) + output = self.str_out.read() + print(output) + assert output == tab + def test_record_with_timestamp(self, mock_datetime): fake_timestamp(mock_datetime)