From 853190dc1cabb135bccc99ca477c169acfad16c3 Mon Sep 17 00:00:00 2001 From: almogbaku Date: Wed, 27 Sep 2023 17:13:42 +0300 Subject: [PATCH] fix: update openai_function_call --- openai_streaming/fn_dispatcher.py | 2 ++ openai_streaming/openai_function.py | 16 ++++++++++++---- pyproject.toml | 1 + requirements.txt | 1 + 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/openai_streaming/fn_dispatcher.py b/openai_streaming/fn_dispatcher.py index 6e774a8..410835e 100644 --- a/openai_streaming/fn_dispatcher.py +++ b/openai_streaming/fn_dispatcher.py @@ -27,6 +27,8 @@ def o_func(func): """ if hasattr(func, 'func'): return o_func(func.func) + if hasattr(func, '__func'): + return o_func(func.__func) return func diff --git a/openai_streaming/openai_function.py b/openai_streaming/openai_function.py index cb5bad8..18abfc8 100644 --- a/openai_streaming/openai_function.py +++ b/openai_streaming/openai_function.py @@ -23,9 +23,10 @@ # # Since the original project has taken a huge pivot and provide many unnecessary features - this is a stripped version # of the openai_function decorator copied from -# https://github.com/jxnl/instructor/blob/0.2.3/openai_function_call/function_calls.py +# https://github.com/jxnl/instructor/blob/0.2.8/instructor/function_calls.py import json +from docstring_parser import parse from functools import wraps from typing import Any, Callable from pydantic import validate_arguments @@ -35,7 +36,7 @@ def _remove_a_key(d, remove_key) -> None: """Remove a key from a dictionary recursively""" if isinstance(d, dict): for key in list(d.keys()): - if key == remove_key: + if key == remove_key and "type" in d.keys(): del d[key] else: _remove_a_key(d[key], remove_key) @@ -70,12 +71,19 @@ def sum(a: int, b: int) -> int: def __init__(self, func: Callable) -> None: self.func = func self.validate_func = validate_arguments(func) + self.docstring = parse(self.func.__doc__ or "") + parameters = self.validate_func.model.model_json_schema() parameters["properties"] = { k: v for k, v in parameters["properties"].items() if k not in ("v__duplicate_kwargs", "args", "kwargs") } + for param in self.docstring.params: + if (name := param.arg_name) in parameters["properties"] and ( + description := param.description + ): + parameters["properties"][name]["description"] = description parameters["required"] = sorted( k for k, v in parameters["properties"].items() if not "default" in v ) @@ -83,7 +91,7 @@ def __init__(self, func: Callable) -> None: _remove_a_key(parameters, "title") self.openai_schema = { "name": self.func.__name__, - "description": self.func.__doc__, + "description": self.docstring.short_description, "parameters": parameters, } self.model = self.validate_func.model @@ -106,7 +114,7 @@ def from_response(self, completion, throw_error=True): Returns: result (any): result of the function call """ - message = completion.choices[0].message + message = completion["choices"][0]["message"] if throw_error: assert "function_call" in message, "No function call detected" diff --git a/pyproject.toml b/pyproject.toml index 39c18fc..c2c2dfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ python = "^3.9" openai = "^0.27.8" json-streamer = "^0.1.0" pydantic = "^2.0.2" +docstring-parser = "^0.15" [dev-dependencies] pytest = "^6.2" diff --git a/requirements.txt b/requirements.txt index 418359b..879631f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ openai==0.27.8 json-streamer==0.1.0 pydantic==2.0.2 +docstring-parser==0.15 \ No newline at end of file