Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Properly assign the chunks to the right worker #449

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

tchaton
Copy link
Collaborator

@tchaton tchaton commented Jan 14, 2025

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #442

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@robmarkcole
Copy link
Contributor

robmarkcole commented Jan 15, 2025

Repeat testing on this branch

  • with value of 8 for batch_size and num_workers, I no longer get the error ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1]) - SOLVED
  • with value of 4 for batch_size and num_workers, I get Test batch size: 4 but only 20 of 24 images are processed - final batch is missing..
  • with value of 8 for batch_size and 4 for num_workers I get Test batch size: 8 and all 24 images are processed
  • with value of 4 for batch_size and 2 for num_workers I get Test batch size: 4 but only 20 of 24 images are processed - final batch is missing..
  • with value of 8 for batch_size and 2 for num_workers I get Test batch size: 8 but only 16 of 24 images are processed - final batch is missing..

So an outstanding issue of a missing batch.

I've also noticed (for batch_size: 8) I now have Train dataset length: 192 whereas this was 200 before. I also get 196 for batch_size: 4 (val and test set lengths are similarly affected)


chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)]
intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)]
max_batches = num_items // batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a remainder here that’s a “partial” batch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is why we are casting things to int, so it is an exact number.

tmp_arr = [0 for _ in range(num_workers)]

index = 0
for _ in range(int(max_batches // distributed_env.world_size)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to leftover batches here? The leftover full batches

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are given the last machine and last worker.

@robmarkcole
Copy link
Contributor

o1 suggests using PyTorch’s built-in distribution features (e.g. DistributedSampler) to avoid reinventing the wheel.

@tchaton
Copy link
Collaborator Author

tchaton commented Jan 15, 2025

Hey @robmarkcole. Yes, DistributedSampler won't work here.


num_items_per_workers: Any = []

for rank in range(distributed_env.world_size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only dist worksize here and not global num workers?
same question for below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because, we want to ensure we fill up all the workers for each process rank in the same way.

@robmarkcole
Copy link
Contributor

OK all my checks got the desired results now!

@tchaton
Copy link
Collaborator Author

tchaton commented Jan 15, 2025

OK all my checks got the desired results now!

Perfect !

Copy link
Contributor

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just a minor suggestion

src/litdata/utilities/shuffle.py Outdated Show resolved Hide resolved
@robmarkcole
Copy link
Contributor

Be grateful for a release after this bugfix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

drop_last is not respected
4 participants