Skip to content

Commit

Permalink
Merge branch 'main' into clem_doc_readme
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Feb 7, 2024
2 parents 77984bd + 1e837a9 commit e9845ba
Show file tree
Hide file tree
Showing 16 changed files with 940 additions and 415 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ keywords = ["evaluation", "nlp", "llm"]
dependencies = [
# Base dependencies
"transformers>=4.36.0",
"huggingface_hub==0.19.4",
"huggingface_hub==0.20.3",
"torch>=2.0",
"GitPython==3.1.31", # for logging
"datasets>=2.14.0",
Expand Down
35 changes: 20 additions & 15 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from torch.utils.data.distributed import DistributedSampler, T_co

from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.tasks.requests import Request
from lighteval.tasks.requests import (
GreedyUntilRequest,
GreedyUntilWithLogitsRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
Request,
)


class DynamicBatchDataset(Dataset):
Expand All @@ -28,6 +35,9 @@ def __init__(
requests (List): A list of requests.
dataset_splits (int): The number of dataset splits.
"""
# We make sure the requests contain the tokenized versions of their values
if any(r.tokenized_context is None for r in requests):
raise ValueError("You passed a request for which tokenization had not happened yet.")

# sort the requests using the collate function and save the original order
enumerated_requests = list(enumerate(requests))
Expand Down Expand Up @@ -124,12 +134,12 @@ def __len__(self) -> int:
"""
return self.split_end - self.split_start

def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request) -> int:
raise NotImplementedError()


class LoglikelihoodDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: LoglikelihoodRequest | LoglikelihoodRollingRequest) -> int:
"""
Collates the input data for batching.
Expand All @@ -149,13 +159,12 @@ def _sorting_criteria(self, x) -> int:
Returns:
tuple: A tuple containing the sorted input data.
"""

toks = x[1] + x[2]
toks = request.tokenized_context + request.tokenized_continuation
return -len(toks)


class LoglikelihoodSingleTokenDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:
"""
Collates the input data for batching.
Expand All @@ -167,19 +176,14 @@ def _sorting_criteria(self, x) -> int:
is useful to simplify the batching logic and more importantly to make
automatic adaptive batches much much easier to implement
- any OOMs will happen right away rather than near the end
Args:
x (tuple): A tuple containing the input data.
Returns:
tuple: A tuple containing the collated data.
"""
toks = x[1] # We take only the prompt, no need for the continuation (since it's a list of single tokens)
# We take only the prompt, no need for the continuation (since it's a list of single tokens)
toks = request.tokenized_context
return -len(toks)


class GenerativeTaskDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsRequest) -> int:
"""
Collate function for generating batches.
Expand All @@ -189,7 +193,8 @@ def _sorting_criteria(self, x) -> int:
Returns:
Any: The collated data.
"""
toks, (stop_tokens, gen_length) = x
toks = request.tokenized_context
gen_length = request.generation_size
return -(len(toks) + gen_length)


Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.base_model import BaseModel
from lighteval.models.inference_client import ModelClient
from lighteval.models.tgi_model import ModelClient
from lighteval.tasks.lighteval_task import LightevalTask
from lighteval.tasks.requests import Doc, Request, RequestType, TaskExampleId

Expand Down
4 changes: 3 additions & 1 deletion src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, met
raise ValueError(
"You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead."
)
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))]

# Todo: make better system with return_bool_score instead of taking first element
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum(
gold_ixs = as_list(formatted_doc.gold_index)

for metric in metrics:
Expand Down
155 changes: 155 additions & 0 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from abc import ABC, abstractmethod
from typing import Optional, Union

import torch
from transformers import BatchEncoding

from lighteval.models.model_config import EnvConfig
from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
from lighteval.tasks.requests import (
GreedyUntilRequest,
GreedyUntilWithLogitsRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)


TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding]


class LightevalModel(ABC):
DATASET_SPLITS = 4

"""Abstract model class defining the API that every model to plug into lighteval must follow."""

@abstractmethod
def __init__(
self,
config,
env_config: EnvConfig,
):
return NotImplemented

def cleanup(self):
"""Clean up operations if needed, such as closing an endpoint."""
return

@property
@abstractmethod
def tokenizer(self):
raise NotImplementedError

@property
@abstractmethod
def add_special_tokens(self):
raise NotImplementedError

@property
@abstractmethod
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
raise NotImplementedError

@property
def disable_tqdm(self) -> bool:
raise NotImplementedError

def greedy_until_with_logits(
self,
requests: list[GreedyUntilWithLogitsRequest],
override_bs: Optional[int] = None,
) -> list[GenerateReturn]:
"""
Generates sequences greedily until a stopping condition is met,
returning both the generated sequences and the logits.
Args:
requests (list[tuple[str, dict]]): A list of input requests,
where each request is a tuple containing a prompt string and a dictionary of additional parameters.
disable_tqdm (bool, optional): Whether to disable the tqdm progress bar. Defaults to False.
override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None.
Returns:
list[GenerateReturn]: A list of GenerateReturn objects,
where each object contains the generated sequence and the corresponding logits.
"""
return self.greedy_until(
requests=requests,
override_bs=override_bs,
returns_logits=True,
)

@abstractmethod
def greedy_until(
self,
requests: list[GreedyUntilRequest],
returns_logits: bool = False,
override_bs: Optional[int] = None,
) -> list[GenerateReturn]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Args:
requests (list[Request]): list of requests containing the context and ending conditions.
returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
Returns:
list[GenerateReturn]: list of generated responses.
"""
return NotImplemented

@abstractmethod
def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodReturn]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
return NotImplemented

@abstractmethod
def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs=None
) -> list[LoglikelihoodReturn]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
return NotImplemented

@abstractmethod
def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenReturn]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
return NotImplemented

# Tokenization utils
def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
if add_special_tokens is None:
add_special_tokens = self.add_special_tokens
if isinstance(str_to_encode, str):
return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens)
return self.tokenizer(
str_to_encode,
padding=True,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)

def tok_encode_pair(self, context, continuation):
"""Encodes a context, continuation pair by taking care of the spaces in between."""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc

def tok_decode(self, tokens: torch.LongTensor) -> list[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
8 changes: 4 additions & 4 deletions src/lighteval/models/adapter_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import nullcontext

import torch
from transformers import AutoModel, PreTrainedTokenizer
from transformers import AutoModelForCausalLM, PreTrainedTokenizer

from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.base_model import BaseModel
Expand All @@ -20,7 +20,7 @@ def _create_auto_tokenizer(self, config: AdapterModelConfig, env_config: EnvConf
# (= the parent model, not the model of interest)
return self._create_auto_tokenizer_with_name(config.base_model, config=config, env_config=env_config)

def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModel:
def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModelForCausalLM:
"""Returns a PeftModel from a base model and a version fined tuned using PEFT."""
torch_dtype = _get_dtype(config.dtype, self._config)
config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel)
Expand All @@ -31,7 +31,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)

if self.accelerator.is_local_main_process if self.accelerator is not None else nullcontext():
hlog(f"Loading model from {adapter_weights} and applying adapter to {config.base_model}")
base = self.AUTO_MODEL_CLASS.from_pretrained(
base = AutoModelForCausalLM.from_pretrained(
config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token
)
# Should pass revision
Expand All @@ -43,7 +43,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)

hlog(f"Loading model from {merged_path}")

model = self.AUTO_MODEL_CLASS.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
merged_path,
max_memory=max_memory,
device_map=device_map,
Expand Down
Loading

0 comments on commit e9845ba

Please sign in to comment.