diff --git a/python/array_record_data_source.py b/python/array_record_data_source.py index f855c09..5e184fa 100644 --- a/python/array_record_data_source.py +++ b/python/array_record_data_source.py @@ -217,6 +217,7 @@ def __init__( "a Sequence, String, pathlib.Path or FileInstruction." ) self._read_instructions = _get_read_instructions(paths) + self._paths = [ri.filename for ri in self._read_instructions] # We open readers lazily when we need to read from them. self._readers = [None] * len(self._read_instructions) self._num_records = sum( diff --git a/python/array_record_data_source_test.py b/python/array_record_data_source_test.py index 55a0761..4ac7c8e 100644 --- a/python/array_record_data_source_test.py +++ b/python/array_record_data_source_test.py @@ -174,6 +174,13 @@ def test_array_record_source_empty_sequence(self): with self.assertRaises(ValueError): array_record_data_source.ArrayRecordDataSource([]) + def test_repr(self): + ar = array_record_data_source.ArrayRecordDataSource([ + self.testdata_dir / "digits.array_record-00000-of-00002", + self.testdata_dir / "digits.array_record-00001-of-00002", + ]) + self.assertRegex(repr(ar), r"ArrayRecordDataSource\(hash_of_paths=[\w]+\)") + class RunInParallelTest(parameterized.TestCase):