diff --git a/src/components/tl/ucp/allgather/allgather_knomial.c b/src/components/tl/ucp/allgather/allgather_knomial.c index 360f510a65..f33f668109 100644 --- a/src/components/tl/ucp/allgather/allgather_knomial.c +++ b/src/components/tl/ucp/allgather/allgather_knomial.c @@ -60,6 +60,14 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) ucc_status_t status; size_t extra_count; + uint32_t USE_CUDA = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_cuda; + if(!USE_CUDA){ + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)){ + // should I use ucc_tl_ucp_test_with_etasks ? + return; + } + } + EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test", task->allgather_kn.etask); task->allgather_kn.etask = NULL; @@ -210,23 +218,29 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) &task->allgather_kn.p); offset = ucc_buffer_block_offset(args->dst.info.count, size, rank) * ucc_dt_size(args->dst.info.datatype); - if (!UCC_IS_INPLACE(*args)) { - status = ucc_coll_task_get_executor(&task->super, &exec); - if (ucc_unlikely(status != UCC_OK)) { - task->super.status = status; - return status; - } - eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; - eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset); - eargs.copy.src = args->src.info.buffer; - eargs.copy.len = args->src.info.count * - ucc_dt_size(args->src.info.datatype); - status = ucc_ee_executor_task_post(exec, &eargs, - &task->allgather_kn.etask); - if (ucc_unlikely(status != UCC_OK)) { - task->super.status = status; - return status; - } + if(USE_CUDA){ + status = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(status != UCC_OK)) { + task->super.status = status; + return status; + } + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset); + eargs.copy.src = args->src.info.buffer; + eargs.copy.len = args->src.info.count * + ucc_dt_size(args->src.info.datatype); + status = ucc_ee_executor_task_post(exec, &eargs, + &task->allgather_kn.etask); + if (ucc_unlikely(status != UCC_OK)) { + task->super.status = status; + return status; + } + } else { + /*Loopback*/ + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(args->src.info.buffer, args->src.info.count * ucc_dt_size(args->src.info.datatype), + args->src.info.mem_type, rank, team, task),task, out); + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(args->dst.info.buffer, offset), args->src.info.count * ucc_dt_size(args->src.info.datatype), + args->dst.info.mem_type, rank, team, task),task, out); } } else { ucc_kn_agx_pattern_init(size, rank, radix, args->dst.info.count, @@ -430,6 +444,8 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r( task->super.post = ucc_tl_ucp_allgather_knomial_start; task->super.progress = ucc_tl_ucp_allgather_knomial_progress; task->super.finalize = ucc_tl_ucp_allgather_knomial_finalize; + # trigger_post + # trigger_progress status = register_memory(&task->super); if (status < 0){ tl_error(UCC_TASK_LIB(task),