-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into clem_doc_readme
- Loading branch information
Showing
16 changed files
with
940 additions
and
415 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional, Union | ||
|
||
import torch | ||
from transformers import BatchEncoding | ||
|
||
from lighteval.models.model_config import EnvConfig | ||
from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn | ||
from lighteval.tasks.requests import ( | ||
GreedyUntilRequest, | ||
GreedyUntilWithLogitsRequest, | ||
LoglikelihoodRequest, | ||
LoglikelihoodRollingRequest, | ||
LoglikelihoodSingleTokenRequest, | ||
) | ||
|
||
|
||
TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding] | ||
|
||
|
||
class LightevalModel(ABC): | ||
DATASET_SPLITS = 4 | ||
|
||
"""Abstract model class defining the API that every model to plug into lighteval must follow.""" | ||
|
||
@abstractmethod | ||
def __init__( | ||
self, | ||
config, | ||
env_config: EnvConfig, | ||
): | ||
return NotImplemented | ||
|
||
def cleanup(self): | ||
"""Clean up operations if needed, such as closing an endpoint.""" | ||
return | ||
|
||
@property | ||
@abstractmethod | ||
def tokenizer(self): | ||
raise NotImplementedError | ||
|
||
@property | ||
@abstractmethod | ||
def add_special_tokens(self): | ||
raise NotImplementedError | ||
|
||
@property | ||
@abstractmethod | ||
def max_length(self) -> int: | ||
"""Return the maximum sequence length of the model.""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def disable_tqdm(self) -> bool: | ||
raise NotImplementedError | ||
|
||
def greedy_until_with_logits( | ||
self, | ||
requests: list[GreedyUntilWithLogitsRequest], | ||
override_bs: Optional[int] = None, | ||
) -> list[GenerateReturn]: | ||
""" | ||
Generates sequences greedily until a stopping condition is met, | ||
returning both the generated sequences and the logits. | ||
Args: | ||
requests (list[tuple[str, dict]]): A list of input requests, | ||
where each request is a tuple containing a prompt string and a dictionary of additional parameters. | ||
disable_tqdm (bool, optional): Whether to disable the tqdm progress bar. Defaults to False. | ||
override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None. | ||
Returns: | ||
list[GenerateReturn]: A list of GenerateReturn objects, | ||
where each object contains the generated sequence and the corresponding logits. | ||
""" | ||
return self.greedy_until( | ||
requests=requests, | ||
override_bs=override_bs, | ||
returns_logits=True, | ||
) | ||
|
||
@abstractmethod | ||
def greedy_until( | ||
self, | ||
requests: list[GreedyUntilRequest], | ||
returns_logits: bool = False, | ||
override_bs: Optional[int] = None, | ||
) -> list[GenerateReturn]: | ||
""" | ||
Generates responses using a greedy decoding strategy until certain ending conditions are met. | ||
Args: | ||
requests (list[Request]): list of requests containing the context and ending conditions. | ||
returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False. | ||
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. | ||
override_bs (int, optional): Override the batch size for generation. Defaults to None. | ||
Returns: | ||
list[GenerateReturn]: list of generated responses. | ||
""" | ||
return NotImplemented | ||
|
||
@abstractmethod | ||
def loglikelihood( | ||
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None | ||
) -> list[LoglikelihoodReturn]: | ||
"""Tokenize the context and continuation and compute the log likelihood of those | ||
tokenized sequences. | ||
""" | ||
return NotImplemented | ||
|
||
@abstractmethod | ||
def loglikelihood_rolling( | ||
self, requests: list[LoglikelihoodRollingRequest], override_bs=None | ||
) -> list[LoglikelihoodReturn]: | ||
"""This function is used to compute the log likelihood of the context for perplexity metrics.""" | ||
return NotImplemented | ||
|
||
@abstractmethod | ||
def loglikelihood_single_token( | ||
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None | ||
) -> list[LoglikelihoodSingleTokenReturn]: | ||
"""Tokenize the context and continuation and compute the log likelihood of those | ||
tokenized sequences. | ||
""" | ||
return NotImplemented | ||
|
||
# Tokenization utils | ||
def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence: | ||
if add_special_tokens is None: | ||
add_special_tokens = self.add_special_tokens | ||
if isinstance(str_to_encode, str): | ||
return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens) | ||
return self.tokenizer( | ||
str_to_encode, | ||
padding=True, | ||
add_special_tokens=add_special_tokens, | ||
return_tensors="pt", | ||
) | ||
|
||
def tok_encode_pair(self, context, continuation): | ||
"""Encodes a context, continuation pair by taking care of the spaces in between.""" | ||
n_spaces = len(context) - len(context.rstrip()) | ||
if n_spaces > 0: | ||
continuation = context[-n_spaces:] + continuation | ||
context = context[:-n_spaces] | ||
whole_enc = self.tok_encode(context + continuation) | ||
context_enc = self.tok_encode(context) | ||
context_enc_len = len(context_enc) | ||
continuation_enc = whole_enc[context_enc_len:] | ||
return context_enc, continuation_enc | ||
|
||
def tok_decode(self, tokens: torch.LongTensor) -> list[str]: | ||
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.