From 2ad9c2664faad21e1634d8c22dd9caab89b15185 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Thu, 30 Jan 2025 16:44:29 -0800 Subject: [PATCH] Basic implementation of WeightedCrossEntropy torchmetric. [WIP] --- models/dimamba.py | 39 +++++++++++++++++++++++++++++++++++++-- requirements.yaml | 1 + 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/models/dimamba.py b/models/dimamba.py index 8127b8d..55e5cb9 100644 --- a/models/dimamba.py +++ b/models/dimamba.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Mapping import huggingface_hub import numpy as np @@ -23,6 +23,7 @@ BaseModelOutputWithNoAttention, MaskedLMOutput, ) +from composer.metrics.nlp import LanguageCrossEntropy try: from mamba_ssm.ops.triton.layernorm import ( @@ -835,13 +836,47 @@ def forward( return hidden_states, all_hidden_states +class WeightedCrossEntropy(LanguageCrossEntropy): + """A weighted variant of cross entropy loss, which will discount certain tokens.""" + + def __init__(self, loss_weights: Tensor, dist_sync_on_step: bool = False, ignore_index: int = -100): + super().__init__(dist_sync_on_step, ignore_index) + self.name = "weighted_cross_entropy" + self.loss_weights = loss_weights + # Note: this differs from the inherited implementation, whose `reduction='sum'`. + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none') + + def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None: + """An adaptation of `LangaugeCrossEntropy` that takes a weighted loss.""" + if isinstance(output, Mapping): + logits = output['logits'] + elif isinstance(output, Tensor): + logits = output + else: + raise Exception(f'Type {type(output)} for the output is unsupported.') + + target = target.view(-1) + logits = logits.view(target.shape[0], -1) + losses = self.loss_fn(logits, target) + + # Note: this is an addition to the original implementation. + loss_weights = self.loss_weights.view(-1) + loss_weights[target == self.ignore_index] = 0.0 + + total_items = (target != self.ignore_index).sum() + self.total_items += total_items # type: ignore (third-party) + + # accumulate **weighted** loss over all batches + self.sum_loss += losses * (loss_weights / loss_weights.sum()) + +# Deprecated in favor of standard metric from Composer. def cross_entropy(logits, y, ignore_index=-100): """Cross entropy loss.""" logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) return F.cross_entropy(logits, y, ignore_index=ignore_index) - +# Deprecated in favor of above implementation def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100): """Weighted cross entropy loss (discounts certain tokens).""" logits = logits.view(-1, logits.shape[-1]) diff --git a/requirements.yaml b/requirements.yaml index 9522111..e055ccc 100644 --- a/requirements.yaml +++ b/requirements.yaml @@ -12,6 +12,7 @@ dependencies: - pytorch-cuda=12.1 - pip: - causal-conv1d==1.1.3.post1 + - mosaicml - datasets==2.18.0 - einops==0.7.0 - fsspec==2024.2.0