Skip to content

Commit

Permalink
[XLA:FFI] Add an FFI compatible implementation of tsl::CountDownAsync…
Browse files Browse the repository at this point in the history
…ValueRef.

This supports the common pattern of enqueuing a specific number of async tasks within an FFI handler.

PiperOrigin-RevId: 722455855
  • Loading branch information
dfm authored and Google-ML-Automation committed Feb 3, 2025
1 parent 492a921 commit 8bc7fa2
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 0 deletions.
80 changes: 80 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include <iostream>
#include <limits>
#include <memory>
#include <mutex> // NOLINT
#include <numeric>
#include <optional>
#include <ostream>
Expand Down Expand Up @@ -322,10 +323,14 @@ class ErrorOr : public Expected<T, Error> {
// A promise to complete execution with a success or an error.
class Promise;

// A promise that completes when a specific number of count downs have occurred.
class CountDownPromise;

// A future that becomes available when a corresponding promise is completed.
class Future {
public:
explicit Future(const Promise& promise);
explicit Future(const CountDownPromise& promise);

Future(Future&&) = default;
Future& operator=(Future&&) = default;
Expand Down Expand Up @@ -377,6 +382,9 @@ class Promise {
public:
Promise() : data_(std::make_shared<Future::Data>()) {}

Promise(const Promise&) = default;
Promise& operator=(const Promise&) = default;

Promise(Promise&&) = default;
Promise& operator=(Promise&&) = default;

Expand All @@ -391,11 +399,83 @@ class Promise {
std::shared_ptr<Future::Data> data_;
};

// A simple implementation of `tsl::CountDownAsyncValueRef` that is compatible
// with `ffi::Future`.
class CountDownPromise {
public:
CountDownPromise() = default;

CountDownPromise(Promise promise, int64_t count)
: state_(std::make_shared<State>(std::move(promise), count)) {
assert(count > 0 && "Count must be positive");
}

explicit CountDownPromise(int64_t count)
: CountDownPromise(Promise(), count) {}

// Drops the count by `count` and returns true if the underlying promise
// became available.
bool CountDown(size_t count, const Error& error = Error::Success()) {
assert(state_->count.load() >= count && "Invalid count down value");

if (XLA_FFI_PREDICT_FALSE(!error.success())) {
const std::lock_guard<std::mutex> lock(state_->mutex);
state_->is_error.store(true, std::memory_order_release);
state_->error = error;
}

bool is_complete =
state_->count.fetch_sub(count, std::memory_order_acq_rel) == count;
if (XLA_FFI_PREDICT_FALSE(is_complete)) {
bool is_error = state_->is_error.load(std::memory_order_acquire);
if (XLA_FFI_PREDICT_FALSE(is_error)) {
auto take_error = [&] {
const std::lock_guard<std::mutex> lock(state_->mutex);
return state_->error;
};
state_->promise.SetError(take_error());
return true;
} else {
state_->promise.SetAvailable();
return true;
}
}

return false;
}

// Drops the count by `1` and returns true if the underlying promise became
// available.
bool CountDown(Error error = Error::Success()) { return CountDown(1, error); }

private:
friend class Future;

struct State {
State(Promise promise, int64_t count)
: promise(std::move(promise)), count(count), is_error(false) {}

Promise promise;
std::atomic<int64_t> count;
std::atomic<bool> is_error;

std::mutex mutex;
Error error;
};

std::shared_ptr<State> state_;

const Promise& AsPromise() const { return state_->promise; }
};

inline Future::Future(const Promise& promise) : data_(promise.data_) {
assert(data_.use_count() == 2 &&
"Promise can be used to create at most one Future");
}

inline Future::Future(const CountDownPromise& promise)
: Future(promise.AsPromise()) {}

template <typename F>
void Future::OnReady(F&& f) {
static_assert(std::is_invocable_v<F, const std::optional<Error>&>,
Expand Down
55 changes: 55 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,61 @@ TEST(FfiTest, FutureRace) {
}
}

TEST(FfiTest, CountDownSuccess) {
CountDownPromise counter(2);
Future future(counter);
EXPECT_FALSE(counter.CountDown());
EXPECT_TRUE(counter.CountDown());
future.OnReady([](const std::optional<Error>& error) {
EXPECT_FALSE(error.has_value());
});
}

TEST(FfiTest, CountDownError) {
CountDownPromise counter(3);
Future future(counter);
EXPECT_FALSE(counter.CountDown());
EXPECT_FALSE(counter.CountDown(Error(ErrorCode::kInternal, "Test error")));
EXPECT_TRUE(counter.CountDown());
future.OnReady([](const std::optional<Error>& error) {
EXPECT_TRUE(error.has_value());
EXPECT_THAT(error->message(), HasSubstr("Test error"));
});
}

TEST(FfiTest, CountDownSuccessFromThreadPool) {
tsl::thread::ThreadPool pool(tsl::Env::Default(), "ffi-test", 2);

CountDownPromise counter(2);
Future future(counter);

future.OnReady([](const std::optional<Error>& error) {
EXPECT_FALSE(error.has_value());
});

for (int64_t i = 0; i < 2; ++i) {
pool.Schedule([counter]() mutable { counter.CountDown(); });
}
}

TEST(FfiTest, CountDownErrorFromThreadPool) {
tsl::thread::ThreadPool pool(tsl::Env::Default(), "ffi-test", 2);

CountDownPromise counter(3);
Future future(counter);

future.OnReady([](const std::optional<Error>& error) {
EXPECT_TRUE(error.has_value());
EXPECT_THAT(error->message(), HasSubstr("Test error"));
});

pool.Schedule([counter]() mutable { counter.CountDown(); });
pool.Schedule([counter]() mutable {
counter.CountDown(Error(ErrorCode::kInternal, "Test error"));
});
pool.Schedule([counter]() mutable { counter.CountDown(); });
}

TEST(FfiTest, ReturnError) {
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
auto call_frame = builder.Build();
Expand Down

0 comments on commit 8bc7fa2

Please sign in to comment.