-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pad each batch, not the whole dataset #30
base: master
Are you sure you want to change the base?
Changes from 14 commits
150aaac
68f926b
bfdd032
e7d6e7b
517eb77
2f475cb
0f56c4a
2f1207a
8477c2e
1ba3929
d4e007f
1a00f96
e92ee7c
2c90fdb
d7a3c5e
bc89893
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
|
||
import torch | ||
from torch.nn.parallel import DistributedDataParallel | ||
from torch.nn.utils.rnn import pad_sequence | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from ignite.engine import Engine, Events | ||
from ignite.handlers import ModelCheckpoint | ||
|
@@ -72,11 +73,63 @@ def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=Fals | |
return instance, sequence # TODO: second arg is never used, delete it | ||
|
||
|
||
def pad_and_tensorize(batch_dict, padding): | ||
""" Pad the batch_dict.""" | ||
tensors = [] | ||
for name in MODEL_INPUTS: | ||
if name not in PADDED_INPUTS: | ||
tensors.append(torch.tensor(batch_dict[name])) | ||
continue | ||
entry = batch_dict[name] | ||
pad_id = padding if name != "lm_labels" else -1 | ||
padded = pad_sequence([torch.tensor(seq) for x in entry for seq in x], batch_first=True, | ||
padding_value=pad_id) | ||
bs, n_candidates = len(entry), len(entry[0]) | ||
tensors.append(padded.view(bs, n_candidates, -1)) | ||
return tensors | ||
|
||
class ChatDataset(torch.utils.data.Dataset): | ||
|
||
def __init__(self, fields, pad_id): | ||
self.fields = fields | ||
self.pad_id = pad_id | ||
|
||
def __getitem__(self, item) -> dict: | ||
return {f: self.fields[f][item] for f in MODEL_INPUTS} | ||
|
||
def collate_fn(self, examples): | ||
batch_dict = defaultdict(list) | ||
for input_name in MODEL_INPUTS: | ||
for e in examples: | ||
batch_dict[input_name].append(e[input_name]) | ||
tensors = pad_and_tensorize(batch_dict, padding=self.pad_id) | ||
return tensors | ||
|
||
def __len__(self): | ||
return len(self.fields['input_ids']) | ||
|
||
|
||
def get_data_loaders(args, tokenizer): | ||
""" Prepare the dataset for training and evaluation """ | ||
personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache) | ||
|
||
logger.info("Build inputs and labels") | ||
datasets: dict = make_data_lists(args, personachat, tokenizer) | ||
pad_id = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]) | ||
train_dataset = ChatDataset(datasets['train'], pad_id) | ||
valid_dataset = ChatDataset(datasets['valid'], pad_id) | ||
|
||
logger.info("Build train and validation dataloaders") | ||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (maybe) put this in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at some point might also want to document which tensors are 3D |
||
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None | ||
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed), | ||
collate_fn=train_dataset.collate_fn) | ||
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False, | ||
collate_fn=valid_dataset.collate_fn) | ||
return train_loader, valid_loader, train_sampler, valid_sampler | ||
|
||
|
||
def make_data_lists(args, personachat, tokenizer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring |
||
datasets = {"train": defaultdict(list), "valid": defaultdict(list)} | ||
for dataset_name, dataset in personachat.items(): | ||
num_candidates = len(dataset[0]["utterances"][0]["candidates"]) | ||
|
@@ -86,36 +139,20 @@ def get_data_loaders(args, tokenizer): | |
persona = dialog["personality"].copy() | ||
for _ in range(args.personality_permutations): | ||
for utterance in dialog["utterances"]: | ||
history = utterance["history"][-(2*args.max_history+1):] | ||
candidate_instances = defaultdict(list) | ||
history = utterance["history"][-(2 * args.max_history + 1):] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could add |
||
for j, candidate in enumerate(utterance["candidates"][-num_candidates:]): | ||
lm_labels = bool(j == num_candidates-1) | ||
instance, _ = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels) | ||
lm_labels = bool(j == num_candidates - 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better varname? |
||
instance, _ = build_input_from_segments(persona, history, candidate, | ||
tokenizer, lm_labels) | ||
for input_name, input_array in instance.items(): | ||
datasets[dataset_name][input_name].append(input_array) | ||
candidate_instances[input_name].append(input_array) | ||
for k in candidate_instances.keys(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
datasets[dataset_name][k].append(candidate_instances[k]) | ||
datasets[dataset_name]["mc_labels"].append(num_candidates - 1) | ||
datasets[dataset_name]["n_candidates"] = num_candidates | ||
persona = [persona[-1]] + persona[:-1] # permuted personalities | ||
|
||
logger.info("Pad inputs and convert to Tensor") | ||
tensor_datasets = {"train": [], "valid": []} | ||
for dataset_name, dataset in datasets.items(): | ||
dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])) | ||
for input_name in MODEL_INPUTS: | ||
tensor = torch.tensor(dataset[input_name]) | ||
if input_name != "mc_labels": | ||
tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:]) | ||
tensor_datasets[dataset_name].append(tensor) | ||
|
||
logger.info("Build train and validation dataloaders") | ||
train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"]) | ||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None | ||
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None | ||
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed)) | ||
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False) | ||
|
||
logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape)) | ||
logger.info("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape)) | ||
return train_loader, valid_loader, train_sampler, valid_sampler | ||
return datasets | ||
|
||
|
||
def train(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this and
ChatDataset
should be easy to unit test