Skip to content

Commit

Permalink
[sml] optimize preprocessing by eliminating unnecessary where function (
Browse files Browse the repository at this point in the history
#622)

# Pull Request

## What problem does this PR solve?
Small optimization in sml/preprocessing
Eliminate `where` function in original solution.
In original version, `where` was used only for eliminating redundant
computation branch, however both branches within the `where` function
still need to be computed, which does not meet the desired outcome.
By eliminating the `where` function, the computation can be reduced
without affecting the result.
  • Loading branch information
winnylyc authored Mar 26, 2024
1 parent 10bf15e commit d1850ed
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions sml/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def bin_func(x, KMEANS):
elif vectorize == True:
diverse_n_bins = self.diverse_n_bins
### directly using jnp.linspace will cause dynamic shape problem,
### so we need to use jnp.arange and a branch function jnp.where
### so we need to use jnp.arange with a public value n_bins
if self.strategy == "uniform":
arrange_array = jnp.arange(n_bins + 1)

Expand All @@ -848,9 +848,7 @@ def bin_func(x, diverse_n_bin, arrange_array):
delta = (maxval - minval) / diverse_n_bin

def bin_element_func(x_inner):
return jnp.where(
x_inner <= diverse_n_bin, minval + x_inner * delta, maxval
)
return minval + x_inner * delta

return jax.vmap(bin_element_func)(arrange_array)

Expand Down

0 comments on commit d1850ed

Please sign in to comment.