Skip to content

Commit

Permalink
working version on IL1 with Cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffnvidia committed Sep 19, 2024
1 parent 00a9097 commit 3ca742c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 32 deletions.
68 changes: 42 additions & 26 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
ucc_rank_t rank = VRANK(task->subset.myrank, broot, size);
size_t local = GET_LOCAL_COUNT(args, size, rank);
ucp_mem_h *mh_list = task->allgather_kn.mh_list;
int max_mh = task->allgather_kn.max_mh;
void *sbuf;
ptrdiff_t peer_seg_offset, local_seg_offset;
ucc_rank_t peer, peer_dist;
Expand All @@ -74,15 +73,15 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
INV_VRANK(peer,broot,size)),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh >= max_mh);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);

}
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(rbuf, data_size, mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer,broot,size)),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh >= max_mh);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}
if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) {
peer = ucc_knomial_pattern_get_extra(p, rank);
Expand All @@ -92,7 +91,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
local * dt_size), extra_count * dt_size,
mem_type, peer, team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh >= max_mh);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}

UCC_KN_PHASE_EXTRA:
Expand Down Expand Up @@ -121,14 +120,13 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
continue;
}
}
printf("progress : count_mh: %d, mh: %lx\n", task->allgather_kn.count_mh, (unsigned long)mh_list[task->allgather_kn.count_mh]);
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(sbuf, local_seg_count * dt_size,
mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh >= max_mh);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}

for (loop_step = 1; loop_step < radix; loop_step++) {
Expand All @@ -152,7 +150,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
INV_VRANK(peer, broot, size)),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh >= max_mh);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}
UCC_KN_PHASE_LOOP:
if (UCC_INPROGRESS == ucc_tl_ucp_test_recv_with_etasks(task)) {
Expand All @@ -170,7 +168,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
INV_VRANK(peer, broot, size)),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh >= max_mh);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}
UCC_KN_PHASE_PROXY:
if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks(task)) {
Expand All @@ -179,7 +177,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
}

out:
ucc_assert(task->allgather_kn.count_mh-1 == max_mh);
ucc_assert(task->allgather_kn.count_mh-1 == task->allgather_kn.max_mh);
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
task->super.status = UCC_OK;
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_done", 0);
Expand Down Expand Up @@ -252,6 +250,7 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_type_t ct = args->coll_type;
ucc_kn_radix_t radix = task->allgather_kn.p.radix;
uint8_t node_type = task->allgather_kn.p.node_type;
ucc_knomial_pattern_t *p = &task->allgather_kn.p;
Expand All @@ -273,18 +272,28 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
ucc_status_t status;
size_t extra_count;

ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
ucp_mem_map_params_t mmap_params;
ucp_mem_h mh;
int size_of_list = 1;
int count_mh = 0;
ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h));
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
ucp_mem_map_params_t mmap_params;
// ucp_mem_h mh;
int size_of_list = 1;
int count_mh = 0;
ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h));

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0);
task->allgather_kn.etask = NULL;
task->allgather_kn.phase = UCC_KN_PHASE_INIT;
if (ct == UCC_COLL_TYPE_ALLGATHER) {
ucc_kn_ag_pattern_init(size, rank, radix, args->dst.info.count,
&task->allgather_kn.p);
} else {
ucc_kn_agx_pattern_init(size, rank, radix, args->dst.info.count,
&task->allgather_kn.p);
}

mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
mmap_params.memory_type = ucc_memtype_to_ucs[mem_type];
printf("I'm in register memory");
if (KN_NODE_EXTRA == node_type) {
if (p->type != KN_PATTERN_ALLGATHERX) {
mmap_params.address = task->allgather_kn.sbuf;
Expand All @@ -310,13 +319,10 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
goto out;
}
while (!ucc_knomial_pattern_loop_done(p)) {
printf("in the while loop");
ucc_kn_ag_pattern_peer_seg(rank, p, &local_seg_count,
&local_seg_offset);
sbuf = PTR_OFFSET(rbuf, local_seg_offset * dt_size);

for (loop_step = radix - 1; loop_step > 0; loop_step--) {
printf("in the for loop");
peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step);
if (peer == UCC_KN_PEER_NULL)
continue;
Expand All @@ -329,7 +335,6 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
}
mmap_params.address = sbuf;
mmap_params.length = local_seg_count * dt_size;
printf("register memory : count_mh: %d, mh: %lx\n", count_mh, (unsigned long)mh_list[count_mh]);
MEM_MAP();
}

Expand Down Expand Up @@ -370,12 +375,23 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
ucc_status_t ucc_tl_ucp_allgather_knomial_finalize(ucc_coll_task_t *coll_task){
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_status_t status;
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);

ucc_mpool_cleanup(&task->allgather_kn.etask_node_mpool, 1);
for (int i=0; i<task->allgather_kn.max_mh+1; i++){
ucp_mem_unmap(ctx->worker.ucp_context, task->allgather_kn.mh_list[i]);
}
free(task->allgather_kn.mh_list);
status = ucc_tl_ucp_coll_finalize(&task->super);
if (status < 0){
tl_error(UCC_TASK_LIB(task),
"failed to initialize ucc_mpool");
}

return UCC_OK;
};
}

ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
Expand All @@ -401,17 +417,17 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
task->subset.myrank = sbgp->group_rank;
task->subset.map = sbgp->map;
}
status = register_memory(&task->super);
if (status < 0){
tl_error(UCC_TASK_LIB(task),
"failed to register memory");
}
task->allgather_kn.etask_linked_list_head = NULL;
task->allgather_kn.p.radix = radix;
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
task->super.post = ucc_tl_ucp_allgather_knomial_start;
task->super.progress = ucc_tl_ucp_allgather_knomial_progress;
task->super.finalize = ucc_tl_ucp_allgather_knomial_finalize;
status = register_memory(&task->super);
if (status < 0){
tl_error(UCC_TASK_LIB(task),
"failed to register memory");
}
*task_h = &task->super;
return UCC_OK;
}
Expand Down
9 changes: 3 additions & 6 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,14 @@ void ucc_tl_ucp_team_default_score_str_free(
} while(0)

#define MEM_MAP() do { \
status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); \
status = ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh_list[count_mh++])); \
if (UCC_OK != status) { \
return status; \
} \
if (count_mh == size_of_list){ \
size_of_list *= 2; \
mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); \
} \
mh_list[count_mh++] = mh; \
} while(0)

#define EXEC_TASK_WAIT(_etask, ...) \
Expand Down Expand Up @@ -503,7 +502,7 @@ static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *t
while(current_node != NULL) {
status = ucc_ee_executor_task_test(current_node->etask);
if (status > 0) {
ucp_memcpy_device_complete(current_node->etask->completion, status);
ucp_memcpy_device_complete(current_node->etask->completion, ucc_status_to_ucs_status(status));
status_2 = ucc_ee_executor_task_finalize(current_node->etask);
ucc_mpool_put(current_node);
if (ucc_unlikely(status_2 < 0)){
Expand All @@ -517,9 +516,7 @@ static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *t
task->allgather_kn.etask_linked_list_head = current_node->next;
}
}
else {
prev_node = current_node;
}
prev_node = current_node;
current_node = current_node->next; //to iterate to next node
}
if (UCC_TL_UCP_TASK_RECV_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) {
Expand Down

0 comments on commit 3ca742c

Please sign in to comment.