Skip to content

Commit

Permalink
GH-3429: Add example to docs; Guard against corpus being different on…
Browse files Browse the repository at this point in the history
… each process
  • Loading branch information
jeffpicard committed Nov 8, 2024
1 parent eb6f5b0 commit d74deec
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 29 deletions.
5 changes: 3 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This folder contains actively maintained examples of use of Flair, organized alo

## Table of Tasks

| Task | Documentation
| ----------------------------- | -------------
| Task | Documentation
|--------------------------| -------------
| Named Entity Recognition (NER) | [Here](ner/)
| Multi GPU | [Here](multi_gpu/)
32 changes: 32 additions & 0 deletions examples/multi_gpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Multi GPU

Training can be distributed across multiple GPUs on a local machine when using
[`ModelTrainer`](#flair.trainers.trainer.ModelTrainer).

## Example

See the script `run_multi_gpu.py` and its comments.

## Tutorial

There are 2 changes that are always required, as well as a few things to consider

Always Required:
1) Pass the argument `multi_gpu=True` to your [`.train()`](#flair.trainers.trainer.ModelTrainer.train) or `.fine_tune()`
2) Wrap your code in [`launch_distributed`](#flair.distributed_utils.launch_distributed), e.g.
`launch_distributed(main, *args)`. This spawns multiple processes, each driving a GPU

Other considerations:
- The corpus and other preprocessing must be the same on all processes. For example, if corpus initialization involves
anything random, you should either
- Set the random seed before initializing the corpus (e.g. [`flair.set_seed(42)`) OR
- Initialize the corpus before calling `launch_distributed` and pass the corpus as an argument so it's serialized to
all processes
- The effective batch size will be larger by a factor of num_gpus
- Each GPU will now process `mini_batch_size` examples before the optimizer steps, resulting in fewer total steps
taken relative to training with a single device. To obtain comparable results between single/multi gpu,
both mathematically, and in terms of wall time, consider the method in the example script.
- Large batch sizes may be necessary to see faster runs, otherwise the communication overhead may dominate

Only the parameter updates in the training process will be distributed across multiple GPUs. Evaluation and prediction
are still done on a single device. asdf
Empty file added examples/multi_gpu/__init__.py
Empty file.
54 changes: 54 additions & 0 deletions examples/multi_gpu/run_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

import flair
from flair.datasets import IMDB
from flair.distributed_utils import launch_distributed
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer


def main(multi_gpu):
# Note: Multi-GPU can affect corpus loading
# This code will run multiple times -- each GPU gets its own process and each process runs this code. We need to
# ensure that the corpus has the same elements and order on all processes, despite sampling. We do that by using
# the same seed on all processes.
flair.set_seed(1336)

corpus = IMDB()
corpus.downsample(0.01)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)

embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)

# Note: Multi-GPU can affect choice of batch size.
# In order to compare batch updates fairly between single and multi-GPU training, we should:
# 1) Step the optimizer after the same number of examples to achieve com
# 2) Process the same number of examples in each forward pass
mini_batch_chunk_size = 32 # Make this as large as possible without running out of GPU-memory to pack device
num_devices_when_distributing = max(torch.cuda.device_count(), 1)
mini_batch_size = mini_batch_chunk_size if multi_gpu else mini_batch_chunk_size * num_devices_when_distributing
# e.g. Suppose your machine has 2 GPUs. If multi_gpu=False, the first gpu will process 32 examples, then the
# first gpu will process another 32 examples, then the optimizer will step. If multi_gpu=True, each gpu will
# process 32 examples at the same time, then the optimizer will step.

trainer = ModelTrainer(model, corpus)
trainer.train(
"resources/taggers/multi-gpu",
multi_gpu=multi_gpu, # Required for multi-gpu
max_epochs=2,
mini_batch_chunk_size=mini_batch_chunk_size,
mini_batch_size=mini_batch_size,
)


if __name__ == "__main__":
"""Minimal example demonstrating how to train a model on multiple GPUs."""
multi_gpu = True

if multi_gpu:
launch_distributed(main, multi_gpu) # Required for multi-gpu
else:
main(multi_gpu)
5 changes: 5 additions & 0 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,11 @@ def __len__(self) -> int:
def __repr__(self) -> str:
return self.__str__()

def __eq__(self, o: object) -> bool:
if not isinstance(o, Sentence):
return False
return self.to_dict() == o.to_dict()

@property
def start_position(self) -> int:
return self._start_position
Expand Down
19 changes: 19 additions & 0 deletions flair/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
import os
import random
from multiprocessing.connection import Connection
from typing import Callable

import numpy as np
import torch
import torch.multiprocessing as mp
from torch.distributed import destroy_process_group, init_process_group
from torch.utils.data import Dataset

import flair
from flair.data import Corpus, _len_dataset

log = logging.getLogger("flair")

Expand Down Expand Up @@ -66,3 +69,19 @@ def aggregate(value, aggregation_fn=np.mean):
else:
gathered_values = [value]
return aggregation_fn(gathered_values)


def validate_corpus_same_across_processes(corpus: Corpus) -> None:
for dataset in [corpus.train, corpus.dev, corpus.test]:
if dataset is not None:
validate_dataset_same_across_processes(dataset)


def validate_dataset_same_across_processes(dataset: Dataset, sample_size: int = 10) -> None:
"""Sanity checks a few examples to catch datasets that are obviously different, but not exhaustive to save time."""
random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset)))
for i in random_indices:
example = dataset[i]
examples = aggregate(example, list)
if not all(example == examples[0] for example in examples):
raise ValueError("Dataset must be the same on each process")
6 changes: 5 additions & 1 deletion flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
model_state["model_card"] = self.model_card

# save model
torch.save(model_state, str(model_file), pickle_protocol=4)
if is_main_process():
torch.save(model_state, str(model_file), pickle_protocol=4)

if torch.distributed.is_initialized():
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete

@classmethod
def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model":
Expand Down
30 changes: 7 additions & 23 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import flair.nn
from flair.data import Corpus, Dictionary, _len_dataset
from flair.datasets import DataLoader
from flair.distributed_utils import aggregate, is_main_process
from flair.distributed_utils import aggregate, is_main_process, validate_corpus_same_across_processes
from flair.samplers import FlairSampler
from flair.trainers.plugins import (
AnnealingPlugin,
Expand Down Expand Up @@ -495,6 +495,8 @@ def train_custom(
if multi_gpu:
if not torch.distributed.is_initialized():
raise RuntimeError("multi_gpu=True can only used inside flair.distributed_utils.launch_distributed()")
# Guard against each process initializing corpus differently due to e.g. different random seeds
validate_corpus_same_across_processes(self.corpus)
self.ddp_model = DistributedDataParallel(
self.model, device_ids=[flair.device.index], find_unused_parameters=True
)
Expand Down Expand Up @@ -787,14 +789,14 @@ def wrapped_forward_loss(*args, **kwargs2):

if save_best_model and current_epoch_has_best_model_so_far:
log.info("saving best model")
self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state)
self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state)

# - SWAPlugin -> restores SGD weights from SWA
self.dispatch("after_training_loop")

# if we do not use dev data for model selection, save final model
if save_final_model:
self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)

except KeyboardInterrupt:
log_line(log)
Expand All @@ -804,7 +806,7 @@ def wrapped_forward_loss(*args, **kwargs2):

if save_final_model:
log.info("Saving model ...")
self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)
log.info("Done.")

except TrainingInterrupt as exc:
Expand All @@ -815,7 +817,7 @@ def wrapped_forward_loss(*args, **kwargs2):

if save_final_model:
log.info("Saving model ...")
self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state)
self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)
log.info("Done.")

except Exception:
Expand Down Expand Up @@ -956,23 +958,5 @@ def _initialize_model_card(self, **training_parameters):
def _record(self, metric):
self.dispatch("metric_recorded", metric)

def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
"""Saves the current model. Safe to call from a distributed context.
Args:
model_file: the model file
checkpoint: currently unused.
"""
if is_main_process():
self.model.save(model_file, checkpoint)

if torch.distributed.is_initialized():
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete

def _load_model(self, model_file: Union[str, Path]) -> None:
"""Loads the model from the given file into the current state. Safe to call from a distributed context."""
self.model.load_state_dict(self.model.load(model_file).state_dict())
if torch.distributed.is_initialized():
self.ddp_model = DistributedDataParallel(
self.model, device_ids=[flair.device.index], find_unused_parameters=True
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ filterwarnings = [
'ignore:`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.', # transformers calls deprecated hf_hub
"ignore:`torch.cuda.amp.GradScaler", # GradScaler changes in torch 2.3.0 but we want to be backwards compatible.
"ignore:`clean_up_tokenization_spaces` was not set", # Default behavior changes in transformers v4.45, raising irrelevant FutureWarning for serialized models.
"ignore:1Torch was not compiled with flash attention", # You might want to install flash attention, but you don't have to.
]
markers = [
"integration",
Expand Down
10 changes: 8 additions & 2 deletions tests/test_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@ def test_sentence_context():
def test_equality():
assert Sentence("Guten Tag!") != Sentence("Good day!")
assert Sentence("Guten Tag!", use_tokenizer=True) != Sentence("Guten Tag!", use_tokenizer=False)
sentence1 = Sentence("This sentence will be labeled")
sentence1[1].set_label("ner", "B-subject")
sentence2 = Sentence("This sentence will be labeled")
sentence2[1].set_label("ner", "B-object")
assert sentence1 != sentence2

# TODO: is this desirable? Or should two sentences with same text be considered same objects?
assert Sentence("Guten Tag!") != Sentence("Guten Tag!")
assert Sentence("Guten Tag!") == Sentence("Guten Tag!")
sentence2[1].set_label("ner", "B-subject")
assert sentence1 == sentence2


def test_token_labeling():
Expand Down

0 comments on commit d74deec

Please sign in to comment.