-
-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use rayon to speed up linfa-logistic #355
base: master
Are you sure you want to change the base?
Conversation
@@ -495,8 +495,7 @@ fn logistic_grad<F: Float, A: Data<Elem = F>>( | |||
let yz = x.dot(¶ms.into_shape((params.len(), 1)).unwrap()) + intercept; | |||
let len = yz.len(); | |||
let mut yz = yz.into_shape(len).unwrap() * y; | |||
yz.mapv_inplace(logistic); | |||
yz -= F::one(); | |||
yz.par_mapv_inplace(|v| logistic(v) - F::one()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put * y
from below inside par_mapv_inplace, like -1
? Plus there is * y
above too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, y
is also a vector, so this would need to be more like map on zipped vectors, store in one of them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe another comment, the expression we have right now, isn't it -logistic(-v)
, so we could avoid possibly numerically bad subtraction.
@@ -454,9 +454,9 @@ fn log_sum_exp<F: linfa::Float, A: Data<Elem = F>>( | |||
/// Computes `exp(n - max) / sum(exp(n- max))`, which is a numerically stable version of softmax | |||
fn softmax_inplace<F: linfa::Float, A: DataMut<Elem = F>>(v: &mut ArrayBase<A, Ix1>) { | |||
let max = v.iter().copied().reduce(F::max).unwrap(); | |||
v.mapv_inplace(|n| (n - max).exp()); | |||
v.par_mapv_inplace(|n| (n - max).exp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there par_reduce
by any chance, for max and sum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alas, I implemented a benchmark, but it seems that the limiting step (if you have many more samples than "features") is a matrix multiply that I didn't parallelize. And moreover I don't see a good way to make I'll leave this open for now, in case someone wants to use it as a starting point for doing this properly, but feel free to close when you wish to. |
This PR speeds up a test logistic regression by a factor of two on my laptop, from 2 minutes and 13 seconds to just 1 minute.