Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643488612
  • Loading branch information
RyanMullins committed Jun 27, 2024
1 parent cd5a0e1 commit 6a37c54
Show file tree
Hide file tree
Showing 9 changed files with 886 additions and 486 deletions.
483 changes: 47 additions & 436 deletions python/notebooks/run_scripts_for_llm_comparator.ipynb

Large diffs are not rendered by default.

45 changes: 22 additions & 23 deletions python/src/llm_comparator/llm_judge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

from collections.abc import Sequence
import math
import re
from typing import Optional
import xml.etree.ElementTree as ET

from llm_comparator import model_helper
from llm_comparator import types
from llm_comparator import utils


_IndividualRating = types.IndividualRating
_JsonDict = types.JsonDict
_LLMJudgeInput = types.LLMJudgeInput
_LLMJudgeOutput = types.LLMJudgeOutput
_ModelHelper = model_helper.ModelHelper
_GenerationModelHelper = model_helper.GenerationModelHelper


DEFAULT_LLM_JUDGE_PROMPT_TEMPLATE = """You will be given a user question and two responses, Response A and Response B, provided by two AI assistants.
Expand Down Expand Up @@ -70,7 +70,7 @@ class LLMJudgeRunner:

def __init__(
self,
generation_model_helper: _ModelHelper,
generation_model_helper: _GenerationModelHelper,
llm_judge_prompt_template: str = DEFAULT_LLM_JUDGE_PROMPT_TEMPLATE,
rating_to_score_map: Optional[dict[str, float]] = None,
):
Expand All @@ -96,8 +96,8 @@ def create_prompt_for_judge(
return prompt_for_judge

def create_inputs_with_repeats_for_judge(
self, inputs: list[_LLMJudgeInput], num_repeats: int
) -> list[_JsonDict]:
self, inputs: Sequence[_LLMJudgeInput], num_repeats: int
) -> Sequence[_JsonDict]:
"""Creates inputs with repeated runs for LLM Judge."""
inputs_with_repeats = []
for index, ex in enumerate(inputs):
Expand All @@ -124,7 +124,7 @@ def create_inputs_with_repeats_for_judge(
print(f'Created {len(inputs_with_repeats)} inputs for LLM judge.')
return inputs_with_repeats

def run_query(self, inputs: Sequence[_JsonDict]) -> list[str]:
def run_query(self, inputs: Sequence[_JsonDict]) -> Sequence[str]:
"""Runs LLM judge."""
judge_inputs = [
self.create_prompt_for_judge(
Expand All @@ -141,26 +141,25 @@ def parse_results(
self,
outputs_from_judge: Sequence[str],
inputs_for_judge: Sequence[_JsonDict],
) -> list[_IndividualRating]:
) -> Sequence[Sequence[_IndividualRating]]:
"""Parses XML-formatted LLM judge outputs."""

def parse_output(raw_output: str):
# Find parts where <result> is in the XML-formatted output.
xml_output = re.search(
r'<result>(.*?)</result>', raw_output, flags=re.DOTALL
)
if not xml_output:
print('Invalid output with missing <result> tags')
parsed_xml = utils.extract_xml_part(raw_output, 'result')
if not parsed_xml:
return None

try:
parsed_xml = ET.fromstring(xml_output.group(0))
except ET.ParseError as e:
print(f'Invalid format: {e} ({xml_output})')
if (rationale := parsed_xml.find('explanation')) is None:
return None
if (rationale := rationale.text) is None:
return None

if (rating_label := parsed_xml.find('verdict')) is None:
return None
if (rating_label := rating_label.text) is None:
return None

rationale = parsed_xml.find('explanation').text
rating_label = parsed_xml.find('verdict').text
score = self.rating_to_score_map[rating_label]
return (score, rating_label, rationale.strip(' \n'))

Expand All @@ -185,19 +184,19 @@ def parse_output(raw_output: str):

def postprocess_results(
self, example_ratings: Sequence[Sequence[_IndividualRating]]
) -> list[_LLMJudgeOutput]:
results = []
) -> Sequence[_LLMJudgeOutput]:
results: list[_LLMJudgeOutput] = []
for ratings in example_ratings:
score = sum([rating['score'] for rating in ratings]) / len(ratings)
results.append({
'score': score,
'individual_ratings': ratings,
'individual_rater_scores': list(ratings),
})
return results

def run(
self, inputs: Sequence[_LLMJudgeInput], num_repeats=6
) -> list[_LLMJudgeOutput]:
) -> Sequence[_LLMJudgeOutput]:
"""Runs the LLM judge pipeline."""
input_list_for_judge = self.create_inputs_with_repeats_for_judge(
inputs, num_repeats
Expand Down
117 changes: 91 additions & 26 deletions python/src/llm_comparator/model_helper.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
"""Helper classes for calling LLMs."""

import abc
from collections.abc import Iterable, Sequence
import time
from typing import Optional, Union
from typing import Optional

from tqdm.auto import tqdm
import vertexai
from vertexai.generative_models import GenerationConfig
from vertexai.generative_models import GenerativeModel
from vertexai import generative_models
from vertexai import language_models
import tqdm.auto


MAX_NUM_RETRIES = 5
DEFAULT_MAX_OUTPUT_TOKENS = 256

BATCH_EMBED_SIZE = 100

class ModelHelper(abc.ABC):

class GenerationModelHelper(abc.ABC):
"""Class for managing calling LLMs."""

def predict(
self,
prompt: str,
temperature: float,
max_output_tokens: Optional[int] = DEFAULT_MAX_OUTPUT_TOKENS,
) -> Union[list[str], str]:
def predict(self, prompt: str, **kwargs) -> str:
raise NotImplementedError()

def predict_batch(self, prompts: Sequence[str], **kwargs) -> Sequence[str]:
raise NotImplementedError()


class VertexModelHelper(ModelHelper):
"""Vertex AI model API calls."""
class VertexGenerationModelHelper(GenerationModelHelper):
"""Vertex AI text generation model API calls."""

def __init__(self, project_id: str, region: str, model_name='gemini-pro'):
vertexai.init(project=project_id, location=region)
self.engine = GenerativeModel(model_name)
def __init__(self, model_name='gemini-pro'):
self.engine = generative_models.GenerativeModel(model_name)

def predict(
self,
Expand All @@ -43,32 +42,98 @@ def predict(
return ''
num_attempts = 0
response = None
prediction = None

while num_attempts < MAX_NUM_RETRIES and response is None:
num_attempts += 1

try:
prediction = self.engine.generate_content(
prompt,
generation_config=GenerationConfig(
generation_config=generative_models.GenerationConfig(
temperature=temperature,
candidate_count=1,
max_output_tokens=max_output_tokens,
),
)
num_attempts += 1
response = prediction.text
except Exception as e: # pylint: disable=broad-except
if 'quota' in str(e):
print('\033[31mQuota limit exceeded. Waiting to retry...\033[0m')
time.sleep(2**num_attempts)
return response if response is not None else ''
print('\033[31mQuota limit exceeded\033[0m')
wait_time = 2**num_attempts
print(f'\033[31mWaiting {wait_time}s to retry...\033[0m')
time.sleep(2**num_attempts)

if isinstance(prediction, Iterable):
prediction = list(prediction)[0]

return prediction.text if prediction is not None else ''

def predict_batch(
self,
prompts: list[str],
prompts: Sequence[str],
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = DEFAULT_MAX_OUTPUT_TOKENS,
) -> list[str]:
) -> Sequence[str]:
outputs = []
for i in tqdm(range(0, len(prompts))):
for i in tqdm.auto.tqdm(range(0, len(prompts))):
# TODO(b/344631023): Implement multiprocessing.
outputs.append(self.predict(prompts[i], temperature, max_output_tokens))
return outputs


class EmbeddingModelHelper(abc.ABC):
"""Class for managing calling text embedding models."""

def embed(self, text: str) -> Sequence[float]:
raise NotImplementedError()

def embed_batch(self, texts: Sequence[str]) -> Sequence[Sequence[float]]:
raise NotImplementedError()


class VertexEmbeddingModelHelper(EmbeddingModelHelper):
"""Vertex AI text embedding model API calls."""

def __init__(self, model_name: str = 'textembedding-gecko@003'):
self.model = language_models.TextEmbeddingModel.from_pretrained(model_name)

def _embed_single_run(
self, texts: Sequence[str]
) -> Sequence[Sequence[float]]:
"""Embeds a list of strings into the models embedding space."""
num_attempts = 0
embeddings = None

if not isinstance(texts, list):
texts = list(texts)

while num_attempts < MAX_NUM_RETRIES and embeddings is None:
try:
embeddings = self.model.get_embeddings(texts)
except Exception as e: # pylint: disable=broad-except
print(f'Waiting to retry... ({e})')
time.sleep(2**num_attempts)

if embeddings is None:
return []

return [embedding.values for embedding in embeddings]

def embed(self, text: str) -> Sequence[float]:
results = self._embed_single_run([text])
return results[0]

def embed_batch(self, texts: Sequence[str]) -> Sequence[Sequence[float]]:
if len(texts) <= BATCH_EMBED_SIZE:
return self._embed_single_run(texts)
else:
results = []
for batch_start_index in tqdm.auto.tqdm(
range(0, len(texts), BATCH_EMBED_SIZE)
):
results.extend(
self._embed_single_run(
texts[batch_start_index : batch_start_index + BATCH_EMBED_SIZE]
)
)
return results
Loading

0 comments on commit 6a37c54

Please sign in to comment.