diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index f404aa9c00..c63780c3f7 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -186,18 +186,34 @@ def start_prediction() -> None: self._input_download_pool.submit(item.convert) for item in v ] to_await += futs[k] - futures.wait(to_await, return_when=futures.FIRST_EXCEPTION) + done, not_done = futures.wait( + to_await, return_when=futures.FIRST_EXCEPTION + ) + + if len(not_done) > 0: + # if any future isn't done, this is because one of the + # futures raised an exception. first we cancel outstanding + # work + for fut in not_done: + fut.cancel() + # then we find an exception to raise + for fut in done: + fut.result() # raises if the future finished with an exception + # we should never get here + raise Exception( + "Internal error: lost track of exception while downloading input files" + ) + + # all futures are done. some might still have raised an + # exception, but when we call fut.result() that will re-raise + # and do the right thing for k, v in futs.items(): if isinstance(v, list): payload[k] = [] for fut in v: - # the future may not be done if and only if another - # future finished with an exception - if fut.done(): - payload[k].append(fut.result()) + payload[k].append(fut.result()) elif isinstance(v, Future): - if v.done(): - payload[k] = v.result() + payload[k] = v.result() # send the prediction to the child to start self._events.send( Envelope(