Skip to content

Commit

Permalink
add a user friendly exception for rf when some workers get fewer labe…
Browse files Browse the repository at this point in the history
…l values (#811)

Signed-off-by: Erik Ordentlich <[email protected]>
  • Loading branch information
eordentlich authored Dec 25, 2024
1 parent 4c2e232 commit 1f949ef
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion python/src/spark_rapids_ml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,20 @@ def _single_fit(rf: cuRf) -> Dict[str, Any]:
all_tl_mod_handles = [
rf._tl_handle_from_bytes(i) for i in mod_bytes
]
rf._concatenate_treelite_handle(all_tl_mod_handles)

# tree concatenation raises a non-user friendly error if some workers didn't get all label values
try:
rf._concatenate_treelite_handle(all_tl_mod_handles)
except RuntimeError as err:
import traceback

exc_str = traceback.format_exc()
if "different num_class than the first model object" in exc_str:
raise RuntimeError(
"Some GPU workers did not receive all label values. Rerun with fewer workers or shuffle input data."
)
else:
raise err

from cuml.fil.fil import TreeliteModel

Expand Down

0 comments on commit 1f949ef

Please sign in to comment.