-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Properly assign the chunks to the right worker #449
base: main
Are you sure you want to change the base?
Conversation
Repeat testing on this branch
So an outstanding issue of a missing batch. I've also noticed (for |
|
||
chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)] | ||
intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)] | ||
max_batches = num_items // batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a remainder here that’s a “partial” batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is why we are casting things to int, so it is an exact number.
src/litdata/utilities/shuffle.py
Outdated
tmp_arr = [0 for _ in range(num_workers)] | ||
|
||
index = 0 | ||
for _ in range(int(max_batches // distributed_env.world_size)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens to leftover batches here? The leftover full batches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are given the last machine and last worker.
o1 suggests using PyTorch’s built-in distribution features (e.g. DistributedSampler) to avoid reinventing the wheel. |
Hey @robmarkcole. Yes, DistributedSampler won't work here. |
|
||
num_items_per_workers: Any = [] | ||
|
||
for rank in range(distributed_env.world_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why only dist worksize here and not global num workers?
same question for below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because, we want to ensure we fill up all the workers for each process rank in the same way.
OK all my checks got the desired results now! |
Perfect ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just a minor suggestion
Be grateful for a release after this bugfix |
Before submitting
What does this PR do?
Fixes #442
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃