diff --git a/lit_nlp/examples/lm_salience_demo.py b/lit_nlp/examples/lm_salience_demo.py index f9637940..3f180e09 100644 --- a/lit_nlp/examples/lm_salience_demo.py +++ b/lit_nlp/examples/lm_salience_demo.py @@ -2,12 +2,14 @@ from collections.abc import Sequence import functools +import os import sys from typing import Optional from absl import app from absl import flags from absl import logging +import keras from lit_nlp import dev_server from lit_nlp import server_flags from lit_nlp.api import layout @@ -37,6 +39,10 @@ ), ) +_KERAS_FLOATX = flags.DEFINE_string( + "keras_floatx", "bfloat16", "Floating-point type for Keras models." +) + # Custom frontend layout; see api/layout.py modules = layout.LitModuleName LM_LAYOUT = layout.LitCanonicalLayout( @@ -109,6 +115,10 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") + # Set Keras backend and floating-point precision. + os.environ["KERAS_BACKEND"] = "tensorflow" + keras.config.set_floatx(_KERAS_FLOATX.value) + plaintextPrompts = functools.partial( # pylint: disable=invalid-name lm_data.PlaintextSents, field_name="prompt" ) diff --git a/lit_nlp/examples/models/instrumented_keras_lms.py b/lit_nlp/examples/models/instrumented_keras_lms.py index fc8f68e3..19dd9a4b 100644 --- a/lit_nlp/examples/models/instrumented_keras_lms.py +++ b/lit_nlp/examples/models/instrumented_keras_lms.py @@ -68,8 +68,8 @@ def __init__( self.model.preprocessor.tokenizer.id_to_token ) - # map ids: [batch_size, num_tokens] - # to embs: [batch_size, num_tokens, emb_dim] + # map ids: [batch_size, num_tokens] + # to embs: [batch_size, num_tokens, emb_dim] self.embedder = self.model.backbone.token_embedding @classmethod @@ -114,7 +114,7 @@ def embed_texts(self, texts: Sequence[str]): processed_inputs = self.encode_inputs( texts, sequence_length=self.max_length ) - # [batch_size, num_tokens, emb_dim] + # [batch_size, num_tokens, emb_dim] embs = self.embedder(processed_inputs["token_ids"]) # [batch_size, num_tokens] mask = processed_inputs["padding_mask"] @@ -123,13 +123,13 @@ def embed_texts(self, texts: Sequence[str]): def embed_and_mean_pool(self, texts: Sequence[str]): """Return a single vector for each text.""" embs, mask = self.embed_texts(texts) - # [batch_size, num_tokens, 1] - mask = tf.expand_dims(tf.cast(mask, dtype=tf.float32), axis=2) - # [batch_size, 1, emb_dim] + # [batch_size, num_tokens, 1] + mask = tf.expand_dims(tf.cast(mask, dtype=embs.dtype), axis=2) + # [batch_size, 1, emb_dim] pooled_embs = tf.reduce_sum( mask * embs, axis=1, keepdims=True ) / tf.reduce_sum(mask, axis=1, keepdims=True) - # [batch_size, emb_dim] + # [batch_size, emb_dim] return tf.squeeze(pooled_embs, axis=1) def predict_minibatch( @@ -203,7 +203,7 @@ def __init__(self, *args, **kw): def _pred(self, input_ids, padding_mask, target_masks): """Predict a batch of tokenized text.""" - # [batch_size, num_tokens]; ignore the last one in each row. + # [batch_size, num_tokens]; ignore the last one in each row. target_ids = tf.roll(input_ids, shift=-1, axis=1) ## @@ -226,13 +226,13 @@ def _pred(self, input_ids, padding_mask, target_masks): axis=0, ) - padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32) + padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool) # Shift masks back so they align with target_ids. loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) embeddings = None - with tf.GradientTape(watch_accessed_variables=True) as tape: + with tf.GradientTape(watch_accessed_variables=False) as tape: def layer_intercept_fn(x, i): if i == -1: @@ -241,7 +241,7 @@ def layer_intercept_fn(x, i): tape.watch(embeddings) return x - # [batch_size, num_tokens] + # [batch_size, num_tokens] per_token_loss = self.model.score( token_ids=input_ids, padding_mask=padding_mask, @@ -249,13 +249,13 @@ def layer_intercept_fn(x, i): layer_intercept_fn=layer_intercept_fn, target_ids=target_ids, ) - masked_loss = per_token_loss * loss_mask + masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype) - # [batch_size, num_tokens, hdim] + # [batch_size, num_tokens, hdim] grads = tape.gradient(masked_loss, embeddings) - # [batch_size, num_tokens] + # [batch_size, num_tokens] grad_l2 = tf.norm(grads, axis=2) - # [batch_size, num_tokens] + # [batch_size, num_tokens] grad_dot_input = tf.reduce_sum(grads * embeddings, axis=2) batched_outputs = { diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 28baea46..9fdae45d 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -490,7 +490,7 @@ def predict_minibatch(self, inputs): responses = self.tokenizer.batch_decode( outputs[:, -self.max_new_tokens :], skip_special_tokens=True ) - # Input embeddings: [batch_size, num_tokens, emb_dim] + # Input embeddings: [batch_size, num_tokens, emb_dim] embeddings = self.model.transformer.wte(outputs) batched_outputs = { "embs": embeddings, @@ -532,7 +532,7 @@ def _pred(self, encoded_inputs, target_masks): """ input_ids = encoded_inputs["input_ids"] - # [batch_size, num_tokens]; ignore the last one in each row. + # [batch_size, num_tokens]; ignore the last one in each row. target_ids = tf.roll(encoded_inputs["input_ids"], shift=-1, axis=1) ## # Process target masks @@ -554,11 +554,11 @@ def _pred(self, encoded_inputs, target_masks): axis=0, ) - padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32) + padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool) # Shift masks back so they align with target_ids. loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) - with tf.GradientTape(watch_accessed_variables=True) as tape: + with tf.GradientTape(watch_accessed_variables=False) as tape: # We need to run the embedding layer ourselves so we can trace it. # See here for how the model normally does this: # http://google3/third_party/py/transformers/models/gpt2/modeling_tf_gpt2.py;l=450;rcl=578656271 @@ -574,18 +574,18 @@ def _pred(self, encoded_inputs, target_masks): loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction="none" ) - # [batch_size, num_tokens] + # [batch_size, num_tokens] per_token_loss = loss_fn(target_ids, out.logits) - masked_loss = per_token_loss * loss_mask + masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype) grads = tape.gradient( masked_loss, embs - ) # [batch_size, num_tokens, hdim] + ) # [batch_size, num_tokens, hdim] - grad_l2 = tf.norm(grads, axis=2) # [batch_size, num_tokens] + grad_l2 = tf.norm(grads, axis=2) # [batch_size, num_tokens] grad_dot_input = tf.reduce_sum( grads * embs, axis=2 - ) # [batch_size, num_tokens] + ) # [batch_size, num_tokens] batched_outputs = { "input_ids": encoded_inputs["input_ids"],