Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

little mistake when calculate the perplexity #26

Open
AgentDS opened this issue Oct 29, 2024 · 1 comment
Open

little mistake when calculate the perplexity #26

AgentDS opened this issue Oct 29, 2024 · 1 comment

Comments

@AgentDS
Copy link

AgentDS commented Oct 29, 2024

In __get_ppl() of PPLInferencer, at line 186

lens = (inputs["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy()

where it tries to calculate the token number of each text sample in input_texts, by count the number of token IDs that do not equal to tokenizer.pad_token_id.

However, when we calculate the loss, the number of tokens calculated actually starts from the second token rather the beginning of each inputs as shown in line 173

shift_labels = inputs["input_ids"][..., 1:].contiguous()

Thus, I think the correct way to calculate the token number for line 186 should be

lens = (inputs["input_ids"][..., 1:] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy()

The new version will have very small difference from the original version, that is, new_lens = orig_lens - 1.

For reference:

@AgentDS
Copy link
Author

AgentDS commented Oct 30, 2024

A modified version can be

    def __get_ppl(self, input_texts: List[str], mask_length=None):
        if self.call_api:
            return api_get_ppl(self.api_name, input_texts)
        self.tokenizer.padding_side = "right"
        inputs = self.tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        outputs = self.model(**inputs)

        shift_logits = outputs.logits[..., :-1, :].contiguous()
        shift_labels = inputs["input_ids"][..., 1:].contiguous()
        shift_attention_mask_batch = inputs["attention_mask"][..., 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.tokenizer.pad_token_id)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(
            shift_labels.size())

        if mask_length is not None:
            mask = torch.zeros_like(shift_labels)  # [batch,seqlen]
            for i in range(len(mask)):
                for j in range(mask_length[i] - 1, len(mask[i])):
                    mask[i][j] = 1
            loss = loss * mask

        lens = shift_attention_mask_batch.sum(1).cpu().numpy()
        if mask_length is not None:
            lens -= np.array(mask_length)
        ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
        return ce_loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant