From 5d16013ed7c3b95cffc9fb78c012cb86bfe4c267 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 22 Jan 2024 10:24:39 -0800 Subject: [PATCH] Pack bootstrapAllGathers in ncclCommSplit --- src/init.cc | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/init.cc b/src/init.cc index e82e64e14..7120f40af 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1301,29 +1301,31 @@ struct ncclCommFinalizeAsyncJob { NCCL_PARAM(CommSplitShareResources, "COMM_SPLIT_SHARE_RESOURCES", NCCL_CONFIG_UNDEF_INT); static ncclResult_t commGetSplitInfo(struct ncclComm* comm, struct ncclComm* parent, int color, int key, int* nRanksRet, int* myRankRet, int* parentRanksRet) { - int* colors = NULL; - int* keys = NULL; int nRanks = 0, myRank = 0; ncclResult_t ret = ncclSuccess; - NCCLCHECKGOTO(ncclCalloc(&colors, parent->nRanks), ret, fail); - NCCLCHECKGOTO(ncclCalloc(&keys, parent->nRanks), ret, fail); + struct colorKeyPair{ + int color; + int key; + }; + struct colorKeyPair* ckPairs = NULL; + + NCCLCHECKGOTO(ncclCalloc(&ckPairs, parent->nRanks), ret, fail); // Compute nRanks, my rank and the ranks (of the original comm) before and after me - colors[parent->rank] = color; - keys[parent->rank] = key; - NCCLCHECKGOTO(bootstrapAllGather(parent->bootstrap, colors, sizeof(int)), ret, fail); - NCCLCHECKGOTO(bootstrapAllGather(parent->bootstrap, keys, sizeof(int)), ret, fail); + ckPairs[parent->rank].color = color; + ckPairs[parent->rank].key = key; + NCCLCHECKGOTO(bootstrapAllGather(parent->bootstrap, ckPairs, sizeof(struct colorKeyPair)), ret, fail); // Negative color does not create a new comm. Return now. if (color == NCCL_SPLIT_NOCOLOR) goto exit; memset(parentRanksRet, 0xff, sizeof(int) * parent->nRanks); for (int i = 0; i < parent->nRanks; i++) { - if (colors[i] != color) continue; + if (ckPairs[i].color != color) continue; // Find where to insert this rank int insert = 0; - while (insert < nRanks && keys[parentRanksRet[insert]] <= keys[i]) insert++; + while (insert < nRanks && ckPairs[parentRanksRet[insert]].key <= ckPairs[i].key) insert++; // Shift ranks by one after insert for (int r = nRanks; r > insert; r--) parentRanksRet[r] = parentRanksRet[r - 1]; // Insert our rank @@ -1339,8 +1341,7 @@ static ncclResult_t commGetSplitInfo(struct ncclComm* comm, struct ncclComm* par *myRankRet = myRank; exit: - free(colors); - free(keys); + free(ckPairs); return ret; fail: goto exit;