Skip to content

Commit

Permalink
Modify the output to be a list to ensure compatibility with JSON seri…
Browse files Browse the repository at this point in the history
…alization.
  • Loading branch information
llcourage committed Sep 5, 2024
1 parent 9d39a00 commit c9c7aff
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion lit_nlp/examples/gcp/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _handler(app, request, environ):
data = serialize.from_json(request.data) if len(request.data) else None
inputs = data['inputs']
outputs = predict_fn(inputs)
response_body = serialize.to_json(outputs, simple=True)
response_body = serialize.to_json(list(outputs), simple=True)
return app.respond(request, response_body, 'application/json', 200)

return _handler
Expand Down
16 changes: 8 additions & 8 deletions lit_nlp/examples/gcp/model_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ class TestWSGIApp(absltest.TestCase):
def test_predict_endpoint(self, mock_get_models):

mock_model = mock.MagicMock()
mock_model.predict.side_effect = [{'response': 'test output text'}]
mock_model.predict.side_effect = [[{'response': 'test output text'}]]

salience_model = mock.MagicMock()
salience_model.predict.side_effect = [{
salience_model.predict.side_effect = [[{
'tokens': ['test', 'output', 'text'],
'grad_l2': [0.1234, 0.3456, 0.5678],
'grad_dot_input': [0.1234, -0.3456, 0.5678],
}]
}]]

tokenize_model = mock.MagicMock()
tokenize_model.predict.side_effect = [
{'tokens': ['test', 'output', 'text']}
[{'tokens': ['test', 'output', 'text']}]
]

mock_get_models.return_value = {
Expand All @@ -33,22 +33,22 @@ def test_predict_endpoint(self, mock_get_models):

response = app.post_json('/predict', {'inputs': 'test_input'})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, {'response': 'test output text'})
self.assertEqual(response.json, [{'response': 'test output text'}])

response = app.post_json('/salience', {'inputs': 'test_input'})
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json,
{
[{
'tokens': ['test', 'output', 'text'],
'grad_l2': [0.1234, 0.3456, 0.5678],
'grad_dot_input': [0.1234, -0.3456, 0.5678],
},
}],
)

response = app.post_json('/tokenize', {'inputs': 'test_input'})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, {'tokens': ['test', 'output', 'text']})
self.assertEqual(response.json, [{'tokens': ['test', 'output', 'text']}])


if __name__ == '__main__':
Expand Down

0 comments on commit c9c7aff

Please sign in to comment.