Skip to content

Commit

Permalink
GH-3429: Remove serializability requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffpicard committed Nov 15, 2024
1 parent ffb8e29 commit 1873d7b
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 17 deletions.
6 changes: 3 additions & 3 deletions examples/multi_gpu/run_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def main(multi_gpu):
# This code will run multiple times -- each GPU gets its own process and each process runs this code. We need to
# ensure that the corpus has the same elements and order on all processes, despite sampling. We do that by using
# the same seed on all processes.
flair.set_seed(1336)
flair.set_seed(42)

corpus = IMDB()
corpus.downsample(0.01)
corpus.downsample(0.1)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)

Expand All @@ -35,7 +35,7 @@ def main(multi_gpu):
# process 32 examples at the same time, then the optimizer will step.

trainer = ModelTrainer(model, corpus)
trainer.train(
trainer.fine_tune(
"resources/taggers/multi-gpu",
multi_gpu=multi_gpu, # Required for multi-gpu
max_epochs=2,
Expand Down
5 changes: 0 additions & 5 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,11 +1072,6 @@ def __len__(self) -> int:
def __repr__(self) -> str:
return self.__str__()

def __eq__(self, o: object) -> bool:
if not isinstance(o, Sentence):
return False
return self.to_dict() == o.to_dict()

@property
def start_position(self) -> int:
return self._start_position
Expand Down
11 changes: 6 additions & 5 deletions flair/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,18 @@ def aggregate(value, aggregation_fn=np.mean):
return aggregation_fn(gathered_values)


def validate_corpus_same_across_processes(corpus: Corpus) -> None:
def validate_corpus_same_each_process(corpus: Corpus) -> None:
"""Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable"""
for dataset in [corpus.train, corpus.dev, corpus.test]:
if dataset is not None:
validate_dataset_same_across_processes(dataset)
_validate_dataset_same_each_process(dataset)


def validate_dataset_same_across_processes(dataset: Dataset, sample_size: int = 10) -> None:
"""Sanity checks a few examples to catch datasets that are obviously different, but not exhaustive to save time."""
def _validate_dataset_same_each_process(dataset: Dataset, sample_size: int = 10) -> None:
random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset)))
for i in random_indices:
example = dataset[i]
example = str(dataset[i])
examples = aggregate(example, list)
if not all(example == examples[0] for example in examples):
raise ValueError("Dataset must be the same on each process")
7 changes: 3 additions & 4 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import warnings
from inspect import signature
from pathlib import Path
from queue import Queue
from typing import Optional, Tuple, Type, Union
from typing import Optional, Union

import numpy as np
import torch
Expand All @@ -21,7 +20,7 @@
import flair.nn
from flair.data import Corpus, Dictionary, _len_dataset
from flair.datasets import DataLoader
from flair.distributed_utils import aggregate, is_main_process, validate_corpus_same_across_processes
from flair.distributed_utils import aggregate, is_main_process, validate_corpus_same_each_process
from flair.samplers import FlairSampler
from flair.trainers.plugins import (
AnnealingPlugin,
Expand Down Expand Up @@ -497,7 +496,7 @@ def train_custom(
if not torch.distributed.is_initialized():
raise RuntimeError("multi_gpu=True can only used inside flair.distributed_utils.launch_distributed()")
# Guard against each process initializing corpus differently due to e.g. different random seeds
validate_corpus_same_across_processes(self.corpus)
validate_corpus_same_each_process(self.corpus)
self.ddp_model = DistributedDataParallel(
self.model, device_ids=[flair.device.index], find_unused_parameters=True
)
Expand Down

0 comments on commit 1873d7b

Please sign in to comment.