Skip to content

Commit

Permalink
[MSA] Microoptimizations in AsynchronousCopyResource.
Browse files Browse the repository at this point in the history
Based on a profiling the memory-space assignment algorithm, this change makes two small optimizations to `AsynchronousCopyResource`:

* Pass a pre-reserved `std::vector<std::pair<int64_t, float>>` instead of an `absl::flat_hash_map<int64_t, float>` to capture the changes to `delays`, because we do not need random access to the map, and a vector is faster to resize than a hash map.
* Cache the raw data pointers from `std::vector<float>` to avoid the overhead of bounds and null checking in the hardened `std::vector` implementation.
* Replace the simple functions in `time_utils.cc` with inline implementations in `time_utils.h`: since these boil down to adding or subtracting `1`, the resulting code will be smaller and more efficient (and less likely to spill FP registers to the stack).
* Refactor the inner-loop that writes `delay_changes` so that the floating-point operations are not separated by a data-dependent call, and we can keep more `float`s in registers.

PiperOrigin-RevId: 719112419
  • Loading branch information
mrry authored and Google-ML-Automation committed Jan 24, 2025
1 parent 7333fdb commit 43b2c3c
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 65 deletions.
1 change: 0 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6336,7 +6336,6 @@ cc_library(

cc_library(
name = "time_utils",
srcs = ["time_utils.cc"],
hdrs = ["time_utils.h"],
deps = [],
)
Expand Down
5 changes: 2 additions & 3 deletions xla/service/memory_space_assignment/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -588,18 +588,17 @@ cc_library(
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
56 changes: 40 additions & 16 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
Expand Down Expand Up @@ -3129,8 +3130,16 @@ bool AsynchronousCopyOrdering::ViolatesOrdering(int64_t exclusive_start_time,

bool AsynchronousCopyResource::ConsumeResource(
int64_t exclusive_start_time, int64_t end_time, float resource,
absl::flat_hash_map<int64_t, float>* delay_change_map,
std::vector<std::pair<int64_t, float>>* delay_changes,
float resource_to_free) {
// Cache the pointers to the arrays to avoid the overhead of `operator[]`
// size checks in hardened libc++.
//
// NOTE: Do not modify the vectors `initial_resources_` or `delay_` in this
// function, otherwise the pointers will become dangling.
float* initial_resources_ptr = initial_resources_.data();
float* delay_ptr = delay_.data();

std::list<AsynchronousCopy>::iterator current_copy = async_copies_.end();
// In order to propagate the resource to the next scheduled copy, we iterate
// over the copies in start time order until we either find enough free
Expand Down Expand Up @@ -3160,7 +3169,8 @@ bool AsynchronousCopyResource::ConsumeResource(
// this copy would have to be delayed because of an earlier copy that wasn't
// finished when this copy starts.
if (current_copy == async_copies_.end()) {
resource += delay_[ExclusiveToInclusiveStartTime(exclusive_start_time)];
resource +=
delay_ptr[ExclusiveToInclusiveStartTime(exclusive_start_time)];
}

// Find the copy that is right after this one. If there are leftover
Expand All @@ -3186,7 +3196,7 @@ bool AsynchronousCopyResource::ConsumeResource(
time < end_time && resource != 0; ++time) {
// Iterate over the logical times that this copy spans. Note that the
// start and end time ranges are exclusive.
float used_resource = std::min(resource, initial_resources_[time]);
float used_resource = std::min(resource, initial_resources_ptr[time]);
if (next_copy != async_copies_.end() &&
next_copy->exclusive_start_time ==
InclusiveToExclusiveStartTime(time)) {
Expand All @@ -3199,15 +3209,17 @@ bool AsynchronousCopyResource::ConsumeResource(
if (!delay_for_next_copy.has_value()) {
// Update the delay_ vector and resource_freed variable with the amount
// that was freed when removing the copy.
float old_delay = delay_ptr[time];
float old_resource =
std::max(0.0f, initial_resources_[time] - delay_[time]);
if (delay_change_map) {
delay_change_map->emplace(time, delay_[time]);
}
delay_[time] = std::max(0.0f, resource - resource_to_free);
std::max(0.0f, initial_resources_ptr[time] - old_delay);
float new_delay = std::max(0.0f, resource - resource_to_free);
float new_resource =
std::max(0.0f, initial_resources_[time] - delay_[time]);
std::max(0.0f, initial_resources_ptr[time] - new_delay);
resource_freed += std::max(0.0f, new_resource - old_resource);
delay_ptr[time] = new_delay;
if (delay_changes) {
delay_changes->emplace_back(time, old_delay);
}
}
// Update the resource with the used amount in this logical time.
resource -= used_resource;
Expand Down Expand Up @@ -3303,7 +3315,7 @@ void AsynchronousCopyResource::RemoveCopy(
copy_it->exclusive_start_time);
CHECK(ConsumeResource(copy_it->exclusive_start_time, copy_it->end_time,
/*resource=*/0,
/*delay_change_map=*/nullptr,
/*delay_changes=*/nullptr,
/*resource_to_free=*/copy_it->resource));
// If the copy to be removed is the value pointed by async_copy_time_map_, we
// make the next copy with the same start time to be pointed by
Expand All @@ -3325,24 +3337,36 @@ void AsynchronousCopyResource::RemoveCopy(
bool AsynchronousCopyResource::HasEnoughResource(int64_t exclusive_start_time,
int64_t end_time,
float resource) {
absl::flat_hash_map<int64_t, float> delay_changes;
std::vector<std::pair<int64_t, float>> delay_changes;
delay_changes.reserve(delay_.size());
bool result =
ConsumeResource(exclusive_start_time, end_time, resource, &delay_changes);
for (const auto& change_pair : delay_changes) {
delay_[change_pair.first] = change_pair.second;
// Apply the delay changes in reverse order. This ensures that the original
// value of each delay is restored.
if (!delay_changes.empty()) {
for (int64_t i = delay_changes.size() - 1; i >= 0; --i) {
const auto& [time, delay] = delay_changes[i];
delay_[time] = delay;
}
}
return result;
}

bool AsynchronousCopyResource::HasEnoughResourceMultiCheck(
const std::vector<ResourceSpec>& specs) {
absl::flat_hash_map<int64_t, float> delay_changes;
std::vector<std::pair<int64_t, float>> delay_changes;
delay_changes.reserve(delay_.size());
bool result = absl::c_all_of(specs, [&](const ResourceSpec& spec) {
return ConsumeResource(spec.exclusive_start_time, spec.end_time,
spec.resource, &delay_changes);
});
for (const auto& change_pair : delay_changes) {
delay_[change_pair.first] = change_pair.second;
// Apply the delay changes in reverse order. This ensures that the original
// value of each delay is restored.
if (!delay_changes.empty()) {
for (int64_t i = delay_changes.size() - 1; i >= 0; --i) {
const auto& [time, delay] = delay_changes[i];
delay_[time] = delay;
}
}
return result;
}
Expand Down
10 changes: 7 additions & 3 deletions xla/service/memory_space_assignment/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALGORITHM_H_

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <list>
#include <map>
Expand All @@ -35,6 +36,8 @@ limitations under the License.
#endif
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -218,12 +221,13 @@ class AsynchronousCopyResource {

private:
// Internal helper method to implement adding/removing/checking resources.
// ConsumeResource() may modify delay_. If delay_change_map is not null,
// ConsumeResource() may modify delay_. If delay_changes is not null,
// for any change to delay_[i], {i, delay_[i]} will be added to
// delay_change_map, allowing callers to undo any modifications.
// delay_changes, allowing callers to undo any modifications by iterating over
// the vector in reverse order.
bool ConsumeResource(
int64_t exclusive_start_time, int64_t end_time, float resource,
absl::flat_hash_map<int64_t, float>* delay_change_map = nullptr,
std::vector<std::pair<int64_t, float>>* delay_changes = nullptr,
float resource_to_free = 0.0);

// Same as the public RemoveCopy except it works on the async_copies_
Expand Down
38 changes: 0 additions & 38 deletions xla/service/time_utils.cc

This file was deleted.

20 changes: 16 additions & 4 deletions xla/service/time_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,22 @@ limitations under the License.
namespace xla {

// Convert between inclusive/exclusive start/end times.
int64_t ExclusiveToInclusiveStartTime(int64_t exclusive_time);
int64_t InclusiveToExclusiveStartTime(int64_t inclusive_time);
int64_t ExclusiveToInclusiveEndTime(int64_t exclusive_time);
int64_t InclusiveToExclusiveEndTime(int64_t inclusive_time);

inline int64_t ExclusiveToInclusiveStartTime(int64_t exclusive_time) {
return exclusive_time + 1;
}

inline int64_t InclusiveToExclusiveStartTime(int64_t inclusive_time) {
return inclusive_time - 1;
}

inline int64_t ExclusiveToInclusiveEndTime(int64_t exclusive_time) {
return exclusive_time - 1;
}

inline int64_t InclusiveToExclusiveEndTime(int64_t inclusive_time) {
return inclusive_time + 1;
}

} // namespace xla

Expand Down

0 comments on commit 43b2c3c

Please sign in to comment.