From 8b1fbbc679559d2c5033157ecb6442bb50af2ab5 Mon Sep 17 00:00:00 2001 From: Chunqing Shan Date: Sat, 18 May 2024 10:29:06 +0200 Subject: [PATCH] Refactor FloatPyCast function for improved performance using lookup table --- ml_dtypes/_src/dtypes.cc | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index bc292a93..c07caa81 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -178,16 +178,37 @@ struct TypeDescriptor : Int4TypeDescriptor { }; namespace { +template +struct FloatPyCaster { + static void Cast(void* from_void, void* to_void, npy_intp n) { + const auto* from = static_cast(from_void); + auto* to = static_cast(to_void); + for (npy_intp i = 0; i < n; ++i) { + to[i] = static_cast(static_cast(from[i])); + } + } +}; -// Performs a NumPy array cast from type 'From' to 'To' via float. template -void FloatPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr, - void* toarr) { - const auto* from = static_cast(from_void); - auto* to = static_cast(to_void); - for (npy_intp i = 0; i < n; ++i) { - to[i] = static_cast(static_cast(from[i])); +struct FloatPyCaster { + static void Cast(void* from_void, void* to_void, npy_intp n) { + const auto* from = static_cast(from_void); + auto* to = static_cast(to_void); + To table[256]; + // Use int for loop index to avoid overflow. + for (int i = 0; i < 256; ++i) { + table[i] = static_cast(static_cast(__builtin_bit_cast(From, uint8_t(i)))); + } + for (npy_intp i = 0; i < n; ++i) { + to[i] = table[__builtin_bit_cast(uint8_t, from[i])]; + } } +}; + +template +void FloatPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr, void* toarr) { + constexpr bool is_table = (sizeof(From) == 1); + FloatPyCaster::Cast(from_void, to_void, n); } template