Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic implementation of WeightedCrossEntropy torchmetric. #1

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions models/dimamba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from functools import partial
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Mapping

import huggingface_hub
import numpy as np
Expand All @@ -23,6 +23,7 @@
BaseModelOutputWithNoAttention,
MaskedLMOutput,
)
from composer.metrics.nlp import LanguageCrossEntropy

try:
from mamba_ssm.ops.triton.layernorm import (
Expand Down Expand Up @@ -835,13 +836,47 @@ def forward(
return hidden_states, all_hidden_states


class WeightedCrossEntropy(LanguageCrossEntropy):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better approach for this would be to subclass torchmetrics.Metric directly since it would then be portable across Lightning, Composer and Torchtitan (at the very least), without needing composer installed. This is also pretty much overriding everything LanguageCrossEntropy does, so I see little advantage to it.

"""A weighted variant of cross entropy loss, which will discount certain tokens."""

def __init__(self, loss_weights: Tensor, dist_sync_on_step: bool = False, ignore_index: int = -100):
super().__init__(dist_sync_on_step, ignore_index)
self.name = "weighted_cross_entropy"
self.loss_weights = loss_weights
# Note: this differs from the inherited implementation, whose `reduction='sum'`.
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none')

def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None:
"""An adaptation of `LangaugeCrossEntropy` that takes a weighted loss."""
if isinstance(output, Mapping):
logits = output['logits']
elif isinstance(output, Tensor):
logits = output
else:
raise Exception(f'Type {type(output)} for the output is unsupported.')

target = target.view(-1)
logits = logits.view(target.shape[0], -1)
losses = self.loss_fn(logits, target)

# Note: this is an addition to the original implementation.
loss_weights = self.loss_weights.view(-1)
loss_weights[target == self.ignore_index] = 0.0

total_items = (target != self.ignore_index).sum()
self.total_items += total_items # type: ignore (third-party)

# accumulate **weighted** loss over all batches
self.sum_loss += losses * (loss_weights / loss_weights.sum())

# Deprecated in favor of standard metric from Composer.
def cross_entropy(logits, y, ignore_index=-100):
"""Cross entropy loss."""
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
return F.cross_entropy(logits, y, ignore_index=ignore_index)


# Deprecated in favor of above implementation
def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
"""Weighted cross entropy loss (discounts certain tokens)."""
logits = logits.view(-1, logits.shape[-1])
Expand Down
1 change: 1 addition & 0 deletions requirements.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- pytorch-cuda=12.1
- pip:
- causal-conv1d==1.1.3.post1
- mosaicml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this re: https://github.com/Open-Athena/mdlm/pull/1/files#r1937418877. It's great that Composer is designed to be decoupled from models like that. Torchtitan is too AFAIK. Lightning is not, and I'm not sure what that means yet for running lightning models like this on other training frameworks yet. Either way, there shouldn't be any need to depend on composer in mdlm.

- datasets==2.18.0
- einops==0.7.0
- fsspec==2024.2.0
Expand Down