Skip to content

Commit

Permalink
improved torch generator exception
Browse files Browse the repository at this point in the history
  • Loading branch information
jkiesele committed Nov 7, 2024
1 parent 06f6a05 commit 65c399b
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/djcdata/torch_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,12 @@ def __iter__(self):
return self

def __next__(self):
data = next(self.iterator)
# Convert numpy arrays to torch tensors and move to device
return self._convert_to_tensors(data)
try:
data = next(self.iterator)
# Convert numpy arrays to torch tensors and move to device
return self._convert_to_tensors(data)
except StopIteration:
return None

def _convert_to_tensors(self, data):
# data can be (x, y) or (x, y, w)
Expand Down

0 comments on commit 65c399b

Please sign in to comment.