Skip to content

Commit

Permalink
Fix state methods and actually fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenDS9 committed Feb 16, 2023
1 parent 181f2f6 commit 606ff72
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions torchdata/datapipes/iter/util/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
self._depleted = False

def _load_map(self):
if self._map is None:
self._map = {}
self._itr = iter(self.datapipe)
while not self._depleted:
try:
self._load_next_item()
Expand All @@ -84,10 +81,7 @@ def _load_map(self):

def __getitem__(self, index):
try:
if self._map is None:
self._map = {}
self._itr = iter(self.datapipe)
else:
if self._map is not None:
return self._map[index]
except KeyError:
pass
Expand All @@ -101,7 +95,10 @@ def __getitem__(self, index):
raise IndexError(f"Index {index} is invalid for IterToMapConverter.")

def _load_next_item(self):
elem = next(self._itr)
if self._map is None:
self._map = {}
self._itr = iter(self.datapipe)
elem = next(self._itr) # type: ignore[arg-type]
inp = elem if self.key_value_fn is None else self.key_value_fn(elem)
try:
length = len(inp)
Expand Down Expand Up @@ -135,14 +132,10 @@ def __getstate__(self):
dill_key_value_fn = dill.dumps(self.key_value_fn)
else:
dill_key_value_fn = self.key_value_fn
return (
self.datapipe,
dill_key_value_fn,
self._map,
)
return (self.datapipe, dill_key_value_fn, self._map, self._itr, self._depleted)

def __setstate__(self, state):
(self.datapipe, dill_key_value_fn, self._map) = state
(self.datapipe, dill_key_value_fn, self._map, self._itr, self._depleted) = state
if DILL_AVAILABLE:
self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment]
else:
Expand Down

0 comments on commit 606ff72

Please sign in to comment.