diff --git a/interact.py b/interact.py index e22f1fe..185088d 100644 --- a/interact.py +++ b/interact.py @@ -62,7 +62,6 @@ def sample_sequence(personality, history, tokenizer, model, args, current_output for i in range(args.max_length): instance = build_input_from_segments(personality, history, current_output, tokenizer, with_eos=False) - input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0) token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0) diff --git a/train.py b/train.py index bf70da2..2245a3a 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,7 @@ import torch from torch.nn.parallel import DistributedDataParallel +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, TensorDataset from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint @@ -68,11 +69,63 @@ def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=Fals return instance +def pad_and_tensorize(batch_dict, padding): + """ Pad the batch_dict.""" + tensors = [] + for name in MODEL_INPUTS: + if name not in PADDED_INPUTS: + tensors.append(torch.tensor(batch_dict[name])) + continue + entry = batch_dict[name] + pad_id = padding if name != "lm_labels" else -1 + padded = pad_sequence([torch.tensor(seq) for x in entry for seq in x], batch_first=True, + padding_value=pad_id) + bs, n_candidates = len(entry), len(entry[0]) + tensors.append(padded.view(bs, n_candidates, -1)) + return tensors + +class ChatDataset(torch.utils.data.Dataset): + + def __init__(self, fields, pad_id): + self.fields = fields + self.pad_id = pad_id + + def __getitem__(self, item) -> dict: + return {f: self.fields[f][item] for f in MODEL_INPUTS} + + def collate_fn(self, examples): + batch_dict = defaultdict(list) + for input_name in MODEL_INPUTS: + for e in examples: + batch_dict[input_name].append(e[input_name]) + tensors = pad_and_tensorize(batch_dict, padding=self.pad_id) + return tensors + + def __len__(self): + return len(self.fields['input_ids']) + + def get_data_loaders(args, tokenizer): """ Prepare the dataset for training and evaluation """ personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache) logger.info("Build inputs and labels") + datasets: dict = make_data_lists(args, personachat, tokenizer) + pad_id = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]) + train_dataset = ChatDataset(datasets['train'], pad_id) + valid_dataset = ChatDataset(datasets['valid'], pad_id) + + logger.info("Build train and validation dataloaders") + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None + valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None + train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed), + collate_fn=train_dataset.collate_fn) + valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False, + collate_fn=valid_dataset.collate_fn) + return train_loader, valid_loader, train_sampler, valid_sampler + + +def make_data_lists(args, personachat, tokenizer): datasets = {"train": defaultdict(list), "valid": defaultdict(list)} for dataset_name, dataset in personachat.items(): num_candidates = len(dataset[0]["utterances"][0]["candidates"]) @@ -82,36 +135,19 @@ def get_data_loaders(args, tokenizer): persona = dialog["personality"].copy() for _ in range(args.personality_permutations): for utterance in dialog["utterances"]: - history = utterance["history"][-(2*args.max_history+1):] + candidate_instances = defaultdict(list) + history = utterance["history"][-(2 * args.max_history + 1):] for j, candidate in enumerate(utterance["candidates"][-num_candidates:]): lm_labels = bool(j == num_candidates-1) instance = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels) for input_name, input_array in instance.items(): - datasets[dataset_name][input_name].append(input_array) + candidate_instances[input_name].append(input_array) + for k in candidate_instances.keys(): + datasets[dataset_name][k].append(candidate_instances[k]) datasets[dataset_name]["mc_labels"].append(num_candidates - 1) datasets[dataset_name]["n_candidates"] = num_candidates persona = [persona[-1]] + persona[:-1] # permuted personalities - - logger.info("Pad inputs and convert to Tensor") - tensor_datasets = {"train": [], "valid": []} - for dataset_name, dataset in datasets.items(): - dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])) - for input_name in MODEL_INPUTS: - tensor = torch.tensor(dataset[input_name]) - if input_name != "mc_labels": - tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:]) - tensor_datasets[dataset_name].append(tensor) - - logger.info("Build train and validation dataloaders") - train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"]) - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None - valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None - train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed)) - valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False) - - logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape)) - logger.info("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape)) - return train_loader, valid_loader, train_sampler, valid_sampler + return datasets def train():