Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pad each batch, not the whole dataset #30

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
87 changes: 62 additions & 25 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, TensorDataset
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint
Expand Down Expand Up @@ -72,11 +73,63 @@ def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=Fals
return instance, sequence # TODO: second arg is never used, delete it


def pad_and_tensorize(batch_dict, padding):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this and ChatDataset should be easy to unit test

""" Pad the batch_dict."""
tensors = []
for name in MODEL_INPUTS:
if name not in PADDED_INPUTS:
tensors.append(torch.tensor(batch_dict[name]))
continue
entry = batch_dict[name]
pad_id = padding if name != "lm_labels" else -1
padded = pad_sequence([torch.tensor(seq) for x in entry for seq in x], batch_first=True,
padding_value=pad_id)
bs, n_candidates = len(entry), len(entry[0])
tensors.append(padded.view(bs, n_candidates, -1))
return tensors

class ChatDataset(torch.utils.data.Dataset):

def __init__(self, fields, pad_id):
self.fields = fields
self.pad_id = pad_id

def __getitem__(self, item) -> dict:
return {f: self.fields[f][item] for f in MODEL_INPUTS}

def collate_fn(self, examples):
batch_dict = defaultdict(list)
for input_name in MODEL_INPUTS:
for e in examples:
batch_dict[input_name].append(e[input_name])
tensors = pad_and_tensorize(batch_dict, padding=self.pad_id)
return tensors

def __len__(self):
return len(self.fields['input_ids'])


def get_data_loaders(args, tokenizer):
""" Prepare the dataset for training and evaluation """
personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache)

logger.info("Build inputs and labels")
datasets: dict = make_data_lists(args, personachat, tokenizer)
pad_id = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])
train_dataset = ChatDataset(datasets['train'], pad_id)
valid_dataset = ChatDataset(datasets['valid'], pad_id)

logger.info("Build train and validation dataloaders")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(maybe) put this in ChatDataset.to_loader(self, args, shuffle) -> sampler, loader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at some point might also want to document which tensors are 3D

valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed),
collate_fn=train_dataset.collate_fn)
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False,
collate_fn=valid_dataset.collate_fn)
return train_loader, valid_loader, train_sampler, valid_sampler


def make_data_lists(args, personachat, tokenizer):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

datasets = {"train": defaultdict(list), "valid": defaultdict(list)}
for dataset_name, dataset in personachat.items():
num_candidates = len(dataset[0]["utterances"][0]["candidates"])
Expand All @@ -86,36 +139,20 @@ def get_data_loaders(args, tokenizer):
persona = dialog["personality"].copy()
for _ in range(args.personality_permutations):
for utterance in dialog["utterances"]:
history = utterance["history"][-(2*args.max_history+1):]
candidate_instances = defaultdict(list)
history = utterance["history"][-(2 * args.max_history + 1):]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add assert len(utterance['candidates']) >= num_candidates

for j, candidate in enumerate(utterance["candidates"][-num_candidates:]):
lm_labels = bool(j == num_candidates-1)
instance, _ = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels)
lm_labels = bool(j == num_candidates - 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better varname?

instance, _ = build_input_from_segments(persona, history, candidate,
tokenizer, lm_labels)
for input_name, input_array in instance.items():
datasets[dataset_name][input_name].append(input_array)
candidate_instances[input_name].append(input_array)
for k in candidate_instances.keys():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.items() will save some chars

datasets[dataset_name][k].append(candidate_instances[k])
datasets[dataset_name]["mc_labels"].append(num_candidates - 1)
datasets[dataset_name]["n_candidates"] = num_candidates
persona = [persona[-1]] + persona[:-1] # permuted personalities

logger.info("Pad inputs and convert to Tensor")
tensor_datasets = {"train": [], "valid": []}
for dataset_name, dataset in datasets.items():
dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]))
for input_name in MODEL_INPUTS:
tensor = torch.tensor(dataset[input_name])
if input_name != "mc_labels":
tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:])
tensor_datasets[dataset_name].append(tensor)

logger.info("Build train and validation dataloaders")
train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"])
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed))
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False)

logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape))
logger.info("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape))
return train_loader, valid_loader, train_sampler, valid_sampler
return datasets


def train():
Expand Down