-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sine Gaussian Waveform #23
Conversation
@tedwards2412 I hope you don't mind us hopping on board. We envision this being helpful for PE follow-up to Burst searches. @ravioli1369 is a student working with me, @ThibeauWouters, and others. |
times = jnp.arange(num, dtype=jnp.float64) / sample_rate | ||
times -= duration / 2.0 | ||
|
||
# add dimension for calculating waveforms in batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure whether this is needed for jax?
Hi all, Thanks a lot for the contributions, and great to see the extensive checks of the code! I quickly went over this and will jot down some comments from looking at the main py file, @tedwards2412 can indicate whether he agrees or not:
|
I have made most of the changes, but when I removed 64-bit precision from my notebook and ran it again, I noticed a large change in the accuracy of the implementation. Following that, I added a comparison between 64 and 32 bit precision waveforms and how they compared to the LALInference implementation (refer code cells just below the markdown heading As for the |
Regarding float64 vs float32, I think one has to scale the signal respectively to avoid loss of accuracy. I don't see why the sine Gaussian will need float64 accuracy, so it should be okay to refactor it into float32. @tedwards2412 can comment more on this. A side note going forward is I think we should start incorporating tests in the code base as it grows |
Sorry for the delay, this looks great so far! @mcoughlin no problem at all, I'm just happy that people are starting to find the code useful and want to contribute :) Couple of comments
|
I had a few doubts regarding this:
I pushed the results back to the notebook towards the end (https://github.com/ravioli1369/ripple/blob/sine-gaussian/notebooks/check_SineGaussian.ipynb). Is this correct, or should I have done something else? |
@ravioli1369 @tedwards2412 Perhaps it might be good to start dividing up the ripple source code into frequency domain and time domain waveforms? I believe that time domain waveforms will get more supported in the future for other use cases as well. Thomas can indicate whether he agrees with this. |
@ThibeauWouters I agree that splitting the waveforms makes sense and will probably make things easier to track as things grow. |
|
|
I ran the timing benchmarks again with the
The notebook (https://github.com/ravioli1369/sine-gaussian/blob/main/speed-comparision.ipynb) has more details on how I ran the benchmark. I even changed the output of the sine gaussian function to give a single array so that I don't have to evaluate it in a list comprehension, but the results of that were also similar. |
I tried myself with the merged version of the SineGaussian waveform in ripple and find that vmap is instead quicker. Here is the code:
My output gives:
Note I just did this directly on my laptop on the CPU so the effect will be further magnified when using a GPU. Also, I had to remove the reshapes which now incorrectly add dimensions once you use vmap. The resulting array should be shape (batch_dimension, time_grid). |
It does indeed look like vmap is faster than running through the parameters in a for loop. The way I tested it was to send all the parameters into the function and call reshape inside of it; this gave results that were faster than removing the reshape and running vmap. I'm not sure why this is happening. The jitted versions (with and without reshape) give identical results, so I think it should be fine to leave it this way, although it does warrant some investigation to see why vmap is performing worse than reshaping. |
I think this overall makes sense, once you add the jit and your manual reshaping I think this is basically manually vectorizing the function and so it should perform similarly to vmap + jit. Overall though, it's not good practice in Jax to add this kind of manual reshaping when you can instead use vmap :) vmapping doesn't do the jit for you, so this is required to make it fast! |
This PR adds the
SineGaussian
waveform in ripple, along with a detailed python notebook showing the mismatch between the LALInference and Jax implementations.