From cf1b3e609c442402f3214e4b5d6e532506b30092 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Tue, 10 Sep 2024 19:52:52 +0000 Subject: [PATCH] Adds LIT app server code --- lit_nlp/examples/gcp/constants.py | 6 ++ lit_nlp/examples/gcp/model.py | 106 +++++++++++++++++++++++++++ lit_nlp/examples/gcp/model_server.py | 9 ++- lit_nlp/examples/gcp/server.py | 78 ++++++++++++++++++++ 4 files changed, 196 insertions(+), 3 deletions(-) create mode 100644 lit_nlp/examples/gcp/constants.py create mode 100644 lit_nlp/examples/gcp/model.py create mode 100644 lit_nlp/examples/gcp/server.py diff --git a/lit_nlp/examples/gcp/constants.py b/lit_nlp/examples/gcp/constants.py new file mode 100644 index 00000000..244c7637 --- /dev/null +++ b/lit_nlp/examples/gcp/constants.py @@ -0,0 +1,6 @@ +import enum + +class LlmHTTPEndpoints(enum.Enum): + GENERATE = 'predict' + SALIENCE = 'salience' + TOKENIZE = 'tokenize' \ No newline at end of file diff --git a/lit_nlp/examples/gcp/model.py b/lit_nlp/examples/gcp/model.py new file mode 100644 index 00000000..67192acb --- /dev/null +++ b/lit_nlp/examples/gcp/model.py @@ -0,0 +1,106 @@ +"""Wrapper for connetecting to LLMs on GCP via the model_server HTTP API.""" + +import enum + +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.api.types import Spec +from lit_nlp.examples.gcp import constants as lit_gcp_constants +from lit_nlp.examples.prompt_debugging import constants as pd_constants +from lit_nlp.examples.prompt_debugging import utils as pd_utils +from lit_nlp.lib import serialize +import requests + +""" +Plan for this module: + +From GitHub: + +* Rebase to include cl/672527408 and the CL described above +* Define an enum to track HTTP endpoints across Python modules +* Adopt HTTP endpoint enum across model_server.py and LlmOverHTTP +* Adopt model_specs.py in LlmOverHTTP, using HTTP endpoint enum for + conditional additions + +""" + +_LlmHTTPEndpoints = lit_gcp_constants.LlmHTTPEndpoints + + +class LlmOverHTTP(lit_model.BatchedRemoteModel): + + def __init__( + self, + base_url: str, + endpoint: str | _LlmHTTPEndpoints, + max_concurrent_requests: int = 4, + max_qps: int | float = 25 + ): + super().__init__(max_concurrent_requests, max_qps) + self.endpoint = _LlmHTTPEndpoints(endpoint) + self.url = f'{base_url}/{self.endpoint.value}' + + def input_spec(self) -> lit_types.Spec: + input_spec = pd_constants.INPUT_SPEC + + if self.endpoint == _LlmHTTPEndpoints.SALIENCE: + input_spec |= pd_constants.INPUT_SPEC_SALIENCE + + return input_spec + + def output_spec(self) -> lit_types.Spec: + if self.endpoint == _LlmHTTPEndpoints.GENERATE: + return ( + pd_constants.OUTPUT_SPEC_GENERATION + | pd_constants.OUTPUT_SPEC_GENERATION_EMBEDDINGS + ) + elif self.endpoint == _LlmHTTPEndpoints.SALIENCE: + return pd_constants.OUTPUT_SPEC_SALIENCE + else: + return pd_constants.OUTPUT_SPEC_TOKENIZER + + def predict_minibatch( + self, inputs: list[lit_types.JsonDict] + ) -> list[lit_types.JsonDict]: + """Run prediction on a batch of inputs. + + Subclass should implement this. + + Args: + inputs: sequence of inputs, following model.input_spec() + + Returns: + list of outputs, following model.output_spec() + """ + response = requests.post( + self.url, data=serialize.to_json(list(inputs), simple=True) + ) + + if not (200 <= response.status_code < 300): + raise RuntimeError() + + outputs = serialize.from_json(response.text) + return outputs + + +def initialize_model_group_for_salience( + name: str, base_url: str, *args, **kw +) -> dict[str, lit_model.Model]: + """Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'.""" + salience_name, tokenizer_name = pd_utils.generate_model_group_names(name) + + generation_model = LlmOverHTTP( + *args, base_url=base_url, endpoint=_LlmHTTPEndpoints.GENERATE, **kw + ) + salience_model = LlmOverHTTP( + *args, base_url=base_url, endpoint=_LlmHTTPEndpoints.SALIENCE, **kw + ) + tokenizer_model = LlmOverHTTP( + *args, base_url=base_url, endpoint=_LlmHTTPEndpoints.TOKENIZE, **kw + ) + + return { + name: generation_model, + salience_name: salience_model, + tokenizer_name: tokenizer_model, + } diff --git a/lit_nlp/examples/gcp/model_server.py b/lit_nlp/examples/gcp/model_server.py index 22f590cd..d818c778 100644 --- a/lit_nlp/examples/gcp/model_server.py +++ b/lit_nlp/examples/gcp/model_server.py @@ -7,6 +7,7 @@ from absl import app from lit_nlp import dev_server +from lit_nlp.examples.gcp import constants as lit_gcp_constants from lit_nlp.examples.prompt_debugging import models as pd_models from lit_nlp.examples.prompt_debugging import utils as pd_utils from lit_nlp.lib import serialize @@ -19,6 +20,8 @@ DEFAULT_BATCH_SIZE = 1 DEFAULT_MODELS = 'gemma_1.1_2b_IT:gemma_1.1_instruct_2b_en' +_LlmHTTPEndpoints = lit_gcp_constants.LlmHTTPEndpoints + def get_wsgi_app() -> wsgi_app.App: """Return WSGI app for an LLM server.""" @@ -60,9 +63,9 @@ def _handler(app: wsgi_app.App, request, unused_environ): sal_name, tok_name = pd_utils.generate_model_group_names(gen_name) handlers = { - '/predict': models[gen_name].predict, - '/salience': models[sal_name].predict, - '/tokenize': models[tok_name].predict, + f'/{_LlmHTTPEndpoints.GENERATE.value}': models[gen_name].predict, + f'/{_LlmHTTPEndpoints.SALIENCE.value}': models[sal_name].predict, + f'/{_LlmHTTPEndpoints.TOKENIZE.value}': models[tok_name].predict, } wrapped_handlers = { diff --git a/lit_nlp/examples/gcp/server.py b/lit_nlp/examples/gcp/server.py new file mode 100644 index 00000000..437b5378 --- /dev/null +++ b/lit_nlp/examples/gcp/server.py @@ -0,0 +1,78 @@ +"""Server for sequence salience with a left-to-right language model.""" + +from collections.abc import Mapping, Sequence +import sys +from typing import Optional + +from absl import app +from absl import flags +from absl import logging +from lit_nlp import dev_server +from lit_nlp import server_flags +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.examples.gcp import model as lit_gcp_model +from lit_nlp.examples.prompt_debugging import datasets as pd_datasets +from lit_nlp.examples.prompt_debugging import layouts as pd_layouts + + +_FLAGS = flags.FLAGS + +_SPLASH_SCREEN_DOC = """ +# Language Model Salience + +To begin, select an example, then click the segment(s) (tokens, words, etc.) +of the output that you would like to explain. Preceding segments(s) will be +highlighted according to their importance to the selected target segment(s), +with darker colors indicating a greater influence (salience) of that segment on +the model's likelihood of the target segment. +""" + + +def init_llm_on_gcp( + name: str, base_url: str, *args, **kw +) -> Mapping[str, lit_model.Model]: + return lit_gcp_model.initialize_model_group_for_salience( + name=name, base_url=base_url, *args, **kw + ) + + +def get_wsgi_app() -> Optional[dev_server.LitServerType]: + """Return WSGI app for container-hosted demos.""" + _FLAGS.set_default("server_type", "external") + _FLAGS.set_default("demo_mode", True) + _FLAGS.set_default("page_title", "LM Prompt Debugging") + _FLAGS.set_default("default_layout", pd_layouts.THREE_PANEL) + # Parse flags without calling app.run(main), to avoid conflict with + # gunicorn command line flags. + unused = flags.FLAGS(sys.argv, known_only=True) + if unused: + logging.info("lm_demo:get_wsgi_app() called with unused args: %s", unused) + return main([]) + + +def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + lit_demo = dev_server.Server( + models={}, + datasets={}, + layouts=pd_layouts.PROMPT_DEBUGGING_LAYOUTS, + model_loaders={ + 'LLM on GCP': (init_llm_on_gcp, { + 'name': lit_types.String(), + 'base_url': lit_types.String(), + 'max_concurrent_requests': lit_types.Integer(default=1), + 'max_qps': lit_types.Scalar(default=25), + }) + }, + dataset_loaders=pd_datasets.get_dataset_loaders(), + onboard_start_doc=_SPLASH_SCREEN_DOC, + **server_flags.get_flags(), + ) + return lit_demo.serve() + + +if __name__ == "__main__": + app.run(main)