Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644114565
  • Loading branch information
RyanMullins committed Jun 27, 2024
1 parent 4eae13e commit 6e3a410
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 24 deletions.
22 changes: 20 additions & 2 deletions python/notebooks/run_scripts_for_llm_comparator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@
"# !pip install llm-comparator"
]
},
{
"metadata": {
"id": "QZlVpN83nJBv"
},
"cell_type": "code",
"source": [
"# Run this if using a google3 Colab Kernel, such as with\n",
"# blaze run //third_party/javascript/llm_comparator/python/src/llm_comparator:kernel\n",
"# Otherwise, import modules using the following cell.\n",
"from llm_comparator import model_helper\n",
"from llm_comparator import llm_judge_runner\n",
"from llm_comparator import rationale_bullet_generator\n",
"from llm_comparator import rationale_cluster_generator\n",
"import vertexai"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -22,7 +40,8 @@
"from llm_comparator import model_helper\n",
"from llm_comparator import llm_judge_runner\n",
"from llm_comparator import rationale_bullet_generator\n",
"from llm_comparator import rationale_cluster_generator"
"from llm_comparator import rationale_cluster_generator\n",
"import vertexai"
]
},
{
Expand All @@ -35,7 +54,6 @@
"source": [
"#@title Setup for using Vertex AI.\n",
"from google.colab import auth\n",
"import vertexai\n",
"\n",
"auth.authenticate_user()\n",
"\n",
Expand Down
8 changes: 8 additions & 0 deletions python/src/llm_comparator/_colab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Provides a constant that incidates whether we are running in Google Colab."""

try:
import google.colab # pylint: disable=g-import-not-at-top,unused-import # pytype: disable=import-error

IS_COLAB = True
except (ImportError, ModuleNotFoundError):
IS_COLAB = False
16 changes: 16 additions & 0 deletions python/src/llm_comparator/_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Logging utilities."""

import logging

import absl.logging

from llm_comparator import _colab

if _colab.IS_COLAB:
# Colab is set to log WARNING+ by default. This call resets the environment to
# log INFO+ instead. See
# https://stackoverflow.com/questions/54597462 and
# https://colab.sandbox.google.com/github/aviadr1/learn-advanced-python/blob/master/content/15_logging/logging_in_python.ipynb
logging.basicConfig(level=logging.INFO, force=True)

logger = absl.logging
15 changes: 10 additions & 5 deletions python/src/llm_comparator/llm_judge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
from typing import Optional

from llm_comparator import _logging
from llm_comparator import model_helper
from llm_comparator import types
from llm_comparator import utils
Expand All @@ -15,6 +16,8 @@
_LLMJudgeOutput = types.LLMJudgeOutput
_GenerationModelHelper = model_helper.GenerationModelHelper

_logger = _logging.logger


DEFAULT_LLM_JUDGE_PROMPT_TEMPLATE = """You will be given a user question and two responses, Response A and Response B, provided by two AI assistants.
Your task is to act as a judge by determining which response is answering the user's question better.
Expand Down Expand Up @@ -121,7 +124,7 @@ def create_inputs_with_repeats_for_judge(
'response_b': ex['response_a'],
'is_flipped': True,
})
print(f'Created {len(inputs_with_repeats)} inputs for LLM judge.')
_logger.info('Created %d inputs for LLM judge.', len(inputs_with_repeats))
return inputs_with_repeats

def run_query(self, inputs: Sequence[_JsonDict]) -> Sequence[str]:
Expand All @@ -133,7 +136,7 @@ def run_query(self, inputs: Sequence[_JsonDict]) -> Sequence[str]:
for input in inputs
]
judge_outputs = self.generation_model_helper.predict_batch(judge_inputs)
print(f'Generated {len(judge_outputs)} outputs from LLM judge.')
_logger.info('Generated %d outputs from LLM judge.', len(judge_outputs))
return judge_outputs

# TODO(b/344919097): Add some unit tests.
Expand Down Expand Up @@ -163,7 +166,9 @@ def parse_output(raw_output: str):
try:
score = self.rating_to_score_map[rating_label]
except KeyError:
print(f'LLM judge returned an unknown rating label: {rating_label}')
_logger.error(
'LLM judge returned an unknown rating label: %s}', rating_label
)
return None
return (score, rating_label, rationale.strip(' \n'))

Expand All @@ -183,7 +188,7 @@ def parse_output(raw_output: str):
'rating_label': parsed_output[1],
'rationale': parsed_output[2],
})
print(f'Parsed {len(example_ratings)} example ratings.')
_logger.info('Parsed %d example ratings.', len(example_ratings))
return example_ratings

def postprocess_results(
Expand All @@ -210,5 +215,5 @@ def run(
outputs_from_judge, input_list_for_judge
)
scores_and_ratings = self.postprocess_results(example_ratings)
print(f'Generated ratings for {len(scores_and_ratings)} examples.')
_logger.info('Generated ratings for %d examples.', len(scores_and_ratings))
return scores_and_ratings
13 changes: 9 additions & 4 deletions python/src/llm_comparator/model_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
from vertexai import language_models
import tqdm.auto

from llm_comparator import _logging


MAX_NUM_RETRIES = 5
DEFAULT_MAX_OUTPUT_TOKENS = 256

BATCH_EMBED_SIZE = 100

_logger = _logging.logger


class GenerationModelHelper(abc.ABC):
"""Class for managing calling LLMs."""
Expand Down Expand Up @@ -58,9 +62,9 @@ def predict(
)
except Exception as e: # pylint: disable=broad-except
if 'quota' in str(e):
print('\033[31mQuota limit exceeded\033[0m')
_logger.info('\033[31mQuota limit exceeded.\033[0m')
wait_time = 2**num_attempts
print(f'\033[31mWaiting {wait_time}s to retry...\033[0m')
_logger.info('\033[31mWaiting %ds to retry...\033[0m', wait_time)
time.sleep(2**num_attempts)

if isinstance(prediction, Iterable):
Expand Down Expand Up @@ -111,8 +115,9 @@ def _embed_single_run(
try:
embeddings = self.model.get_embeddings(texts)
except Exception as e: # pylint: disable=broad-except
print(f'Waiting to retry... ({e})')
time.sleep(2**num_attempts)
wait_time = 2**num_attempts
_logger.info('Waiting %ds to retry... (%s)', wait_time, e)
time.sleep(wait_time)

if embeddings is None:
return []
Expand Down
13 changes: 8 additions & 5 deletions python/src/llm_comparator/rationale_bullet_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import tqdm.auto

from llm_comparator import _logging
from llm_comparator import model_helper
from llm_comparator import prompt_templates
from llm_comparator import types
Expand All @@ -14,6 +15,8 @@
_LLMJudgeOutput = types.LLMJudgeOutput
_GenerationModelHelper = model_helper.GenerationModelHelper

_logger = _logging.logger


class _BulletGeneratorInput(TypedDict):
"""Intermediate output for rationale bullet generator."""
Expand Down Expand Up @@ -107,7 +110,7 @@ def _prepare_inputs_for_generating_bullets(
'ex_win_side': ex_win_side,
})

print('Done preparing inputs for generating bullets.')
_logger.info('Done preparing inputs for generating bullets.')
return inputs_for_generating_bullets

def _parse_xml_formatted_rationale_bullets(
Expand Down Expand Up @@ -160,9 +163,9 @@ def _generate_rationale_bullets_for_examples(
Returns:
List of bulleted lists.
"""
print(
'Start generating rationale bullets for '
f'{len(inputs_for_generating_bullets)} examples.'
_logger.info(
'Start generating rationale bullets for %d examples',
len(inputs_for_generating_bullets),
)
rationale_bullets_for_examples = []
for input_for_generation in tqdm.auto.tqdm(inputs_for_generating_bullets):
Expand All @@ -178,7 +181,7 @@ def _generate_rationale_bullets_for_examples(
bullets = []
rationale_bullets_for_examples.append(bullets)

print('Done generating rationale bullets')
_logger.info('Done generating rationale bullets')
return rationale_bullets_for_examples

def run(
Expand Down
14 changes: 8 additions & 6 deletions python/src/llm_comparator/rationale_cluster_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import tqdm.auto

from llm_comparator import _logging
from llm_comparator import model_helper
from llm_comparator import prompt_templates
from llm_comparator import types
Expand All @@ -18,6 +19,7 @@
_RationaleCluster = types.RationaleCluster
_GenerationModelHelper = model_helper.GenerationModelHelper
_EmbeddingModelHelper = model_helper.EmbeddingModelHelper
_logger = _logging.logger


class RationaleClusterGenerator:
Expand Down Expand Up @@ -56,7 +58,7 @@ def _generate_paraphrased_rationale(phrase: str) -> str:
temperature=temperature_for_paraphrasing,
)

print('Start paraphrasing rationale bullets.')
_logger.info('Start paraphrasing rationale bullets.')
paraphrased_rationales = {}
for rationale in tqdm.auto.tqdm(rationales):
output = _generate_paraphrased_rationale(rationale)
Expand All @@ -70,7 +72,7 @@ def _generate_paraphrased_rationale(phrase: str) -> str:
if phrase.text
]

print('Done paraphrasing rationales.')
_logger.info('Done paraphrasing rationales.')
return paraphrased_rationales

def _embed_rationales(
Expand All @@ -79,7 +81,7 @@ def _embed_rationales(
"""Embed rationales by taking the average of the embeddings of paraphrases."""
rationales_with_embeddings = {}

print('Start computing embeddings.')
_logger.info('Start computing embeddings.')
for rationale, paraphrased_list in tqdm.auto.tqdm(
paraphrased_rationales.items()
):
Expand All @@ -88,7 +90,7 @@ def _embed_rationales(
avg_embedding = np.mean(np.array(embeddings), axis=0)
rationales_with_embeddings[rationale] = avg_embedding

print('Done computing embeddings.')
_logger.info('Done computing embeddings.')
return rationales_with_embeddings

def _generate_cluster_titles(
Expand Down Expand Up @@ -146,7 +148,7 @@ def _generate_cluster_titles(
item.text for item in output_parsed.findall('group') if item.text
][:num_clusters]

print('Done generating cluster titles.')
_logger.info('Done generating cluster titles.')
return cluster_titles

def _embed_cluster_titles(
Expand Down Expand Up @@ -203,7 +205,7 @@ def _store_similarities_to_rationale_bullets(
rationale_bullets_with_similarities_for_example
)

print('Done assigning cluster similarities to rationales.')
_logger.info('Done assigning cluster similarities to rationales.')
return rationale_bullets_with_similarities

def run(
Expand Down
9 changes: 7 additions & 2 deletions python/src/llm_comparator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@
import re
from typing import Optional
import xml.etree.ElementTree as ET

import numpy as np

from llm_comparator import _logging

_logger = _logging.logger


def extract_xml_part(raw_output: str, tag_name: str) -> Optional[ET.Element]:
"""Find parts where <result> is in the XML-formatted output."""
xml_output = re.search(
rf'<{tag_name}>(.*?)</{tag_name}>', raw_output, flags=re.DOTALL
)
if not xml_output:
print(f'Invalid output with missing <{tag_name}> tags')
_logger.warning('Invalid output with missing <%s> tags', tag_name)
return None

try:
parsed_xml = ET.fromstring(xml_output.group(0))
return parsed_xml
except ET.ParseError as e:
print(f'Invalid format: {e} ({xml_output})')
_logger.warning('Invalid format: %s (%s)', e, xml_output)
return None


Expand Down

0 comments on commit 6e3a410

Please sign in to comment.