Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation stuck when trainers have different data size #20561

Open
btian opened this issue Jan 23, 2025 · 1 comment
Open

Validation stuck when trainers have different data size #20561

btian opened this issue Jan 23, 2025 · 1 comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x

Comments

@btian
Copy link

btian commented Jan 23, 2025

Bug description

Image

I'm running validation where each trainer can have different data size, however, validation gets stuck.

In the example above, rank 5 ran out of data after batch 50, while other ranks still have data. But the program got stuck.

I'm using FSDP strategy to train an LLM. Not sure why validation batches are synchronized.

What version are you seeing the problem on?

v2.3
v2.5.0.post0

How to reproduce the bug

def validation_step(self, input_dict: Dict, batch_idx: int) -> Dict[str, torch.Tensor]:
        pred_dict : Dict[str, torch.Tensor] = self.model(input_dict)
        loss_dict: Dict[str, torch.Tensor] = (
            self.model.get_loss_dict(gt_dict=input_dict, pred_dict=pred_dict)
        )

        self.log_dict(
            {f"val_loss/{key}": loss for key, loss in loss_dict.items()},
            # In validation step, all logs are aggregated and logged at the epoch end.
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )

        return pred_dict

Error messages and logs

No error message. NCCL timeout after 10 minutes.

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA L20
    - available: True
    - version: 12.4
  • Lightning:
    - efficientnet-pytorch: 0.7.1
    - lightning: 2.3.0
    - lightning-thunder: 0.1.0
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.5.0.post0
    - pytorch-quantization: 2.1.2
    - pytorch-triton: 3.0.0+a9bc1a364
    - torch: 2.5.1+cu124
    - torch-scatter: 2.1.2+pt25cu124
    - torchcodec: 0.1.1+cu124
    - torchdata: 0.7.1a0
    - torchmetrics: 1.6.1
    - torchvision: 0.20.1+cu124
    - xpilot-lightning-mini: 2.3.0
  • Packages:
    - absl-py: 2.1.0
    - accelerate: 1.2.1
    - adjusttext: 1.1.1
    - aiofiles: 22.1.0
    - aiohttp: 3.9.3
    - aiosignal: 1.3.1
    - aiosqlite: 0.20.0
    - albumentations: 1.1.0
    - alembic: 1.14.0
    - aliyun-python-sdk-core: 2.16.0
    - aliyun-python-sdk-kms: 2.16.5
    - annotated-types: 0.6.0
    - anyio: 3.7.1
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - astral: 3.2
    - astroid: 2.15.8
    - asttokens: 2.4.1
    - astunparse: 1.6.3
    - async-timeout: 4.0.3
    - attrs: 23.2.0
    - audioread: 3.0.1
    - av: 12.3.0
    - babel: 2.16.0
    - bcrypt: 4.2.1
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - blessed: 1.20.0
    - blinker: 1.9.0
    - blis: 0.7.11
    - blosc2: 2.7.1
    - cachetools: 5.3.3
    - casadi: 3.6.5
    - catalogue: 2.0.10
    - cattrs: 24.1.2
    - ccimport: 0.4.4
    - certifi: 2024.2.2
    - cffi: 1.16.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - cloudpathlib: 0.16.0
    - cloudpickle: 2.2.1
    - cmake: 3.29.0.1
    - coloredlogs: 15.0.1
    - colorlog: 6.9.0
    - comm: 0.2.2
    - confection: 0.1.4
    - contextlib2: 21.6.0
    - contourpy: 1.2.1
    - crcmod: 1.7
    - cryptography: 44.0.0
    - cuda-python: 12.4.0rc7+3.ge75c8a9.dirty
    - cudf: 24.2.0
    - cudnn: 1.1.2
    - cugraph: 24.2.0
    - cugraph-dgl: 24.2.0
    - cugraph-service-client: 24.2.0
    - cugraph-service-server: 24.2.0
    - cuml: 24.2.0
    - cumm-cu120: 0.4.11
    - cupy-cuda12x: 13.0.0
    - cycler: 0.12.1
    - cymem: 2.0.8
    - cython: 3.0.10
    - dask: 2024.1.1
    - dask-cuda: 24.2.0
    - dask-cudf: 24.2.0
    - databricks-cli: 0.18.0
    - dbus-python: 1.2.18
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - decord: 0.6.0
    - deepspeed: 0.15.0
    - defusedxml: 0.7.1
    - deprecated: 1.2.15
    - descartes: 1.1.0
    - dill: 0.3.9
    - distributed: 2024.1.1
    - distro: 1.7.0
    - dm-tree: 0.1.8
    - dnspython: 2.7.0
    - docker: 6.1.3
    - docstring-parser: 0.16
    - dotmap: 1.3.30
    - duckdb: 0.8.1
    - durationpy: 0.9
    - easydict: 1.11
    - editor: 1.6.6
    - efficientnet-pytorch: 0.7.1
    - einops: 0.7.0
    - elasticsearch: 7.17.12
    - elasticsearch-dsl: 7.4.1
    - engineering-notation: 0.10.0
    - entrypoints: 0.4
    - exceptiongroup: 1.2.0
    - execnet: 2.0.2
    - executing: 2.0.1
    - expecttest: 0.1.3
    - fastjsonschema: 2.19.1
    - fastrlock: 0.8.2
    - ffmpeg-python: 0.2.0
    - filelock: 3.13.3
    - filterpy: 1.4.5
    - fire: 0.7.0
    - flash-attn: 2.7.2.post1
    - flask: 2.3.3
    - flatbuffers: 24.12.23
    - fonttools: 4.51.0
    - fqdn: 1.5.1
    - frozenlist: 1.4.1
    - fsspec: 2022.11.0
    - ftfy: 6.2.3
    - func-timeout: 4.3.5
    - future: 1.0.0
    - fuyao: 4.5.1.post2
    - fuyao-all: 3.8.17
    - fuyao-skinny: 0.1.10
    - gast: 0.5.4
    - gitdb: 4.0.12
    - gitpython: 3.1.44
    - google-auth: 2.29.0
    - google-auth-oauthlib: 0.4.6
    - graphsurgeon: 0.4.6
    - greenlet: 3.1.1
    - grpcio: 1.62.1
    - gunicorn: 20.1.0
    - guppy3: 3.1.3
    - h5py: 3.12.1
    - hjson: 3.1.0
    - huggingface-hub: 0.27.1
    - humanfriendly: 10.0
    - hypothesis: 5.35.1
    - idna: 3.6
    - igraph: 0.11.4
    - imageio: 2.36.1
    - importlib-metadata: 4.13.0
    - iniconfig: 2.0.0
    - inquirer: 3.4.0
    - intel-openmp: 2021.4.0
    - ipcqueue: 0.9.7
    - ipykernel: 6.29.4
    - ipython: 8.21.0
    - ipython-genutils: 0.2.0
    - ipywidgets: 8.1.5
    - isoduration: 20.11.0
    - isort: 5.13.2
    - itsdangerous: 2.2.0
    - jedi: 0.19.1
    - jinja2: 3.1.3
    - jmespath: 0.10.0
    - joblib: 1.3.2
    - json5: 0.9.24
    - jsonpointer: 3.0.0
    - jsonschema: 4.21.1
    - jsonschema-specifications: 2023.12.1
    - jupyter: 1.1.1
    - jupyter-client: 6.2.0
    - jupyter-console: 6.6.3
    - jupyter-core: 5.7.2
    - jupyter-enterprise-gateway: 3.2.2
    - jupyter-events: 0.11.0
    - jupyter-server: 1.24.0
    - jupyter-server-fileid: 0.9.3
    - jupyter-server-ydoc: 0.8.0
    - jupyter-tensorboard: 0.2.0
    - jupyter-ydoc: 0.2.5
    - jupyterlab: 3.6.6
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.3
    - jupyterlab-widgets: 3.0.13
    - jupytext: 1.16.1
    - kafka-python: 2.0.2
    - kiwisolver: 1.4.5
    - kubernetes: 31.0.0
    - langcodes: 3.3.0
    - lark: 1.1.9
    - lazy-loader: 0.4
    - lazy-object-proxy: 1.10.0
    - librosa: 0.10.1
    - lightning: 2.3.0
    - lightning-thunder: 0.1.0
    - lightning-utilities: 0.11.2
    - llvmlite: 0.42.0
    - locket: 1.0.0
    - looseversion: 1.3.0
    - lz4: 4.3.2
    - mako: 1.3.8
    - markdown: 3.6
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - masksdk: 1.2.1
    - matplotlib: 3.8.4
    - matplotlib-inline: 0.1.6
    - mccabe: 0.7.0
    - mdit-py-plugins: 0.4.0
    - mdurl: 0.1.2
    - mistune: 3.0.2
    - mkl: 2021.1.1
    - mkl-devel: 2021.1.1
    - mkl-include: 2021.1.1
    - mlflow: 1.29.0
    - mock: 5.1.0
    - mpmath: 1.3.0
    - msgpack: 1.0.8
    - multidict: 6.0.5
    - multiscaledeformableattention: 1.0
    - murmurhash: 1.0.10
    - mypy-extensions: 1.0.0
    - namedatomiclock: 1.1.3
    - nbclassic: 1.1.0
    - nbclient: 0.10.0
    - nbconvert: 7.16.3
    - nbformat: 5.10.4
    - ndindex: 1.9.2
    - nest-asyncio: 1.6.0
    - networkx: 3.4.2
    - ninja: 1.11.1.1
    - notebook: 6.4.10
    - notebook-shim: 0.2.4
    - numba: 0.59.0+1.g20ae2b56c
    - numexpr: 2.10.2
    - numpy: 1.24.4
    - nuscenes-devkit: 1.1.9
    - nvfuser: 0.1.6a0+a684e2a
    - nvidia-cublas-cu12: 12.4.5.8
    - nvidia-cuda-cupti-cu12: 12.4.127
    - nvidia-cuda-nvrtc-cu12: 12.4.127
    - nvidia-cuda-runtime-cu12: 12.4.127
    - nvidia-cudnn-cu12: 9.1.0.70
    - nvidia-cufft-cu12: 11.2.1.3
    - nvidia-curand-cu12: 10.3.5.147
    - nvidia-cusolver-cu12: 11.6.1.9
    - nvidia-cusparse-cu12: 12.3.1.170
    - nvidia-dali-cuda120: 1.36.0
    - nvidia-ml-py: 12.560.30
    - nvidia-nccl-cu12: 2.21.5
    - nvidia-nvimgcodec-cu12: 0.2.0.7
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.4.127
    - nvidia-pyindex: 1.0.9
    - nvsmi: 0.4.2
    - nvtx: 0.2.5
    - oauthlib: 3.2.2
    - onnx: 1.16.0
    - onnxruntime: 1.16.0
    - opencv: 4.7.0
    - opencv-python: 4.5.5.62
    - opencv-python-headless: 4.5.5.62
    - opt-einsum: 3.3.0
    - optree: 0.11.0
    - orjson: 3.8.7
    - oss-middle-layer: 1.9.9
    - oss2: 2.16.0
    - ossfs: 2023.1.0
    - packaging: 21.3
    - pandas: 1.5.3
    - pandocfilters: 1.5.1
    - paramiko: 3.5.0
    - parso: 0.8.4
    - partd: 1.4.1
    - pccm: 0.4.16
    - peft: 0.13.2
    - pexpect: 4.9.0
    - pika: 1.3.2
    - pillow: 10.3.0
    - pip: 24.0
    - pipdeptree: 2.13.0
    - platformdirs: 4.2.0
    - pluggy: 1.4.0
    - ply: 3.11
    - polars: 0.18.6
    - polygraphy: 0.49.8
    - pooch: 1.8.1
    - portalocker: 3.1.1
    - preshed: 3.0.9
    - presto-python-client: 0.8.4
    - prettytable: 3.10.0
    - prometheus-client: 0.20.0
    - prometheus-flask-exporter: 0.23.1
    - prompt-toolkit: 3.0.43
    - protobuf: 3.20.3
    - psutil: 5.9.4
    - psycopg2-binary: 2.9.10
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py-cpuinfo: 9.0.0
    - py-spy: 0.3.14
    - pyarrow: 14.0.1
    - pyasn1: 0.6.0
    - pyasn1-modules: 0.4.0
    - pybind11: 2.12.0
    - pybind11-global: 2.12.0
    - pycls: 0.1.1
    - pycocotools: 2.0.8
    - pycparser: 2.22
    - pycryptodome: 3.21.0
    - pycryptodomex: 3.21.0
    - pydantic: 2.6.4
    - pydantic-core: 2.16.3
    - pygments: 2.17.2
    - pygobject: 3.42.1
    - pyjwt: 2.10.1
    - pylibcugraph: 24.2.0
    - pylibcugraphops: 24.2.0
    - pylibraft: 24.2.0
    - pylint: 2.17.7
    - pylint-exit: 1.2.0
    - pymongo: 4.10.1
    - pymysql: 1.0.2
    - pynacl: 1.5.0
    - pynvjitlink: 0.1.13
    - pynvml: 11.4.1
    - pyparsing: 3.1.2
    - pypcd: 0.1.1
    - pyquaternion: 0.9.9
    - pyrasite: 2.0.1
    - pytest: 8.1.1
    - pytest-flakefinder: 1.1.0
    - pytest-rerunfailures: 14.0
    - pytest-shard: 0.1.2
    - pytest-xdist: 3.5.0
    - python-dateutil: 2.9.0.post0
    - python-hostlist: 1.23.0
    - python-json-logger: 3.2.1
    - python-lzf: 0.2.6
    - python-rapidjson: 1.8
    - pytorch-lightning: 2.5.0.post0
    - pytorch-quantization: 2.1.2
    - pytorch-triton: 3.0.0+a9bc1a364
    - pytz: 2022.7.1
    - pyyaml: 6.0.1
    - pyzmq: 24.0.1
    - qudida: 0.0.4
    - querystring-parser: 1.2.4
    - raft-dask: 24.2.0
    - rapids-dask-dependency: 24.2.0a0
    - ratelimiter: 1.2.0.post0
    - readchar: 4.2.1
    - redis: 5.2.1
    - referencing: 0.34.0
    - regex: 2023.12.25
    - requests: 2.31.0
    - requests-oauthlib: 2.0.0
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.7.1
    - rmm: 24.2.0
    - rpds-py: 0.18.0
    - rsa: 4.9
    - ruamel.yaml: 0.18.10
    - ruamel.yaml.clib: 0.2.12
    - runs: 1.2.2
    - safetensors: 0.5.2
    - schema: 0.7.5
    - scikit-image: 0.25.0
    - scikit-learn: 1.2.0
    - scipy: 1.12.0
    - seaborn: 0.12.2
    - send2trash: 1.8.2
    - sentencepiece: 0.2.0
    - setuptools: 68.2.2
    - sh: 2.0.7
    - shapely: 1.8.1
    - simplejson: 3.19.3
    - six: 1.16.0
    - smart-open: 6.4.0
    - smmap: 5.0.2
    - sniffio: 1.3.1
    - sortedcontainers: 2.4.0
    - soundfile: 0.12.1
    - soupsieve: 2.5
    - soxr: 0.3.7
    - spacy: 3.7.4
    - spacy-legacy: 3.0.12
    - spacy-loggers: 1.0.5
    - spconv-cu120: 2.3.6
    - sphinx-glpi-theme: 0.6
    - sqlalchemy: 1.4.54
    - sqlparse: 0.5.3
    - srsly: 2.4.8
    - ssh-import-id: 5.11
    - stack-data: 0.6.3
    - sympy: 1.13.1
    - tables: 3.10.1
    - tabulate: 0.9.0
    - tbb: 2021.12.0
    - tblib: 3.0.0
    - tenacity: 9.0.0
    - tensorboard: 2.17.0
    - tensorboard-data-server: 0.7.2
    - tensorboard-plugin-wit: 1.8.1
    - tensorboardx: 2.6.2.2
    - tensorrt: 8.6.3
    - termcolor: 2.5.0
    - terminado: 0.18.1
    - texttable: 1.7.0
    - thinc: 8.2.3
    - threadpoolctl: 3.3.0
    - thriftpy2: 0.4.17
    - tifffile: 2024.12.12
    - timm: 0.5.4
    - tinycss2: 1.2.1
    - tk-tools: 0.16.0
    - tokenizers: 0.20.4
    - toml: 0.10.2
    - tomli: 2.0.1
    - tomlkit: 0.13.2
    - toolz: 0.12.1
    - torch: 2.5.1+cu124
    - torch-scatter: 2.1.2+pt25cu124
    - torchcodec: 0.1.1+cu124
    - torchdata: 0.7.1a0
    - torchmetrics: 1.6.1
    - torchvision: 0.20.1+cu124
    - tornado: 6.4
    - tqdm: 4.66.2
    - traitlets: 5.9.0
    - transformers: 4.46.2
    - treelite: 4.0.0
    - triton: 3.1.0
    - typed-argument-parser: 1.10.1
    - typer: 0.9.4
    - types-dataclasses: 0.6.6
    - types-python-dateutil: 2.9.0.20241206
    - typing-extensions: 4.10.0
    - typing-inspect: 0.9.0
    - ucx-py: 0.36.0
    - uff: 0.6.9
    - uri-template: 1.3.0
    - urllib3: 1.26.18
    - wasabi: 1.1.2
    - watchdog: 6.0.0
    - wcwidth: 0.2.13
    - weasel: 0.3.4
    - webcolors: 24.11.1
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - werkzeug: 3.0.2
    - wheel: 0.43.0
    - widgetsnbextension: 4.0.13
    - wrapt: 1.17.0
    - xcompress: 5.0.1
    - xdata-dataloader: 3.1.63
    - xdoctest: 1.0.2
    - xfoundation: 0.0.6.dev2
    - xgboost: 2.0.3
    - xmod: 1.8.1
    - xpilot-lightning-mini: 2.3.0
    - y-py: 0.6.2
    - yacs: 0.1.8
    - yarl: 1.9.4
    - yarn-api-client: 1.0.3
    - ypy-websocket: 0.8.4
    - zict: 3.0.0
    - zipp: 3.17.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.12
    - release: 5.10.134-17.3.al8.x86_64
    - version: Proposal for help #1 SMP Thu Oct 31 14:29:57 CST 2024

More info

No response

@btian btian added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 23, 2025
@ringohoffman
Copy link
Contributor

ringohoffman commented Mar 2, 2025

I am pretty sure all of your ranks need to receive data at the same time in order to trigger the all gathering of the weights. Otherwise the model that isn't receiving data won't be able to send its portion of the weights to the ones that are still receiving data.

Maybe you could consider giving the model a sequence length 0 or 1 tensor to allow it to continue communicating, and then throw away its result instead of aggregating it. Or find a way to drop or split the data so that this doesn't happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x
Projects
None yet
Development

No branches or pull requests

2 participants