Skip to content

Commit

Permalink
first try of loopback
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffnvidia committed Dec 12, 2024
1 parent 9850aa4 commit 2e97c40
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
ucc_status_t status;
size_t extra_count;

uint32_t USE_CUDA = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_cuda;
if(!USE_CUDA){
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)){
// should I use ucc_tl_ucp_test_with_etasks ?
return;
}
}

EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test",
task->allgather_kn.etask);
task->allgather_kn.etask = NULL;
Expand Down Expand Up @@ -210,23 +218,29 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
&task->allgather_kn.p);
offset = ucc_buffer_block_offset(args->dst.info.count, size, rank) *
ucc_dt_size(args->dst.info.datatype);
if (!UCC_IS_INPLACE(*args)) {
status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
}
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset);
eargs.copy.src = args->src.info.buffer;
eargs.copy.len = args->src.info.count *
ucc_dt_size(args->src.info.datatype);
status = ucc_ee_executor_task_post(exec, &eargs,
&task->allgather_kn.etask);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
}
if(USE_CUDA){
status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
}
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset);
eargs.copy.src = args->src.info.buffer;
eargs.copy.len = args->src.info.count *
ucc_dt_size(args->src.info.datatype);
status = ucc_ee_executor_task_post(exec, &eargs,
&task->allgather_kn.etask);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
}
} else {
/*Loopback*/
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(args->src.info.buffer, args->src.info.count * ucc_dt_size(args->src.info.datatype),
args->src.info.mem_type, rank, team, task),task, out);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(args->dst.info.buffer, offset), args->src.info.count * ucc_dt_size(args->src.info.datatype),
args->dst.info.mem_type, rank, team, task),task, out);
}
} else {
ucc_kn_agx_pattern_init(size, rank, radix, args->dst.info.count,
Expand Down Expand Up @@ -430,6 +444,8 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
task->super.post = ucc_tl_ucp_allgather_knomial_start;
task->super.progress = ucc_tl_ucp_allgather_knomial_progress;
task->super.finalize = ucc_tl_ucp_allgather_knomial_finalize;
# trigger_post
# trigger_progress
status = register_memory(&task->super);
if (status < 0){
tl_error(UCC_TASK_LIB(task),
Expand Down

0 comments on commit 2e97c40

Please sign in to comment.