Skip to content

Commit

Permalink
Add batch_size flag to lm_salience_demo and fix minor bugs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623274231
bdu91 authored and LIT team committed Apr 9, 2024
1 parent 8ca1312 commit 8ea325b
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions lit_nlp/examples/lm_salience_demo.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
To run with the default configuration (Gemma on TensorFlow via Keras):

blaze run -c opt examples:lm_salience_demo -- \
--models=gemma_instruct_2b_en:gemma_instruct_2b_en \
--models=gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en \
--port=8890 --alsologtostderr

MODELS:
@@ -64,7 +64,7 @@

_MODELS = flags.DEFINE_list(
"models",
["gemma_instruct_2b_en:gemma_instruct_2b_en"],
["gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en"],
"Models to load, as <name>:<path>. Path can be a URL, a local file path, or"
" the name of a preset for the configured Deep Learning framework (either"
" KerasNLP or HuggingFace Transformers; see --dl_framework for more). This"
@@ -91,6 +91,10 @@
),
)

_BATCH_SIZE = flags.DEFINE_integer(
"batch_size", 4, "The number of examples to process per batch.",
)

_DL_BACKEND = flags.DEFINE_enum(
"dl_backend",
"tensorflow",
@@ -278,18 +282,17 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
path = file_cache.cached_path(
path,
extract_compressed_file=path.endswith(".tar.gz"),
copy_directories=True,
)

if _DL_FRAMEWORK.value == "keras":
if _DL_FRAMEWORK.value == "kerasnlp":
# pylint: disable=g-import-not-at-top
from keras_nlp import models as keras_models
from lit_nlp.examples.models import instrumented_keras_lms as lit_keras
# pylint: enable=g-import-not-at-top
# Load the weights once for the underlying Keras model.
model = keras_models.CausalLM.from_preset(path)
models |= lit_keras.initialize_model_group_for_salience(
model_name, model, max_length=512, batch_size=4
model_name, model, max_length=512, batch_size=_BATCH_SIZE.value
)
# Disable embeddings from the generation model.
# TODO(lit-dev): re-enable embeddings if we can figure out why UMAP was
@@ -301,7 +304,11 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# Assuming a valid decoder model name supported by
# `transformers.AutoModelForCausalLM` is provided to "path".
models |= pretrained_lms.initialize_model_group_for_salience(
model_name, path, framework=_DL_BACKEND.value, max_new_tokens=512
model_name,
path,
batch_size=_BATCH_SIZE.value,
framework=_DL_BACKEND.value,
max_new_tokens=512,
)

for name in datasets:

0 comments on commit 8ea325b

Please sign in to comment.