-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
1,203 additions
and
241 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ build/ | |
dist | ||
openai_streaming.egg-info/ | ||
.benchmarks | ||
junit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,88 @@ | ||
from collections.abc import AsyncGenerator | ||
from inspect import iscoroutinefunction | ||
from inspect import iscoroutinefunction, signature | ||
from types import FunctionType | ||
from typing import Generator, get_origin, Union, Optional, Any | ||
from typing import Generator, get_origin, Union, Optional, Any, get_type_hints | ||
from typing import get_args | ||
from .openai_function import openai_function | ||
|
||
from docstring_parser import parse | ||
from openai.types.beta.assistant import ToolFunction | ||
from openai.types.shared import FunctionDefinition | ||
from pydantic import create_model | ||
|
||
|
||
def openai_streaming_function(func: FunctionType) -> Any: | ||
""" | ||
Decorator that converts a function to an OpenAI streaming function using the `openai-function-call` package. | ||
It simply "reduces" the type of the arguments to the Generator type, and uses `openai_function` to do the rest. | ||
Decorator that creates an OpenAI Schema for your function, while support using Generators for Streaming. | ||
To document your function (so the model will know how to use it), simply use docstring. | ||
Using standard docstring styles will also allow you to document your argument's description | ||
:Example: | ||
```python | ||
@openai_streaming_function | ||
async def error_message(typ: str, description: AsyncGenerator[str, None]): | ||
\""" | ||
You MUST use this function when requested to do something that you cannot do. | ||
:param typ: The error's type | ||
:param description: The error description | ||
\""" | ||
pass | ||
``` | ||
:param func: The function to convert | ||
:return: Wrapped function with a `openai_schema` attribute | ||
:return: Your function with additional attribute `openai_schema` | ||
""" | ||
if not iscoroutinefunction(func): | ||
raise ValueError("openai_streaming_function can only be applied to async functions") | ||
raise ValueError("openai_streaming only supports async functions.") | ||
|
||
for key, val in func.__annotations__.items(): | ||
optional = False | ||
type_hints = get_type_hints(func) | ||
for key, val in type_hints.items(): | ||
|
||
args = get_args(val) | ||
if get_origin(val) is Union and len(args) == 2: | ||
gen = None | ||
other = None | ||
for arg in args: | ||
if isinstance(arg, type(None)): | ||
optional = True | ||
if get_origin(arg) is get_origin(Generator) or get_origin(arg) is AsyncGenerator: | ||
gen = arg | ||
else: | ||
other = arg | ||
if gen is not None and (get_args(gen)[0] is other or optional): | ||
val = gen | ||
|
||
args = get_args(val) | ||
# Unpack optionals | ||
optional = False | ||
if val is Optional or (get_origin(val) is Union and len(args) == 2 and args[1] is type(None)): | ||
optional = True | ||
val = args[0] | ||
args = get_args(val) | ||
|
||
if get_origin(val) is get_origin(Generator): | ||
raise ValueError("openai_streaming_function does not support Generator type. Use AsyncGenerator instead.") | ||
raise ValueError("openai_streaming does not support `Generator` type, instead use `AsyncGenerator`.") | ||
if get_origin(val) is AsyncGenerator: | ||
val = args[0] | ||
|
||
if optional: | ||
val = Optional[val] | ||
func.__annotations__[key] = val | ||
|
||
wrapped = openai_function(func) | ||
if hasattr(wrapped, "model") and "self" in wrapped.model.model_fields: | ||
del wrapped.model.model_fields["self"] | ||
if hasattr(wrapped, "openai_schema") and "self" in wrapped.openai_schema["parameters"]["properties"]: | ||
del wrapped.openai_schema["parameters"]["properties"]["self"] | ||
for i, required in enumerate(wrapped.openai_schema["parameters"]["required"]): | ||
if required == "self": | ||
del wrapped.openai_schema["parameters"]["required"][i] | ||
break | ||
return wrapped | ||
|
||
type_hints[key] = val | ||
|
||
# Prepare fields for the dynamic model | ||
fields = { | ||
param.name: (type_hints[param.name], ...) | ||
for param in signature(func).parameters.values() | ||
if param.name in type_hints | ||
} | ||
|
||
# Create a Pydantic model dynamically | ||
model = create_model(func.__name__, **fields) | ||
|
||
# parse the function docstring | ||
docstring = parse(func.__doc__ or "") | ||
|
||
# prepare the parameters(arguments) | ||
parameters = model.model_json_schema() | ||
|
||
# extract parameter documentations from the docstring | ||
for param in docstring.params: | ||
if (name := param.arg_name) in parameters["properties"] and (description := param.description): | ||
parameters["properties"][name]["description"] = description | ||
|
||
func.openai_schema = ToolFunction(type='function', function=FunctionDefinition( | ||
name=func.__name__, | ||
description=docstring.short_description, | ||
parameters=parameters, | ||
)) | ||
|
||
return func |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.