Skip to content

Commit

Permalink
[SYCL][Graph] Add common reference semantics (#16788)
Browse files Browse the repository at this point in the history
Adds missing common reference semantic functionality such as operator==,
operator!= and hash functions to all sycl graph related classes.

---------

Co-authored-by: Ewan Crawford <[email protected]>
  • Loading branch information
fabiomestre and EwanC authored Feb 13, 2025
1 parent a63f8b4 commit 822cf9b
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 5 deletions.
82 changes: 82 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ class __SYCL_EXPORT node {
/// Update the Range of this node if it is a kernel execution node
template <int Dimensions> void update_range(range<Dimensions> executionRange);

/// Common Reference Semantics
friend bool operator==(const node &LHS, const node &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const node &LHS, const node &RHS) {
return !operator==(LHS, RHS);
}

private:
node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {}

Expand Down Expand Up @@ -181,6 +189,16 @@ class __SYCL_EXPORT dynamic_command_group {
size_t get_active_index() const;
void set_active_index(size_t Index);

/// Common Reference Semantics
friend bool operator==(const dynamic_command_group &LHS,
const dynamic_command_group &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const dynamic_command_group &LHS,
const dynamic_command_group &RHS) {
return !operator==(LHS, RHS);
}

private:
template <class Obj>
friend const decltype(Obj::impl) &
Expand Down Expand Up @@ -307,6 +325,16 @@ class __SYCL_EXPORT modifiable_command_graph
/// Get a list of all root nodes (nodes without dependencies) in this graph.
std::vector<node> get_root_nodes() const;

/// Common Reference Semantics
friend bool operator==(const modifiable_command_graph &LHS,
const modifiable_command_graph &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const modifiable_command_graph &LHS,
const modifiable_command_graph &RHS) {
return !operator==(LHS, RHS);
}

protected:
/// Constructor used internally by the runtime.
/// @param Impl Detail implementation class to construct object with.
Expand Down Expand Up @@ -386,6 +414,16 @@ class __SYCL_EXPORT executable_command_graph
/// @param Nodes The nodes to use for updating the graph.
void update(const std::vector<node> &Nodes);

/// Common Reference Semantics
friend bool operator==(const executable_command_graph &LHS,
const executable_command_graph &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const executable_command_graph &LHS,
const executable_command_graph &RHS) {
return !operator==(LHS, RHS);
}

protected:
/// Constructor used by internal runtime.
/// @param Graph Detail implementation class to construct with.
Expand Down Expand Up @@ -452,6 +490,16 @@ class __SYCL_EXPORT dynamic_parameter_base {
Graph,
size_t ParamSize, const void *Data);

/// Common Reference Semantics
friend bool operator==(const dynamic_parameter_base &LHS,
const dynamic_parameter_base &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const dynamic_parameter_base &LHS,
const dynamic_parameter_base &RHS) {
return !operator==(LHS, RHS);
}

protected:
void updateValue(const void *NewValue, size_t Size);

Expand Down Expand Up @@ -512,3 +560,37 @@ command_graph(const context &SyclContext, const device &SyclDevice,

} // namespace _V1
} // namespace sycl

namespace std {
template <> struct __SYCL_EXPORT hash<sycl::ext::oneapi::experimental::node> {
size_t operator()(const sycl::ext::oneapi::experimental::node &Node) const;
};

template <>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::dynamic_command_group> {
size_t operator()(const sycl::ext::oneapi::experimental::dynamic_command_group
&DynamicCGH) const;
};

template <sycl::ext::oneapi::experimental::graph_state State>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::command_graph<State>> {
size_t operator()(const sycl::ext::oneapi::experimental::command_graph<State>
&Graph) const {
auto ID = sycl::detail::getSyclObjImpl(Graph)->getID();
return std::hash<decltype(ID)>()(ID);
}
};

template <typename ValueT>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>> {
size_t
operator()(const sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>
&DynamicParam) const {
auto ID = sycl::detail::getSyclObjImpl(DynamicParam)->getID();
return std::hash<decltype(ID)>()(ID);
}
};
} // namespace std
23 changes: 20 additions & 3 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ graph_impl::graph_impl(const sycl::context &SyclContext,
const sycl::device &SyclDevice,
const sycl::property_list &PropList)
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
MEventsMap(), MInorderQueueMap() {
MEventsMap(), MInorderQueueMap(),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
checkGraphPropertiesAndThrow(PropList);
if (PropList.has_property<property::graph::no_cycle_check>()) {
MSkipCycleChecks = true;
Expand Down Expand Up @@ -913,7 +914,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
MExecutionEvents(),
MIsUpdatable(PropList.has_property<property::graph::updatable>()),
MEnableProfiling(
PropList.has_property<property::graph::enable_profiling>()) {
PropList.has_property<property::graph::enable_profiling>()),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
checkGraphPropertiesAndThrow(PropList);
// If the graph has been marked as updatable then check if the backend
// actually supports that. Devices supporting aspect::ext_oneapi_graph must
Expand Down Expand Up @@ -2026,7 +2028,8 @@ void dynamic_parameter_impl::updateCGAccessor(

dynamic_command_group_impl::dynamic_command_group_impl(
const command_graph<graph_state::modifiable> &Graph)
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {}
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {}

void dynamic_command_group_impl::finalizeCGFList(
const std::vector<std::function<void(handler &)>> &CGFList) {
Expand Down Expand Up @@ -2150,3 +2153,17 @@ void dynamic_command_group::set_active_index(size_t Index) {
} // namespace ext
} // namespace _V1
} // namespace sycl

size_t std::hash<sycl::ext::oneapi::experimental::node>::operator()(
const sycl::ext::oneapi::experimental::node &Node) const {
auto ID = sycl::detail::getSyclObjImpl(Node)->getID();
return std::hash<decltype(ID)>()(ID);
}

size_t
std::hash<sycl::ext::oneapi::experimental::dynamic_command_group>::operator()(
const sycl::ext::oneapi::experimental::dynamic_command_group &DynamicCG)
const {
auto ID = sycl::detail::getSyclObjImpl(DynamicCG)->getID();
return std::hash<decltype(ID)>()(ID);
}
32 changes: 30 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
return MBarrierDependencyMap[Queue];
}

unsigned long long getID() const { return MID; }

private:
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
/// @param NodeFunc A function which receives as input a node in the graph to
Expand Down Expand Up @@ -1198,6 +1200,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
MBarrierDependencyMap;

unsigned long long MID;
// Used for std::hash in order to create a unique hash for the instance.
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

/// Class representing the implementation of command_graph<executable>.
Expand Down Expand Up @@ -1297,6 +1303,8 @@ class exec_graph_impl {

void updateImpl(std::shared_ptr<node_impl> NodeImpl);

unsigned long long getID() const { return MID; }

private:
/// Create a command-group for the node and add it to command-buffer by going
/// through the scheduler.
Expand Down Expand Up @@ -1408,21 +1416,27 @@ class exec_graph_impl {
// Stores a cache of node ids from modifiable graph nodes to the companion
// node(s) in this graph. Used for quick access when updating this graph.
std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;

unsigned long long MID;
// Used for std::hash in order to create a unique hash for the instance.
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

class dynamic_parameter_impl {
public:
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
size_t ParamSize, const void *Data)
: MGraph(GraphImpl), MValueStorage(ParamSize) {
: MGraph(GraphImpl), MValueStorage(ParamSize),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
std::memcpy(MValueStorage.data(), Data, ParamSize);
}

/// sycl_ext_oneapi_raw_kernel_arg constructor
/// Parameter size is taken from member of raw_kernel_arg object.
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl, size_t,
raw_kernel_arg *Data)
: MGraph(GraphImpl) {
: MGraph(GraphImpl),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
size_t RawArgSize = Data->MArgSize;
const void *RawArgData = Data->MArgData;
MValueStorage.reserve(RawArgSize);
Expand Down Expand Up @@ -1493,13 +1507,20 @@ class dynamic_parameter_impl {
int ArgIndex,
const sycl::detail::AccessorBaseHost *Acc);

unsigned long long getID() const { return MID; }

// Weak ptrs to node_impls which will be updated
std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
// Dynamic command-groups which will be updated
std::vector<DynamicCGInfo> MDynCGs;

std::shared_ptr<graph_impl> MGraph;
std::vector<std::byte> MValueStorage;

private:
unsigned long long MID;
// Used for std::hash in order to create a unique hash for the instance.
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

class dynamic_command_group_impl
Expand Down Expand Up @@ -1540,6 +1561,13 @@ class dynamic_command_group_impl

/// List of nodes using this dynamic command-group.
std::vector<std::weak_ptr<node_impl>> MNodes;

unsigned long long getID() const { return MID; }

private:
unsigned long long MID;
// Used for std::hash in order to create a unique hash for the instance.
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};
} // namespace detail
} // namespace experimental
Expand Down
6 changes: 6 additions & 0 deletions sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,12 @@
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@$$QEAV0123456@@Z
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@AEBV0123456@@Z
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@V?$command_graph@$0A@@23456@_KPEBX@Z
??4?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@AEBU01@@Z
??4?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@$$QEAU01@@Z
??R?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEBA_KAEBVdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@Z
??R?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEBA_KAEBVnode@experimental@oneapi@ext@_V1@sycl@@@Z
??4?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@$$QEAU01@@Z
??4?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@AEBU01@@Z
??0event@_V1@sycl@@AEAA@V?$shared_ptr@Vevent_impl@detail@_V1@sycl@@@std@@@Z
??0event@_V1@sycl@@QEAA@$$QEAV012@@Z
??0event@_V1@sycl@@QEAA@AEBV012@@Z
Expand Down
1 change: 1 addition & 0 deletions sycl/unittests/Extensions/CommandGraph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(CMAKE_CXX_EXTENSIONS OFF)
add_sycl_unittest(CommandGraphExtensionTests OBJECT
Barrier.cpp
CommandGraph.cpp
CommonReferenceSemantics.cpp
Exceptions.cpp
InOrderQueue.cpp
MultiThreaded.cpp
Expand Down
Loading

0 comments on commit 822cf9b

Please sign in to comment.