Skip to content

Commit

Permalink
Fix collective-permute host memory not being unregistered.
Browse files Browse the repository at this point in the history
CUDA host memory was registered in Initialize() and unregistered in Cleanup() but Cleanup() is not called. Now instead store host memory as a steam_executor::MemoryAllocation object, which automatically unregisters it in the destructor.

PiperOrigin-RevId: 723243458
  • Loading branch information
reedwm authored and Google-ML-Automation committed Feb 4, 2025
1 parent 519df88 commit a7c38a6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
1 change: 1 addition & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ cc_library(
"//xla/service:global_device_id",
"//xla/service/gpu:backend_configs_cc",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:memory_allocation",
"//xla/stream_executor:stream",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:errors",
Expand Down
24 changes: 7 additions & 17 deletions xla/backends/gpu/runtime/nccl_collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/nccl_collective_permute_thunk.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
Expand Down Expand Up @@ -43,6 +44,7 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
Expand Down Expand Up @@ -176,10 +178,10 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(
GetCurrentId(params.collective_params, config_));
absl::MutexLock lock(&barrier_mutex_);
if (barrier_flags_.find(current_id) == barrier_flags_.end()) {
if (!params.stream->parent()->HostMemoryRegister(
&barrier_flags_[current_id], sizeof(uint8_t))) {
LOG(ERROR) << "Registering barrier flag failed.";
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<se::MemoryAllocation> alloc,
params.stream->parent()->HostMemoryAllocate(sizeof(uint8_t)));
barrier_flags_[current_id] = std::move(alloc);
}

TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -212,18 +214,6 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(
return absl::OkStatus();
}

absl::Status NcclCollectivePermuteStartThunk::Cleanup(
const CleanupParams& params) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));

absl::MutexLock lock(&barrier_mutex_);
if (!params.executor->HostMemoryUnregister(&barrier_flags_[current_id])) {
LOG(ERROR) << "Unregistering barrier flag failed.";
}
return absl::OkStatus();
}

absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
const ExecuteParams& params, se::Stream& stream,
CommunicatorHandle comm_handle) {
Expand All @@ -248,7 +238,7 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
if (use_memcpy) {
se::DeviceMemoryBase sync_var_address =
se::DeviceMemoryBase((void*)(&barrier_flags_[current_id]));
se::DeviceMemoryBase(barrier_flags_[current_id]->opaque());
TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce(
sync_var_address, sync_var_address, PrimitiveType::U8, 1,
ReductionKind::MIN, GpuCollectives::On(stream)));
Expand Down
7 changes: 5 additions & 2 deletions xla/backends/gpu/runtime/nccl_collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#define XLA_BACKENDS_GPU_RUNTIME_NCCL_COLLECTIVE_PERMUTE_THUNK_H_

#include <cstdint>
#include <memory>
#include <unordered_map>

#include "absl/base/thread_annotations.h"
#include "absl/container/node_hash_map.h"
Expand All @@ -31,6 +33,7 @@ limitations under the License.
#include "xla/core/collectives/communicator.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream.h"
#include "xla/tsl/concurrency/async_value.h"
#include "xla/tsl/concurrency/async_value_ref.h"
Expand Down Expand Up @@ -104,7 +107,6 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
const std::vector<Buffer>& buffers,
bool p2p_memcpy_enabled);
absl::Status Initialize(const InitializeParams& params) override;
absl::Status Cleanup(const CleanupParams& params) override;

static const char* GetHloOpName() { return "collective-permute-start"; }

Expand All @@ -119,7 +121,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
std::vector<Buffer> buffers_;
RecvPtrMap recv_ptr_map_;
absl::Mutex barrier_mutex_;
std::unordered_map<int64_t, uint8_t> barrier_flags_;
std::unordered_map<int64_t, std::unique_ptr<se::MemoryAllocation>>
barrier_flags_;
bool p2p_memcpy_enabled_ = false;
int64_t device_count_;
};
Expand Down

0 comments on commit a7c38a6

Please sign in to comment.