Skip to content

Commit

Permalink
pybind11: expose CPU-only
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Mar 13, 2024
1 parent 7b6e788 commit ae2e009
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
52 changes: 34 additions & 18 deletions cuvec/include/cuvec_pybind11.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,53 @@
#include <pybind11/pybind11.h> // pybind11, PYBIND11_MAKE_OPAQUE
#include <pybind11/stl.h> // std::vector

#ifndef CUVEC_DISABLE_CUDA // ensure CPU-only alternative exists
#define NDxVEC_MAKE_OPAQUE(T) \
PYBIND11_MAKE_OPAQUE(NDCuVec<T>); \
PYBIND11_MAKE_OPAQUE(NDCVec<T>);
#else
#define NDxVEC_MAKE_OPAQUE(T) PYBIND11_MAKE_OPAQUE(NDCuVec<T>);
#endif // CUVEC_DISABLE_CUDA

PYBIND11_MAKE_OPAQUE(std::vector<size_t>);
PYBIND11_MAKE_OPAQUE(NDCuVec<signed char>);
PYBIND11_MAKE_OPAQUE(NDCuVec<unsigned char>);
PYBIND11_MAKE_OPAQUE(NDCuVec<char>);
PYBIND11_MAKE_OPAQUE(NDCuVec<short>);
PYBIND11_MAKE_OPAQUE(NDCuVec<unsigned short>);
PYBIND11_MAKE_OPAQUE(NDCuVec<int>);
PYBIND11_MAKE_OPAQUE(NDCuVec<unsigned int>);
PYBIND11_MAKE_OPAQUE(NDCuVec<long long>);
PYBIND11_MAKE_OPAQUE(NDCuVec<unsigned long long>);
NDxVEC_MAKE_OPAQUE(signed char);
NDxVEC_MAKE_OPAQUE(unsigned char);
NDxVEC_MAKE_OPAQUE(char);
NDxVEC_MAKE_OPAQUE(short);
NDxVEC_MAKE_OPAQUE(unsigned short);
NDxVEC_MAKE_OPAQUE(int);
NDxVEC_MAKE_OPAQUE(unsigned int);
NDxVEC_MAKE_OPAQUE(long long);
NDxVEC_MAKE_OPAQUE(unsigned long long);
#ifdef _CUVEC_HALF
PYBIND11_MAKE_OPAQUE(NDCuVec<_CUVEC_HALF>);
NDxVEC_MAKE_OPAQUE(_CUVEC_HALF);
template <> struct pybind11::format_descriptor<_CUVEC_HALF> : pybind11::format_descriptor<float> {
static std::string format() { return "e"; }
};
#endif
PYBIND11_MAKE_OPAQUE(NDCuVec<float>);
PYBIND11_MAKE_OPAQUE(NDCuVec<double>);
NDxVEC_MAKE_OPAQUE(float);
NDxVEC_MAKE_OPAQUE(double);

#define PYBIND11_BIND_NDCUVEC(T, typechar) \
pybind11::class_<NDCuVec<T>>(m, PYBIND11_TOSTRING(NDCuVec_##typechar), \
pybind11::buffer_protocol()) \
.def_buffer([](NDCuVec<T> &v) -> pybind11::buffer_info { \
#define PYBIND11_BIND_NDVEC(Vec, T, typechar) \
pybind11::class_<Vec<T>>(m, PYBIND11_TOSTRING(Vec##_##typechar), pybind11::buffer_protocol()) \
.def_buffer([](Vec<T> &v) -> pybind11::buffer_info { \
return pybind11::buffer_info(v.vec.data(), sizeof(T), \
pybind11::format_descriptor<T>::format(), v.shape.size(), \
v.shape, v.strides()); \
}) \
.def(pybind11::init<>()) \
.def(pybind11::init<std::vector<size_t>>()) \
.def_property( \
"shape", [](const NDCuVec<T> &v) { return &v.shape; }, &NDCuVec<T>::reshape) \
.def_property_readonly("address", [](const NDCuVec<T> &v) { return (size_t)v.vec.data(); })
"shape", [](const Vec<T> &v) { return &v.shape; }, &Vec<T>::reshape) \
.def_property_readonly("address", [](const Vec<T> &v) { return (size_t)v.vec.data(); })
#define PYBIND11_BIND_NDCUVEC(T, typechar) PYBIND11_BIND_NDVEC(NDCuVec, T, typechar)
#ifndef CUVEC_DISABLE_CUDA // ensure CPU-only alternative exists
#define PYBIND11_BIND_NDCVEC(T, typechar) PYBIND11_BIND_NDVEC(NDCVec, T, typechar)
#else
#define PYBIND11_BIND_NDCVEC(T, typechar)
#endif // CUVEC_DISABLE_CUDA
#define PYBIND11_BIND_NDxVEC(T, typechar) \
PYBIND11_BIND_NDCVEC(T, typechar); \
PYBIND11_BIND_NDCUVEC(T, typechar)

#endif // _CUVEC_PYBIND11_H_
24 changes: 12 additions & 12 deletions cuvec/src/pybind11.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ PYBIND11_MODULE(cuvec_pybind11, m) {
m.doc() = "PyBind11 external module.";
pybind11::bind_vector<std::vector<size_t>>(m, "Shape");
pybind11::implicitly_convertible<pybind11::tuple, std::vector<size_t>>();
PYBIND11_BIND_NDCUVEC(signed char, b);
PYBIND11_BIND_NDCUVEC(unsigned char, B);
PYBIND11_BIND_NDCUVEC(char, c);
PYBIND11_BIND_NDCUVEC(short, h);
PYBIND11_BIND_NDCUVEC(unsigned short, H);
PYBIND11_BIND_NDCUVEC(int, i);
PYBIND11_BIND_NDCUVEC(unsigned int, I);
PYBIND11_BIND_NDCUVEC(long long, q);
PYBIND11_BIND_NDCUVEC(unsigned long long, Q);
PYBIND11_BIND_NDxVEC(signed char, b);
PYBIND11_BIND_NDxVEC(unsigned char, B);
PYBIND11_BIND_NDxVEC(char, c);
PYBIND11_BIND_NDxVEC(short, h);
PYBIND11_BIND_NDxVEC(unsigned short, H);
PYBIND11_BIND_NDxVEC(int, i);
PYBIND11_BIND_NDxVEC(unsigned int, I);
PYBIND11_BIND_NDxVEC(long long, q);
PYBIND11_BIND_NDxVEC(unsigned long long, Q);
#ifdef _CUVEC_HALF
PYBIND11_BIND_NDCUVEC(_CUVEC_HALF, e);
PYBIND11_BIND_NDxVEC(_CUVEC_HALF, e);
#endif
PYBIND11_BIND_NDCUVEC(float, f);
PYBIND11_BIND_NDCUVEC(double, d);
PYBIND11_BIND_NDxVEC(float, f);
PYBIND11_BIND_NDxVEC(double, d);
m.attr("__author__") = "Casper da Costa-Luis (https://github.com/casperdcl)";
m.attr("__date__") = "2024";
m.attr("__version__") = "2.0.0";
Expand Down

0 comments on commit ae2e009

Please sign in to comment.