Skip to content

Commit

Permalink
fix: update openai_function_call
Browse files Browse the repository at this point in the history
  • Loading branch information
AlmogBaku committed Sep 27, 2023
1 parent 5f05d46 commit 853190d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
2 changes: 2 additions & 0 deletions openai_streaming/fn_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def o_func(func):
"""
if hasattr(func, 'func'):
return o_func(func.func)
if hasattr(func, '__func'):
return o_func(func.__func)
return func


Expand Down
16 changes: 12 additions & 4 deletions openai_streaming/openai_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
#
# Since the original project has taken a huge pivot and provide many unnecessary features - this is a stripped version
# of the openai_function decorator copied from
# https://github.com/jxnl/instructor/blob/0.2.3/openai_function_call/function_calls.py
# https://github.com/jxnl/instructor/blob/0.2.8/instructor/function_calls.py

import json
from docstring_parser import parse
from functools import wraps
from typing import Any, Callable
from pydantic import validate_arguments
Expand All @@ -35,7 +36,7 @@ def _remove_a_key(d, remove_key) -> None:
"""Remove a key from a dictionary recursively"""
if isinstance(d, dict):
for key in list(d.keys()):
if key == remove_key:
if key == remove_key and "type" in d.keys():
del d[key]
else:
_remove_a_key(d[key], remove_key)
Expand Down Expand Up @@ -70,20 +71,27 @@ def sum(a: int, b: int) -> int:
def __init__(self, func: Callable) -> None:
self.func = func
self.validate_func = validate_arguments(func)
self.docstring = parse(self.func.__doc__ or "")

parameters = self.validate_func.model.model_json_schema()
parameters["properties"] = {
k: v
for k, v in parameters["properties"].items()
if k not in ("v__duplicate_kwargs", "args", "kwargs")
}
for param in self.docstring.params:
if (name := param.arg_name) in parameters["properties"] and (
description := param.description
):
parameters["properties"][name]["description"] = description
parameters["required"] = sorted(
k for k, v in parameters["properties"].items() if not "default" in v
)
_remove_a_key(parameters, "additionalProperties")
_remove_a_key(parameters, "title")
self.openai_schema = {
"name": self.func.__name__,
"description": self.func.__doc__,
"description": self.docstring.short_description,
"parameters": parameters,
}
self.model = self.validate_func.model
Expand All @@ -106,7 +114,7 @@ def from_response(self, completion, throw_error=True):
Returns:
result (any): result of the function call
"""
message = completion.choices[0].message
message = completion["choices"][0]["message"]

if throw_error:
assert "function_call" in message, "No function call detected"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ python = "^3.9"
openai = "^0.27.8"
json-streamer = "^0.1.0"
pydantic = "^2.0.2"
docstring-parser = "^0.15"

[dev-dependencies]
pytest = "^6.2"
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
openai==0.27.8
json-streamer==0.1.0
pydantic==2.0.2
docstring-parser==0.15

0 comments on commit 853190d

Please sign in to comment.