diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml
index b1b3e35e478f..7e8fb0033177 100644
--- a/.github/workflows/contrib-openai.yml
+++ b/.github/workflows/contrib-openai.yml
@@ -111,46 +111,7 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
- CompressionTest:
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: ["3.9"]
- runs-on: ${{ matrix.os }}
- environment: openai1
- steps:
- # checkout to pr branch
- - name: Checkout
- uses: actions/checkout@v4
- with:
- ref: ${{ github.event.pull_request.head.sha }}
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install packages and dependencies
- run: |
- docker --version
- python -m pip install --upgrade pip wheel
- pip install -e .
- python -c "import autogen"
- pip install pytest-cov>=5 pytest-asyncio
- - name: Install packages for test when needed
- run: |
- pip install docker
- - name: Coverage
- env:
- OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
- AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
- OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
- run: |
- pytest test/agentchat/contrib/test_compressible_agent.py
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v3
- with:
- file: ./coverage.xml
- flags: unittests
+
GPTAssistantAgent:
strategy:
matrix:
@@ -306,44 +267,7 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
- ContextHandling:
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: ["3.11"]
- runs-on: ${{ matrix.os }}
- environment: openai1
- steps:
- # checkout to pr branch
- - name: Checkout
- uses: actions/checkout@v4
- with:
- ref: ${{ github.event.pull_request.head.sha }}
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install packages and dependencies
- run: |
- docker --version
- python -m pip install --upgrade pip wheel
- pip install -e .
- python -c "import autogen"
- pip install pytest-cov>=5
- - name: Coverage
- env:
- OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
- AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
- OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
- BING_API_KEY: ${{ secrets.BING_API_KEY }}
- run: |
- pytest test/agentchat/contrib/capabilities/test_context_handling.py
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v3
- with:
- file: ./coverage.xml
- flags: unittests
+
ImageGen:
strategy:
matrix:
diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml
index 7d8a932b0254..3abe257dfad6 100644
--- a/.github/workflows/contrib-tests.yml
+++ b/.github/workflows/contrib-tests.yml
@@ -9,6 +9,8 @@ on:
paths:
- "autogen/**"
- "test/agentchat/contrib/**"
+ - "test/test_browser_utils.py"
+ - "test/test_retrieve_utils.py"
- ".github/workflows/contrib-tests.yml"
- "setup.py"
@@ -85,6 +87,10 @@ jobs:
--health-retries 5
ports:
- 5432:5432
+ mongodb:
+ image: mongodb/mongodb-atlas-local:latest
+ ports:
+ - 27017:27017
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -102,6 +108,9 @@ jobs:
- name: Install pgvector when on linux
run: |
pip install -e .[retrievechat-pgvector]
+ - name: Install mongodb when on linux
+ run: |
+ pip install -e .[retrievechat-mongodb]
- name: Install unstructured when python-version is 3.9 and on linux
if: matrix.python-version == '3.9'
run: |
@@ -154,41 +163,6 @@ jobs:
file: ./coverage.xml
flags: unittests
- CompressionTest:
- runs-on: ${{ matrix.os }}
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, macos-latest, windows-2019]
- python-version: ["3.10"]
- steps:
- - uses: actions/checkout@v4
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install packages and dependencies for all tests
- run: |
- python -m pip install --upgrade pip wheel
- pip install pytest-cov>=5
- - name: Install packages and dependencies for Compression
- run: |
- pip install -e .
- - name: Set AUTOGEN_USE_DOCKER based on OS
- shell: bash
- run: |
- if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
- echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
- fi
- - name: Coverage
- run: |
- pytest test/agentchat/contrib/test_compressible_agent.py --skip-openai
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v3
- with:
- file: ./coverage.xml
- flags: unittests
-
GPTAssistantAgent:
runs-on: ${{ matrix.os }}
strategy:
@@ -375,41 +349,6 @@ jobs:
file: ./coverage.xml
flags: unittests
- ContextHandling:
- runs-on: ${{ matrix.os }}
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, macos-latest, windows-2019]
- python-version: ["3.11"]
- steps:
- - uses: actions/checkout@v4
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install packages and dependencies for all tests
- run: |
- python -m pip install --upgrade pip wheel
- pip install pytest-cov>=5
- - name: Install packages and dependencies for Context Handling
- run: |
- pip install -e .
- - name: Set AUTOGEN_USE_DOCKER based on OS
- shell: bash
- run: |
- if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
- echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
- fi
- - name: Coverage
- run: |
- pytest test/agentchat/contrib/capabilities/test_context_handling.py --skip-openai
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v3
- with:
- file: ./coverage.xml
- flags: unittests
-
TransformMessages:
runs-on: ${{ matrix.os }}
strategy:
@@ -476,7 +415,6 @@ jobs:
file: ./coverage.xml
flags: unittests
-
AnthropicTest:
runs-on: ${{ matrix.os }}
strategy:
@@ -598,3 +536,119 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
+
+ GroqTest:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-2019]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ exclude:
+ - os: macos-latest
+ python-version: "3.9"
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for Groq
+ run: |
+ pip install -e .[groq,test]
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
+ fi
+ - name: Coverage
+ run: |
+ pytest test/oai/test_groq.py --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+
+ CohereTest:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for Cohere
+ run: |
+ pip install -e .[cohere,test]
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
+ fi
+ - name: Coverage
+ run: |
+ pytest test/oai/test_cohere.py --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+
+ BedrockTest:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-2019]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ exclude:
+ - os: macos-latest
+ python-version: "3.9"
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for Amazon Bedrock
+ run: |
+ pip install -e .[boto3,test]
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
+ fi
+ - name: Coverage
+ run: |
+ pytest test/oai/test_bedrock.py --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
diff --git a/.github/workflows/dotnet-build.yml b/.github/workflows/dotnet-build.yml
index f4074b061693..6aac54d3818c 100644
--- a/.github/workflows/dotnet-build.yml
+++ b/.github/workflows/dotnet-build.yml
@@ -43,24 +43,47 @@ jobs:
if: steps.filter.outputs.workflows == 'true'
build:
name: Dotnet Build
- runs-on: ubuntu-latest
needs: paths-filter
if: needs.paths-filter.outputs.hasChanges == 'true'
defaults:
run:
working-directory: dotnet
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest, macos-latest ]
+ python-version: ["3.11"]
+ runs-on: ${{ matrix.os }}
+ timeout-minutes: 30
steps:
- uses: actions/checkout@v4
with:
lfs: true
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install jupyter and ipykernel
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install jupyter
+ python -m pip install ipykernel
+ - name: list available kernels
+ run: |
+ python -m jupyter kernelspec list
- name: Setup .NET
uses: actions/setup-dotnet@v4
with:
- global-json-file: dotnet/global.json
+ dotnet-version: '8.0.x'
- name: Restore dependencies
run: |
# dotnet nuget add source --name dotnet-tool https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-tools/nuget/v3/index.json --configfile NuGet.config
dotnet restore -bl
+ - name: Format check
+ run: |
+ echo "Format check"
+ echo "If you see any error in this step, please run 'dotnet format' locally to format the code."
+ dotnet format --verify-no-changes -v diag --no-restore
- name: Build
run: |
echo "Build AutoGen"
@@ -87,7 +110,7 @@ jobs:
- name: Setup dotnet
uses: actions/setup-dotnet@v4
with:
- global-json-file: dotnet/global.json
+ dotnet-version: '8.0.x'
- name: publish AOT testApp, assert static analysis warning count, and run the app
shell: pwsh
@@ -105,6 +128,18 @@ jobs:
- uses: actions/checkout@v4
with:
lfs: true
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v5
+ with:
+ python-version: 3.11
+ - name: Install jupyter and ipykernel
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install jupyter
+ python -m pip install ipykernel
+ - name: list available kernels
+ run: |
+ python -m jupyter kernelspec list
- name: Setup .NET
uses: actions/setup-dotnet@v4
with:
@@ -176,12 +211,14 @@ jobs:
env:
AZURE_ARTIFACTS_FEED_URL: https://devdiv.pkgs.visualstudio.com/DevDiv/_packaging/AutoGen/nuget/v3/index.json
NUGET_AUTH_TOKEN: ${{ secrets.AZURE_DEVOPS_TOKEN }}
+ continue-on-error: true
- name: Publish nightly package to github package
run: |
echo "Publish nightly package to github package"
echo "ls output directory"
ls -R ./output/nightly
dotnet nuget push --api-key ${{ secrets.GITHUB_TOKEN }} --source "https://nuget.pkg.github.com/microsoft/index.json" ./output/nightly/*.nupkg --skip-duplicate
+ continue-on-error: true
- name: Publish nightly package to agentchat myget feed
run: |
echo "Publish nightly package to agentchat myget feed"
@@ -190,3 +227,5 @@ jobs:
dotnet nuget push --api-key ${{ secrets.MYGET_TOKEN }} --source "https://www.myget.org/F/agentchat/api/v3/index.json" ./output/nightly/*.nupkg --skip-duplicate
env:
MYGET_TOKEN: ${{ secrets.MYGET_TOKEN }}
+ continue-on-error: true
+
diff --git a/.github/workflows/dotnet-release.yml b/.github/workflows/dotnet-release.yml
index 2877d058377b..23f4258a0e0c 100644
--- a/.github/workflows/dotnet-release.yml
+++ b/.github/workflows/dotnet-release.yml
@@ -29,10 +29,22 @@ jobs:
- uses: actions/checkout@v4
with:
lfs: true
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v5
+ with:
+ python-version: 3.11
+ - name: Install jupyter and ipykernel
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install jupyter
+ python -m pip install ipykernel
+ - name: list available kernels
+ run: |
+ python -m jupyter kernelspec list
- name: Setup .NET
uses: actions/setup-dotnet@v4
with:
- global-json-file: dotnet/global.json
+ dotnet-version: '8.0.x'
- name: Restore dependencies
run: |
dotnet restore -bl
diff --git a/CITATION.cff b/CITATION.cff
index bc9a03f375a8..5e4c468067f7 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -5,7 +5,7 @@ preferred-citation:
given-names: "Qingyun"
affiliation: "Penn State University, University Park PA USA"
- family-names: "Bansal"
- given-names: "Gargan"
+ given-names: "Gagan"
affiliation: "Microsoft Research, Redmond WA USA"
- family-names: "Zhang"
given-names: "Jieyu"
@@ -43,6 +43,7 @@ preferred-citation:
- family-names: "Wang"
given-names: "Chi"
affiliation: "Microsoft Research, Redmond WA USA"
- booktitle: "ArXiv preprint arXiv:2308.08155"
+ booktitle: "COLM"
title: "AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework"
- year: 2023
+ year: 2024
+ url: "https://aka.ms/autogen-pdf"
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
new file mode 100644
index 000000000000..4726588453b4
--- /dev/null
+++ b/CONTRIBUTORS.md
@@ -0,0 +1,43 @@
+# Contributors
+
+## Special thanks to all the people who help this project:
+> These individuals dedicate their time and expertise to improve this project. We are deeply grateful for their contributions.
+
+| Name | GitHub Handle | Organization | Features | Roadmap Lead | Additional Information |
+|---|---|---|---|---|---|
+| Qingyun Wu | [qingyun-wu](https://github.com/qingyun-wu) | Penn State University | all, alt-models, autobuilder | Yes | Available most of the time (US Eastern Time) |
+| Chi Wang | [sonichi](https://github.com/sonichi) | - | all | Yes | |
+| Li Jiang | [thinkall](https://github.com/thinkall) | Microsoft | rag, autobuilder, group chat | Yes | [Issue #1657](https://github.com/microsoft/autogen/issues/1657) - Beijing, GMT+8 |
+| Mark Sze | [marklysze](https://github.com/marklysze) | - | alt-models, group chat | No | Generally available (Sydney, AU time) - Group Chat "auto" speaker selection |
+| Hrushikesh Dokala | [Hk669](https://github.com/Hk669) | - | alt-models, swebench, logging, rag | No | [Issue #2946](https://github.com/microsoft/autogen/issues/2946), [Pull Request #2933](https://github.com/microsoft/autogen/pull/2933) - Available most of the time (India, GMT+5:30) |
+| Jiale Liu | [LeoLjl](https://github.com/LeoLjl) | Penn State University | autobuild, group chat | No | |
+| Shaokun Zhang | [skzhang1](https://github.com/skzhang1) | Penn State University | AgentOptimizer, Teachability | Yes | [Issue #521](https://github.com/microsoft/autogen/issues/521) |
+| Rajan Chari | [rajan-chari](https://github.com/rajan-chari) | Microsoft Research | CAP, Survey of other frameworks | No | |
+| Victor Dibia | [victordibia](https://github.com/victordibia) | Microsoft Research | autogenstudio | Yes | [Issue #737](https://github.com/microsoft/autogen/issues/737) |
+| Yixuan Zhai | [randombet](https://github.com/randombet) | Meta | group chat, sequential_chats, rag | No | |
+| Xiaoyun Zhang | [LittleLittleCloud](https://github.com/LittleLittleCloud) | Microsoft | AutoGen.Net, group chat | Yes | [Backlog - AutoGen.Net](https://github.com/microsoft/autogen/issues) - Available most of the time (PST) |
+| Yiran Wu | [yiranwu0](https://github.com/yiranwu0) | Penn State University | alt-models, group chat, logging | Yes | |
+| Beibin Li | [BeibinLi](https://github.com/BeibinLi) | Microsoft Research | alt-models | Yes | |
+| Gagan Bansal | [gagb](https://github.com/gagb) | Microsoft Research | All | | |
+| Adam Fourney | [afourney](https://github.com/afourney) | Microsoft Research | Complex Tasks | | |
+| Ricky Loynd | [rickyloynd-microsoft](https://github.com/rickyloynd-microsoft) | Microsoft Research | Teachability | | |
+| Eric Zhu | [ekzhu](https://github.com/ekzhu) | Microsoft Research | All, Infra | | |
+| Jack Gerrits | [jackgerrits](https://github.com/jackgerrits) | Microsoft Research | All, Infra | | |
+| David Luong | [DavidLuong98](https://github.com/DavidLuong98) | Microsoft | AutoGen.Net | | |
+| Davor Runje | [davorrunje](https://github.com/davorrunje) | airt.ai | Tool calling, IO | | Available most of the time (Central European Time) |
+| Friederike Niedtner | [Friderike](https://www.microsoft.com/en-us/research/people/fniedtner/) | Microsoft Research | PM | | |
+| Rafah Hosn | [Rafah](https://www.microsoft.com/en-us/research/people/raaboulh/) | Microsoft Research | PM | | |
+| Robin Moeur | [Robin](https://www.linkedin.com/in/rmoeur/) | Microsoft Research | PM | | |
+| Jingya Chen | [jingyachen](https://github.com/JingyaChen) | Microsoft | UX Design, AutoGen Studio | | |
+| Suff Syed | [suffsyed](https://github.com/suffsyed) | Microsoft | UX Design, AutoGen Studio | | |
+
+## I would like to join this list. How can I help the project?
+> We're always looking for new contributors to join our team and help improve the project. For more information, please refer to our [CONTRIBUTING](https://microsoft.github.io/autogen/docs/contributor-guide/contributing) guide.
+
+
+## Are you missing from this list?
+> Please open a PR to help us fix this.
+
+
+## Acknowledgements
+This template was adapted from [GitHub Template Guide](https://github.com/cezaraugusto/github-template-guidelines/blob/master/.github/CONTRIBUTORS.md) by [cezaraugusto](https://github.com/cezaraugusto).
diff --git a/README.md b/README.md
index 5bff3300a50e..8595bb60506c 100644
--- a/README.md
+++ b/README.md
@@ -1,20 +1,57 @@
-[![PyPI version](https://badge.fury.io/py/pyautogen.svg)](https://badge.fury.io/py/pyautogen)
-[![Build](https://github.com/microsoft/autogen/actions/workflows/python-package.yml/badge.svg)](https://github.com/microsoft/autogen/actions/workflows/python-package.yml)
-![Python Version](https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)
+
+
+
+
+
+![Python Version](https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue) [![PyPI version](https://img.shields.io/badge/PyPI-v0.2.34-blue.svg)](https://pypi.org/project/pyautogen/)
+[![NuGet version](https://badge.fury.io/nu/AutoGen.Core.svg)](https://badge.fury.io/nu/AutoGen.Core)
+
[![Downloads](https://static.pepy.tech/badge/pyautogen/week)](https://pepy.tech/project/pyautogen)
[![Discord](https://img.shields.io/discord/1153072414184452236?logo=discord&style=flat)](https://aka.ms/autogen-dc)
+
[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40pyautogen)](https://twitter.com/pyautogen)
-[![NuGet version](https://badge.fury.io/nu/AutoGen.Core.svg)](https://badge.fury.io/nu/AutoGen.Core)
+
# AutoGen
-[📚 Cite paper](#related-papers).
-
+
+AutoGen is an open-source programming framework for building AI agents and facilitating cooperation among multiple agents to solve tasks. AutoGen aims to streamline the development and research of agentic AI, much like PyTorch does for Deep Learning. It offers features such as agents capable of interacting with each other, facilitates the use of various large language models (LLMs) and tool use support, autonomous and human-in-the-loop workflows, and multi-agent conversation patterns.
+
+> [!IMPORTANT]
+> *Note for contributors and users*: [microsoft/autogen](https://aka.ms/autogen-gh) is the official repository of AutoGen project and it is under active development and maintenance under MIT license. We welcome contributions from developers and organizations worldwide. Our goal is to foster a collaborative and inclusive community where diverse perspectives and expertise can drive innovation and enhance the project's capabilities. We acknowledge the invaluable contributions from our existing contributors, as listed in [contributors.md](./CONTRIBUTORS.md). Whether you are an individual contributor or represent an organization, we invite you to join us in shaping the future of this project. For further information please also see [Microsoft open-source contributing guidelines](https://github.com/microsoft/autogen?tab=readme-ov-file#contributing).
+>
+> -_Maintainers (Sept 6th, 2024)_
+
+
+![AutoGen Overview](https://github.com/microsoft/autogen/blob/main/website/static/img/autogen_agentchat.png)
+
+- AutoGen enables building next-gen LLM applications based on [multi-agent conversations](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat) with minimal effort. It simplifies the orchestration, automation, and optimization of a complex LLM workflow. It maximizes the performance of LLM models and overcomes their weaknesses.
+- It supports [diverse conversation patterns](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat#supporting-diverse-conversation-patterns) for complex workflows. With customizable and conversable agents, developers can use AutoGen to build a wide range of conversation patterns concerning conversation autonomy,
+ the number of agents, and agent conversation topology.
+- It provides a collection of working systems with different complexities. These systems span a [wide range of applications](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat#diverse-applications-implemented-with-autogen) from various domains and complexities. This demonstrates how AutoGen can easily support diverse conversation patterns.
+- AutoGen provides [enhanced LLM inference](https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#api-unification). It offers utilities like API unification and caching, and advanced usage patterns, such as error handling, multi-config inference, context programming, etc.
+
+AutoGen was created out of collaborative [research](https://microsoft.github.io/autogen/docs/Research) from Microsoft, Penn State University, and the University of Washington.
+
+
+
+ ↑ Back to Top ↑
+
+
+
+
+
+## News
+
+
+Expand
+
+:fire: June 6, 2024: WIRED publishes a new article on AutoGen: [Chatbot Teamwork Makes the AI Dream Work](https://www.wired.com/story/chatbot-teamwork-makes-the-ai-dream-work/) based on interview with [Adam Fourney](https://github.com/afourney).
+
+:fire: June 4th, 2024: Microsoft Research Forum publishes new update and video on [AutoGen and Complex Tasks](https://www.microsoft.com/en-us/research/video/autogen-update-complex-tasks-and-agents/) presented by [Adam Fourney](https://github.com/afourney).
+
:fire: May 29, 2024: DeepLearning.ai launched a new short course [AI Agentic Design Patterns with AutoGen](https://www.deeplearning.ai/short-courses/ai-agentic-design-patterns-with-autogen), made in collaboration with Microsoft and Penn State University, and taught by AutoGen creators [Chi Wang](https://github.com/sonichi) and [Qingyun Wu](https://github.com/qingyun-wu).
:fire: May 24, 2024: Foundation Capital published an article on [Forbes: The Promise of Multi-Agent AI](https://www.forbes.com/sites/joannechen/2024/05/24/the-promise-of-multi-agent-ai/?sh=2c1e4f454d97) and a video [AI in the Real World Episode 2: Exploring Multi-Agent AI and AutoGen with Chi Wang](https://www.youtube.com/watch?v=RLwyXRVvlNk).
@@ -23,7 +60,7 @@
:fire: May 11, 2024: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://openreview.net/pdf?id=uAjxFFing2) received the best paper award at the [ICLR 2024 LLM Agents Workshop](https://llmagents.github.io/).
-:fire: Apr 26, 2024: [AutoGen.NET](https://microsoft.github.io/autogen-for-net/) is available for .NET developers!
+:fire: Apr 26, 2024: [AutoGen.NET](https://microsoft.github.io/autogen-for-net/) is available for .NET developers! Thanks [XiaoYun Zhang](https://www.linkedin.com/in/xiaoyun-zhang-1b531013a/)
:fire: Apr 17, 2024: Andrew Ng cited AutoGen in [The Batch newsletter](https://www.deeplearning.ai/the-batch/issue-245/) and [What's next for AI agentic workflows](https://youtu.be/sal78ACtGTc?si=JduUzN_1kDnMq0vF) at Sequoia Capital's AI Ascent (Mar 26).
@@ -58,31 +95,7 @@
:fire: FLAML supports Code-First AutoML & Tuning – Private Preview in [Microsoft Fabric Data Science](https://learn.microsoft.com/en-us/fabric/data-science/). -->
-
-
- ↑ Back to Top ↑
-
-
-
-## What is AutoGen
-
-AutoGen is a framework that enables the development of LLM applications using multiple agents that can converse with each other to solve tasks. AutoGen agents are customizable, conversable, and seamlessly allow human participation. They can operate in various modes that employ combinations of LLMs, human inputs, and tools.
-
-![AutoGen Overview](https://github.com/microsoft/autogen/blob/main/website/static/img/autogen_agentchat.png)
-
-- AutoGen enables building next-gen LLM applications based on [multi-agent conversations](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat) with minimal effort. It simplifies the orchestration, automation, and optimization of a complex LLM workflow. It maximizes the performance of LLM models and overcomes their weaknesses.
-- It supports [diverse conversation patterns](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat#supporting-diverse-conversation-patterns) for complex workflows. With customizable and conversable agents, developers can use AutoGen to build a wide range of conversation patterns concerning conversation autonomy,
- the number of agents, and agent conversation topology.
-- It provides a collection of working systems with different complexities. These systems span a [wide range of applications](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat#diverse-applications-implemented-with-autogen) from various domains and complexities. This demonstrates how AutoGen can easily support diverse conversation patterns.
-- AutoGen provides [enhanced LLM inference](https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#api-unification). It offers utilities like API unification and caching, and advanced usage patterns, such as error handling, multi-config inference, context programming, etc.
-
-AutoGen is created out of collaborative [research](https://microsoft.github.io/autogen/docs/Research) from Microsoft, Penn State University, and the University of Washington.
-
-
-
- ↑ Back to Top ↑
-
-
+
## Roadmaps
@@ -242,16 +255,25 @@ In addition, you can find:
## Related Papers
-[AutoGen](https://arxiv.org/abs/2308.08155)
+[AutoGen Studio](https://www.microsoft.com/en-us/research/publication/autogen-studio-a-no-code-developer-tool-for-building-and-debugging-multi-agent-systems/)
+
+```
+@inproceedings{dibia2024studio,
+ title={AutoGen Studio: A No-Code Developer Tool for Building and Debugging Multi-Agent Systems},
+ author={Victor Dibia and Jingya Chen and Gagan Bansal and Suff Syed and Adam Fourney and Erkang (Eric) Zhu and Chi Wang and Saleema Amershi},
+ year={2024},
+ booktitle={Pre-Print}
+}
+```
+
+[AutoGen](https://aka.ms/autogen-pdf)
```
@inproceedings{wu2023autogen,
title={AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework},
author={Qingyun Wu and Gagan Bansal and Jieyu Zhang and Yiran Wu and Beibin Li and Erkang Zhu and Li Jiang and Xiaoyun Zhang and Shaokun Zhang and Jiale Liu and Ahmed Hassan Awadallah and Ryen W White and Doug Burger and Chi Wang},
- year={2023},
- eprint={2308.08155},
- archivePrefix={arXiv},
- primaryClass={cs.AI}
+ year={2024},
+ booktitle={COLM},
}
```
@@ -288,6 +310,16 @@ In addition, you can find:
}
```
+[StateFlow](https://arxiv.org/abs/2403.11322)
+```
+@article{wu2024stateflow,
+ title={StateFlow: Enhancing LLM Task-Solving through State-Driven Workflows},
+ author={Wu, Yiran and Yue, Tianwei and Zhang, Shaokun and Wang, Chi and Wu, Qingyun},
+ journal={arXiv preprint arXiv:2403.11322},
+ year={2024}
+}
+```
+
↑ Back to Top ↑
@@ -339,7 +371,7 @@ may be either trademarks or registered trademarks of Microsoft in the United Sta
The licenses for this project do not grant you rights to use any Microsoft names, logos, or trademarks.
Microsoft's general trademark guidelines can be found at http://go.microsoft.com/fwlink/?LinkID=254653.
-Privacy information can be found at https://privacy.microsoft.com/en-us/
+Privacy information can be found at https://go.microsoft.com/fwlink/?LinkId=521839
Microsoft and any contributors reserve all other rights, whether under their respective copyrights, patents,
or trademarks, whether by implication, estoppel, or otherwise.
diff --git a/TRANSPARENCY_FAQS.md b/TRANSPARENCY_FAQS.md
index 206af084748b..addf29d8b8d3 100644
--- a/TRANSPARENCY_FAQS.md
+++ b/TRANSPARENCY_FAQS.md
@@ -31,6 +31,8 @@ While AutoGen automates LLM workflows, decisions about how to use specific LLM o
- Current version of AutoGen was evaluated on six applications to illustrate its potential in simplifying the development of high-performance multi-agent applications. These applications are selected based on their real-world relevance, problem difficulty and problem solving capabilities enabled by AutoGen, and innovative potential.
- These applications involve using AutoGen to solve math problems, question answering, decision making in text world environments, supply chain optimization, etc. For each of these domains AutoGen was evaluated on various success based metrics (i.e., how often the AutoGen based implementation solved the task). And, in some cases, AutoGen based approach was also evaluated on implementation efficiency (e.g., to track reductions in developer effort to build). More details can be found at: https://aka.ms/AutoGen/TechReport
- The team has conducted tests where a “red” agent attempts to get the default AutoGen assistant to break from its alignment and guardrails. The team has observed that out of 70 attempts to break guardrails, only 1 was successful in producing text that would have been flagged as problematic by Azure OpenAI filters. The team has not observed any evidence that AutoGen (or GPT models as hosted by OpenAI or Azure) can produce novel code exploits or jailbreak prompts, since direct prompts to “be a hacker”, “write exploits”, or “produce a phishing email” are refused by existing filters.
+- We also evaluated [a team of AutoGen agents](https://github.com/microsoft/autogen/tree/gaia_multiagent_v01_march_1st/samples/tools/autogenbench/scenarios/GAIA/Templates/Orchestrator) on the [GAIA benchmarks](https://arxiv.org/abs/2311.12983), and got [SOTA results](https://huggingface.co/spaces/gaia-benchmark/leaderboard) as of
+ March 1, 2024.
## What are the limitations of AutoGen? How can users minimize the impact of AutoGen’s limitations when using the system?
AutoGen relies on existing LLMs. Experimenting with AutoGen would retain common limitations of large language models; including:
diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py
index 97411e9fc004..d07b4d15cb62 100644
--- a/autogen/agentchat/chat.py
+++ b/autogen/agentchat/chat.py
@@ -107,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
return chat_order
+def _post_process_carryover_item(carryover_item):
+ if isinstance(carryover_item, str):
+ return carryover_item
+ elif isinstance(carryover_item, dict) and "content" in carryover_item:
+ return str(carryover_item["content"])
+ else:
+ return str(carryover_item)
+
+
def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
iostream = IOStream.get_default()
@@ -116,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
UserWarning,
)
print_carryover = (
- ("\n").join([t for t in chat_info["carryover"]])
+ ("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
@@ -153,7 +162,7 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
For example:
- `"sender"` - the sender agent.
- `"recipient"` - the recipient agent.
- - `"clear_history" (bool) - whether to clear the chat history with the agent.
+ - `"clear_history"` (bool) - whether to clear the chat history with the agent.
Default is True.
- `"silent"` (bool or None) - (Experimental) whether to print the messages in this
conversation. Default is False.
diff --git a/autogen/agentchat/contrib/agent_builder.py b/autogen/agentchat/contrib/agent_builder.py
index c9a2d79607dd..430017d13fc9 100644
--- a/autogen/agentchat/contrib/agent_builder.py
+++ b/autogen/agentchat/contrib/agent_builder.py
@@ -103,7 +103,7 @@ class AgentBuilder:
"""
AGENT_NAME_PROMPT = """# Your task
-Suggest no more then {max_agents} experts with their name according to the following user requirement.
+Suggest no more than {max_agents} experts with their name according to the following user requirement.
## User requirement
{task}
diff --git a/autogen/agentchat/contrib/agent_eval/README.md b/autogen/agentchat/contrib/agent_eval/README.md
index 6588a1ec6113..478f28fd74ec 100644
--- a/autogen/agentchat/contrib/agent_eval/README.md
+++ b/autogen/agentchat/contrib/agent_eval/README.md
@@ -1,7 +1,9 @@
-Agents for running the AgentEval pipeline.
+Agents for running the [AgentEval](https://microsoft.github.io/autogen/blog/2023/11/20/AgentEval/) pipeline.
AgentEval is a process for evaluating a LLM-based system's performance on a given task.
When given a task to evaluate and a few example runs, the critic and subcritic agents create evaluation criteria for evaluating a system's solution. Once the criteria has been created, the quantifier agent can evaluate subsequent task solutions based on the generated criteria.
For more information see: [AgentEval Integration Roadmap](https://github.com/microsoft/autogen/issues/2162)
+
+See our [blog post](https://microsoft.github.io/autogen/blog/2024/06/21/AgentEval) for usage examples and general explanations.
diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py
deleted file mode 100644
index 44b10259f1b7..000000000000
--- a/autogen/agentchat/contrib/capabilities/context_handling.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import sys
-from typing import Dict, List, Optional
-from warnings import warn
-
-import tiktoken
-from termcolor import colored
-
-from autogen import ConversableAgent, token_count_utils
-
-warn(
- "Context handling with TransformChatHistory is deprecated and will be removed in `0.2.30`. "
- "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/topics/handling_long_contexts/intro_to_transform_messages",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class TransformChatHistory:
- """
- An agent's chat history with other agents is a common context that it uses to generate a reply.
- This capability allows the agent to transform its chat history prior to using it to generate a reply.
- It does not permanently modify the chat history, but rather processes it on every invocation.
-
- This capability class enables various strategies to transform chat history, such as:
- - Truncate messages: Truncate each message to first maximum number of tokens.
- - Limit number of messages: Truncate the chat history to a maximum number of (recent) messages.
- - Limit number of tokens: Truncate the chat history to number of recent N messages that fit in
- maximum number of tokens.
- Note that the system message, because of its special significance, is always kept as is.
-
- The three strategies can be combined. For example, when each of these parameters are specified
- they are used in the following order:
- 1. First truncate messages to a maximum number of tokens
- 2. Second, it limits the number of message to keep
- 3. Third, it limits the total number of tokens in the chat history
-
- When adding this capability to an agent, the following are modified:
- - A hook is added to the hookable method `process_all_messages_before_reply` to transform the
- received messages for possible truncation.
- Not modifying the stored message history.
- """
-
- def __init__(
- self,
- *,
- max_tokens_per_message: Optional[int] = None,
- max_messages: Optional[int] = None,
- max_tokens: Optional[int] = None,
- ):
- """
- Args:
- max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message.
- max_messages (Optional[int]): Maximum number of messages to keep in the context.
- max_tokens (Optional[int]): Maximum number of tokens to keep in the context.
- """
- self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize
- self.max_messages = max_messages if max_messages else sys.maxsize
- self.max_tokens = max_tokens if max_tokens else sys.maxsize
-
- def add_to_agent(self, agent: ConversableAgent):
- """
- Adds TransformChatHistory capability to the given agent.
- """
- agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
-
- def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
- """
- Args:
- messages: List of messages to process.
-
- Returns:
- List of messages with the first system message and the last max_messages messages,
- ensuring each message does not exceed max_tokens_per_message.
- """
- temp_messages = messages.copy()
- processed_messages = []
- system_message = None
- processed_messages_tokens = 0
-
- if messages[0]["role"] == "system":
- system_message = messages[0].copy()
- temp_messages.pop(0)
-
- total_tokens = sum(
- token_count_utils.count_token(msg["content"]) for msg in temp_messages
- ) # Calculate tokens for all messages
-
- # Truncate each message's content to a maximum token limit of each message
-
- # Process recent messages first
- for msg in reversed(temp_messages[-self.max_messages :]):
- msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
- msg_tokens = token_count_utils.count_token(msg["content"])
- if processed_messages_tokens + msg_tokens > self.max_tokens:
- break
- # append the message to the beginning of the list to preserve order
- processed_messages = [msg] + processed_messages
- processed_messages_tokens += msg_tokens
- if system_message:
- processed_messages.insert(0, system_message)
- # Optionally, log the number of truncated messages and tokens if needed
- num_truncated = len(messages) - len(processed_messages)
-
- if num_truncated > 0 or total_tokens > processed_messages_tokens:
- print(
- colored(
- f"Truncated {num_truncated} messages. Reduced from {len(messages)} to {len(processed_messages)}.",
- "yellow",
- )
- )
- print(
- colored(
- f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
- "yellow",
- )
- )
- return processed_messages
-
-
-def truncate_str_to_tokens(text: str, max_tokens: int, model: str = "gpt-3.5-turbo-0613") -> str:
- """Truncate a string so that the number of tokens is less than or equal to max_tokens using tiktoken.
-
- Args:
- text: The string to truncate.
- max_tokens: The maximum number of tokens to keep.
- model: The target OpenAI model for tokenization alignment.
-
- Returns:
- The truncated string.
- """
-
- encoding = tiktoken.encoding_for_model(model) # Get the appropriate tokenizer
-
- encoded_tokens = encoding.encode(text)
- truncated_tokens = encoded_tokens[:max_tokens]
- truncated_text = encoding.decode(truncated_tokens) # Decode back to text
-
- return truncated_text
diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py
index e96dc39fa7bc..1ce219bdadfa 100644
--- a/autogen/agentchat/contrib/capabilities/transform_messages.py
+++ b/autogen/agentchat/contrib/capabilities/transform_messages.py
@@ -1,9 +1,8 @@
import copy
from typing import Dict, List
-from autogen import ConversableAgent
-
from ....formatting_utils import colored
+from ...conversable_agent import ConversableAgent
from .transforms import MessageTransform
diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py
index dad3fc335edf..d9ad365b91b3 100644
--- a/autogen/agentchat/contrib/capabilities/transforms.py
+++ b/autogen/agentchat/contrib/capabilities/transforms.py
@@ -53,13 +53,16 @@ class MessageHistoryLimiter:
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""
- def __init__(self, max_messages: Optional[int] = None):
+ def __init__(self, max_messages: Optional[int] = None, keep_first_message: bool = False):
"""
Args:
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
+ keep_first_message bool: Whether to keep the original first message in the conversation history.
+ Defaults to False.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
+ self._keep_first_message = keep_first_message
def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Truncates the conversation history to the specified maximum number of messages.
@@ -75,10 +78,31 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""
- if self._max_messages is None:
+ if self._max_messages is None or len(messages) <= self._max_messages:
return messages
- return messages[-self._max_messages :]
+ truncated_messages = []
+ remaining_count = self._max_messages
+
+ # Start with the first message if we need to keep it
+ if self._keep_first_message:
+ truncated_messages = [messages[0]]
+ remaining_count -= 1
+
+ # Loop through messages in reverse
+ for i in range(len(messages) - 1, 0, -1):
+ if remaining_count > 1:
+ truncated_messages.insert(1 if self._keep_first_message else 0, messages[i])
+ if remaining_count == 1:
+ # If there's only 1 slot left and it's a 'tools' message, ignore it.
+ if messages[i].get("role") != "tool":
+ truncated_messages.insert(1, messages[i])
+
+ remaining_count -= 1
+ if remaining_count == 0:
+ break
+
+ return truncated_messages
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_len = len(pre_transform_messages)
@@ -421,3 +445,95 @@ def _compress_text(self, text: str) -> Tuple[str, int]:
def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")
+
+
+class TextMessageContentName:
+ """A transform for including the agent's name in the content of a message."""
+
+ def __init__(
+ self,
+ position: str = "start",
+ format_string: str = "{name}:\n",
+ deduplicate: bool = True,
+ filter_dict: Optional[Dict] = None,
+ exclude_filter: bool = True,
+ ):
+ """
+ Args:
+ position (str): The position to add the name to the content. The possible options are 'start' or 'end'. Defaults to 'start'.
+ format_string (str): The f-string to format the message name with. Use '{name}' as a placeholder for the agent's name. Defaults to '{name}:\n' and must contain '{name}'.
+ deduplicate (bool): Whether to deduplicate the formatted string so it doesn't appear twice (sometimes the LLM will add it to new messages itself). Defaults to True.
+ filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
+ If None, no filters will be applied.
+ exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
+ excluded from compression. If False, messages that match the filter will be compressed.
+ """
+
+ assert isinstance(position, str) and position is not None
+ assert position in ["start", "end"]
+ assert isinstance(format_string, str) and format_string is not None
+ assert "{name}" in format_string
+ assert isinstance(deduplicate, bool) and deduplicate is not None
+
+ self._position = position
+ self._format_string = format_string
+ self._deduplicate = deduplicate
+ self._filter_dict = filter_dict
+ self._exclude_filter = exclude_filter
+
+ # Track the number of messages changed for logging
+ self._messages_changed = 0
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies the name change to the message based on the position and format string.
+
+ Args:
+ messages (List[Dict]): A list of message dictionaries.
+
+ Returns:
+ List[Dict]: A list of dictionaries with the message content updated with names.
+ """
+ # Make sure there is at least one message
+ if not messages:
+ return messages
+
+ messages_changed = 0
+ processed_messages = copy.deepcopy(messages)
+ for message in processed_messages:
+ # Some messages may not have content.
+ if not transforms_util.is_content_right_type(
+ message.get("content")
+ ) or not transforms_util.is_content_right_type(message.get("name")):
+ continue
+
+ if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
+ continue
+
+ if transforms_util.is_content_text_empty(message["content"]) or transforms_util.is_content_text_empty(
+ message["name"]
+ ):
+ continue
+
+ # Get and format the name in the content
+ content = message["content"]
+ formatted_name = self._format_string.format(name=message["name"])
+
+ if self._position == "start":
+ if not self._deduplicate or not content.startswith(formatted_name):
+ message["content"] = f"{formatted_name}{content}"
+
+ messages_changed += 1
+ else:
+ if not self._deduplicate or not content.endswith(formatted_name):
+ message["content"] = f"{content}{formatted_name}"
+
+ messages_changed += 1
+
+ self._messages_changed = messages_changed
+ return processed_messages
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ if self._messages_changed > 0:
+ return f"{self._messages_changed} message(s) changed to incorporate name.", True
+ else:
+ return "No messages changed to incorporate name.", False
diff --git a/autogen/agentchat/contrib/compressible_agent.py b/autogen/agentchat/contrib/compressible_agent.py
deleted file mode 100644
index bea4058b94ac..000000000000
--- a/autogen/agentchat/contrib/compressible_agent.py
+++ /dev/null
@@ -1,436 +0,0 @@
-import copy
-import inspect
-import logging
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
-from warnings import warn
-
-from autogen import Agent, ConversableAgent, OpenAIWrapper
-from autogen.token_count_utils import count_token, get_max_token_limit, num_tokens_from_functions
-
-from ...formatting_utils import colored
-
-logger = logging.getLogger(__name__)
-
-warn(
- "Context handling with CompressibleAgent is deprecated and will be removed in `0.2.30`. "
- "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/topics/handling_long_contexts/intro_to_transform_messages",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CompressibleAgent(ConversableAgent):
- """CompressibleAgent agent. While this agent retains all the default functionalities of the `AssistantAgent`,
- it also provides the added feature of compression when activated through the `compress_config` setting.
-
- `compress_config` is set to False by default, making this agent equivalent to the `AssistantAgent`.
- This agent does not work well in a GroupChat: The compressed messages will not be sent to all the agents in the group.
- The default system message is the same as AssistantAgent.
- `human_input_mode` is default to "NEVER"
- and `code_execution_config` is default to False.
- This agent doesn't execute code or function call by default.
- """
-
- DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant.
-Solve tasks using your coding and language skills.
-In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
- 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
- 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
-Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
-When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
-If you want the user to save the code in a file before executing it, put # filename: inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
-If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
-When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
-Reply "TERMINATE" in the end when everything is done.
- """
- DEFAULT_COMPRESS_CONFIG = {
- "mode": "TERMINATE",
- "compress_function": None,
- "trigger_count": 0.7,
- "async": False,
- "broadcast": True,
- "verbose": False,
- "leave_last_n": 2,
- }
-
- def __init__(
- self,
- name: str,
- system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
- is_termination_msg: Optional[Callable[[Dict], bool]] = None,
- max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
- function_map: Optional[Dict[str, Callable]] = None,
- code_execution_config: Optional[Union[Dict, bool]] = False,
- llm_config: Optional[Union[Dict, bool]] = None,
- default_auto_reply: Optional[Union[str, Dict, None]] = "",
- compress_config: Optional[Dict] = False,
- description: Optional[str] = None,
- **kwargs,
- ):
- """
- Args:
- name (str): agent name.
- system_message (str): system message for the ChatCompletion inference.
- Please override this attribute if you want to reprogram the agent.
- llm_config (dict): llm inference configuration.
- Note: you must set `model` in llm_config. It will be used to compute the token count.
- Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
- for available options.
- is_termination_msg (function): a function that takes a message in the form of a dictionary
- and returns a boolean value indicating if this received message is a termination message.
- The dict can contain the following keys: "content", "role", "name", "function_call".
- max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
- default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
- The limit only plays a role when human_input_mode is not "ALWAYS".
- compress_config (dict or True/False): config for compression before oai_reply. Default to False.
- You should contain the following keys:
- - "mode" (Optional, str, default to "TERMINATE"): Choose from ["COMPRESS", "TERMINATE", "CUSTOMIZED"].
- 1. `TERMINATE`: terminate the conversation ONLY when token count exceeds the max limit of current model. `trigger_count` is NOT used in this mode.
- 2. `COMPRESS`: compress the messages when the token count exceeds the limit.
- 3. `CUSTOMIZED`: pass in a customized function to compress the messages.
- - "compress_function" (Optional, callable, default to None): Must be provided when mode is "CUSTOMIZED".
- The function should takes a list of messages and returns a tuple of (is_compress_success: bool, compressed_messages: List[Dict]).
- - "trigger_count" (Optional, float, int, default to 0.7): the threshold to trigger compression.
- If a float between (0, 1], it is the percentage of token used. if a int, it is the number of tokens used.
- - "async" (Optional, bool, default to False): whether to compress asynchronously.
- - "broadcast" (Optional, bool, default to True): whether to update the compressed message history to sender.
- - "verbose" (Optional, bool, default to False): Whether to print the content before and after compression. Used when mode="COMPRESS".
- - "leave_last_n" (Optional, int, default to 0): If provided, the last n messages will not be compressed. Used when mode="COMPRESS".
- description (str): a short description of the agent. This description is used by other agents
- (e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
- **kwargs (dict): Please refer to other kwargs in
- [ConversableAgent](../conversable_agent#__init__).
- """
- super().__init__(
- name=name,
- system_message=system_message,
- is_termination_msg=is_termination_msg,
- max_consecutive_auto_reply=max_consecutive_auto_reply,
- human_input_mode=human_input_mode,
- function_map=function_map,
- code_execution_config=code_execution_config,
- llm_config=llm_config,
- default_auto_reply=default_auto_reply,
- description=description,
- **kwargs,
- )
-
- self._set_compress_config(compress_config)
-
- # create a separate client for compression.
- if llm_config is False:
- self.llm_compress_config = False
- self.compress_client = None
- else:
- if "model" not in llm_config:
- raise ValueError("llm_config must contain the 'model' field.")
- self.llm_compress_config = self.llm_config.copy()
- # remove functions
- if "functions" in self.llm_compress_config:
- del self.llm_compress_config["functions"]
- self.compress_client = OpenAIWrapper(**self.llm_compress_config)
-
- self._reply_func_list.clear()
- self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
- self.register_reply([Agent], CompressibleAgent.on_oai_token_limit) # check token limit
- self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
- self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
- self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
-
- def _set_compress_config(self, compress_config: Optional[Dict] = False):
- if compress_config:
- if compress_config is True:
- compress_config = {}
- if not isinstance(compress_config, dict):
- raise ValueError("compress_config must be a dict or True/False.")
-
- allowed_modes = ["COMPRESS", "TERMINATE", "CUSTOMIZED"]
- if compress_config.get("mode", "TERMINATE") not in allowed_modes:
- raise ValueError(f"Invalid compression mode. Allowed values are: {', '.join(allowed_modes)}")
-
- self.compress_config = self.DEFAULT_COMPRESS_CONFIG.copy()
- self.compress_config.update(compress_config)
-
- if not isinstance(self.compress_config["leave_last_n"], int) or self.compress_config["leave_last_n"] < 0:
- raise ValueError("leave_last_n must be a non-negative integer.")
-
- # convert trigger_count to int, default to 0.7
- trigger_count = self.compress_config["trigger_count"]
- if not (isinstance(trigger_count, int) or isinstance(trigger_count, float)) or trigger_count <= 0:
- raise ValueError("trigger_count must be a positive number.")
- if isinstance(trigger_count, float) and 0 < trigger_count <= 1:
- self.compress_config["trigger_count"] = int(
- trigger_count * get_max_token_limit(self.llm_config["model"])
- )
- trigger_count = self.compress_config["trigger_count"]
- init_count = self._compute_init_token_count()
- if trigger_count < init_count:
- print(
- f"Warning: trigger_count {trigger_count} is less than the initial token count {init_count} (system message + function description if passed), compression will be disabled. Please increase trigger_count if you want to enable compression."
- )
- self.compress_config = False
-
- if self.compress_config["mode"] == "CUSTOMIZED" and self.compress_config["compress_function"] is None:
- raise ValueError("compress_function must be provided when mode is CUSTOMIZED.")
- if self.compress_config["mode"] != "CUSTOMIZED" and self.compress_config["compress_function"] is not None:
- print("Warning: compress_function is provided but mode is not 'CUSTOMIZED'.")
-
- else:
- self.compress_config = False
-
- def generate_reply(
- self,
- messages: Optional[List[Dict]] = None,
- sender: Optional[Agent] = None,
- exclude: Optional[List[Callable]] = None,
- ) -> Union[str, Dict, None]:
- """
-
- Adding to line 202:
- ```
- if messages is not None and messages != self._oai_messages[sender]:
- messages = self._oai_messages[sender]
- ```
- """
- if all((messages is None, sender is None)):
- error_msg = f"Either {messages=} or {sender=} must be provided."
- logger.error(error_msg)
- raise AssertionError(error_msg)
-
- if messages is None:
- messages = self._oai_messages[sender]
-
- for reply_func_tuple in self._reply_func_list:
- reply_func = reply_func_tuple["reply_func"]
- if exclude and reply_func in exclude:
- continue
- if inspect.iscoroutinefunction(reply_func):
- continue
- if self._match_trigger(reply_func_tuple["trigger"], sender):
- final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
- if messages is not None and sender is not None and messages != self._oai_messages[sender]:
- messages = self._oai_messages[sender]
- if final:
- return reply
- return self._default_auto_reply
-
- def _compute_init_token_count(self):
- """Check if the agent is LLM-based and compute the initial token count."""
- if self.llm_config is False:
- return 0
-
- func_count = 0
- if "functions" in self.llm_config:
- func_count = num_tokens_from_functions(self.llm_config["functions"], self.llm_config["model"])
-
- return func_count + count_token(self._oai_system_message, self.llm_config["model"])
-
- def _manage_history_on_token_limit(self, messages, token_used, max_token_allowed, model):
- """Manage the message history with different modes when token limit is reached.
- Return:
- final (bool): whether to terminate the agent.
- compressed_messages (List[Dict]): the compressed messages. None if no compression or compression failed.
- """
- # 1. mode = "TERMINATE", terminate the agent if no token left.
- if self.compress_config["mode"] == "TERMINATE":
- if max_token_allowed - token_used <= 0:
- # Terminate if no token left.
- print(
- colored(
- f'Warning: Terminate Agent "{self.name}" due to no token left for oai reply. max token for {model}: {max_token_allowed}, existing token count: {token_used}',
- "yellow",
- ),
- flush=True,
- )
- return True, None
- return False, None
-
- # if token_used is less than trigger_count, no compression will be used.
- if token_used < self.compress_config["trigger_count"]:
- return False, None
-
- # 2. mode = "COMPRESS" or mode = "CUSTOMIZED", compress the messages
- copied_messages = copy.deepcopy(messages)
- if self.compress_config["mode"] == "COMPRESS":
- _, compress_messages = self.compress_messages(copied_messages)
- elif self.compress_config["mode"] == "CUSTOMIZED":
- _, compress_messages = self.compress_config["compress_function"](copied_messages)
- else:
- raise ValueError(f"Unknown compression mode: {self.compress_config['mode']}")
-
- if compress_messages is not None:
- for i in range(len(compress_messages)):
- compress_messages[i] = self._get_valid_oai_message(compress_messages[i])
- return False, compress_messages
-
- def _get_valid_oai_message(self, message):
- """Convert a message into a valid OpenAI ChatCompletion message."""
- oai_message = {k: message[k] for k in ("content", "function_call", "name", "context", "role") if k in message}
- if "content" not in oai_message:
- if "function_call" in oai_message:
- oai_message["content"] = None # if only function_call is provided, content will be set to None.
- else:
- raise ValueError(
- "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
- )
- if "function_call" in oai_message:
- oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call.
- oai_message["function_call"] = dict(oai_message["function_call"])
- return oai_message
-
- def _print_compress_info(self, init_token_count, token_used, token_after_compression):
- to_print = "Token Count (including {} tokens from system msg and function descriptions). Before compression : {} | After: {}".format(
- init_token_count,
- token_used,
- token_after_compression,
- )
- print(colored(to_print, "magenta"), flush=True)
- print("-" * 80, flush=True)
-
- def on_oai_token_limit(
- self,
- messages: Optional[List[Dict]] = None,
- sender: Optional[Agent] = None,
- config: Optional[Any] = None,
- ) -> Tuple[bool, Union[str, Dict, None]]:
- """(Experimental) Compress previous messages when a threshold of tokens is reached.
-
- TODO: async compress
- TODO: maintain a list for old oai messages (messages before compression)
- """
- llm_config = self.llm_config if config is None else config
- if self.compress_config is False:
- return False, None
- if messages is None:
- messages = self._oai_messages[sender]
-
- model = llm_config["model"]
- init_token_count = self._compute_init_token_count()
- token_used = init_token_count + count_token(messages, model)
- final, compressed_messages = self._manage_history_on_token_limit(
- messages, token_used, get_max_token_limit(model), model
- )
-
- # update message history with compressed messages
- if compressed_messages is not None:
- self._print_compress_info(
- init_token_count, token_used, count_token(compressed_messages, model) + init_token_count
- )
- self._oai_messages[sender] = compressed_messages
- if self.compress_config["broadcast"]:
- # update the compressed message history to sender
- sender._oai_messages[self] = copy.deepcopy(compressed_messages)
- # switching the role of the messages for the sender
- for i in range(len(sender._oai_messages[self])):
- cmsg = sender._oai_messages[self][i]
- if "function_call" in cmsg or cmsg["role"] == "user":
- cmsg["role"] = "assistant"
- elif cmsg["role"] == "assistant":
- cmsg["role"] = "user"
- sender._oai_messages[self][i] = cmsg
-
- # successfully compressed, return False, None for generate_oai_reply to be called with the updated messages
- return False, None
- return final, None
-
- def compress_messages(
- self,
- messages: Optional[List[Dict]] = None,
- config: Optional[Any] = None,
- ) -> Tuple[bool, Union[str, Dict, None, List]]:
- """Compress a list of messages into one message.
-
- The first message (the initial prompt) will not be compressed.
- The rest of the messages will be compressed into one message, the model is asked to distinguish the role of each message: USER, ASSISTANT, FUNCTION_CALL, FUNCTION_RETURN.
- Check out the compress_sys_msg.
-
- TODO: model used in compression agent is different from assistant agent: For example, if original model used by is gpt-4; we start compressing at 70% of usage, 70% of 8092 = 5664; and we use gpt 3.5 here max_toke = 4096, it will raise error. choosinng model automatically?
- """
- # 1. use the compression client
- client = self.compress_client if config is None else config
-
- # 2. stop if there is only one message in the list
- leave_last_n = self.compress_config.get("leave_last_n", 0)
- if leave_last_n + 1 >= len(messages):
- logger.warning(
- f"Warning: Compression skipped at trigger count threshold. The first msg and last {leave_last_n} msgs will not be compressed. current msg count: {len(messages)}. Consider raising trigger_count."
- )
- return False, None
-
- # 3. put all history into one, except the first one
- if self.compress_config["verbose"]:
- print(colored("*" * 30 + "Start compressing the following content:" + "*" * 30, "magenta"), flush=True)
-
- compressed_prompt = "Below is the compressed content from the previous conversation, evaluate the process and continue if necessary:\n"
- chat_to_compress = "To be compressed:\n"
-
- for m in messages[1 : len(messages) - leave_last_n]: # 0, 1, 2, 3, 4
- # Handle function role
- if m.get("role") == "function":
- chat_to_compress += f"##FUNCTION_RETURN## (from function \"{m['name']}\"): \n{m['content']}\n"
-
- # If name exists in the message
- elif "name" in m:
- chat_to_compress += f"##{m['name']}({m['role'].upper()})## {m['content']}\n"
-
- # Handle case where content is not None and name is absent
- elif m.get("content"): # This condition will also handle None and empty string
- if compressed_prompt in m["content"]:
- chat_to_compress += m["content"].replace(compressed_prompt, "") + "\n"
- else:
- chat_to_compress += f"##{m['role'].upper()}## {m['content']}\n"
-
- # Handle function_call in the message
- if "function_call" in m:
- function_name = m["function_call"].get("name")
- function_args = m["function_call"].get("arguments")
-
- if not function_name or not function_args:
- chat_to_compress += f"##FUNCTION_CALL## {m['function_call']}\n"
- else:
- chat_to_compress += f"##FUNCTION_CALL## \nName: {function_name}\nArgs: {function_args}\n"
-
- chat_to_compress = [{"role": "user", "content": chat_to_compress}]
-
- if self.compress_config["verbose"]:
- print(chat_to_compress[0]["content"])
-
- # 4. use LLM to compress
- compress_sys_msg = """You are a helpful assistant that will summarize and compress conversation history.
-Rules:
-1. Please summarize each of the message and reserve the exact titles: ##USER##, ##ASSISTANT##, ##FUNCTION_CALL##, ##FUNCTION_RETURN##, ##SYSTEM##, ##()## (e.g. ##Bob(ASSISTANT)##).
-2. Try to compress the content but reserve important information (a link, a specific number, etc.).
-3. Use words to summarize the code blocks or functions calls (##FUNCTION_CALL##) and their goals. For code blocks, please use ##CODE## to mark it.
-4. For returns from functions (##FUNCTION_RETURN##) or returns from code execution: summarize the content and indicate the status of the return (e.g. success, error, etc.).
-"""
- try:
- response = client.create(
- context=None,
- messages=[{"role": "system", "content": compress_sys_msg}] + chat_to_compress,
- )
- except Exception as e:
- print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
- return False, None
-
- compressed_message = self.client.extract_text_or_completion_object(response)[0]
- assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
- if self.compress_config["verbose"]:
- print(
- colored("*" * 30 + "Content after compressing:" + "*" * 30, "magenta"),
- flush=True,
- )
- print(compressed_message, colored("\n" + "*" * 80, "magenta"))
-
- # 5. add compressed message to the first message and return
- return (
- True,
- [
- messages[0],
- {
- "content": compressed_prompt + compressed_message,
- "role": "system",
- },
- ]
- + messages[len(messages) - leave_last_n :],
- )
diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py
index 0dcad27b16d5..244f5ed81894 100644
--- a/autogen/agentchat/contrib/gpt_assistant_agent.py
+++ b/autogen/agentchat/contrib/gpt_assistant_agent.py
@@ -209,10 +209,12 @@ def _invoke_assistant(
for message in pending_messages:
if message["content"].strip() == "":
continue
+ # Convert message roles to 'user' or 'assistant', by calling _map_role_for_api, to comply with OpenAI API spec
+ api_role = self._map_role_for_api(message["role"])
self._openai_client.beta.threads.messages.create(
thread_id=assistant_thread.id,
content=message["content"],
- role=message["role"],
+ role=api_role,
)
# Create a new run to get responses from the assistant
@@ -240,6 +242,28 @@ def _invoke_assistant(
self._unread_index[sender] = len(self._oai_messages[sender]) + 1
return True, response
+ def _map_role_for_api(self, role: str) -> str:
+ """
+ Maps internal message roles to the roles expected by the OpenAI Assistant API.
+
+ Args:
+ role (str): The role from the internal message.
+
+ Returns:
+ str: The mapped role suitable for the API.
+ """
+ if role in ["function", "tool"]:
+ return "assistant"
+ elif role == "system":
+ return "system"
+ elif role == "user":
+ return "user"
+ elif role == "assistant":
+ return "assistant"
+ else:
+ # Default to 'assistant' for any other roles not recognized by the API
+ return "assistant"
+
def _get_run_response(self, thread, run):
"""
Waits for and processes the response of a run from the OpenAI assistant.
diff --git a/autogen/agentchat/contrib/graph_rag/__init__.py b/autogen/agentchat/contrib/graph_rag/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/autogen/agentchat/contrib/graph_rag/document.py b/autogen/agentchat/contrib/graph_rag/document.py
new file mode 100644
index 000000000000..9730269c7ab6
--- /dev/null
+++ b/autogen/agentchat/contrib/graph_rag/document.py
@@ -0,0 +1,24 @@
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import Optional
+
+
+class DocumentType(Enum):
+ """
+ Enum for supporting document type.
+ """
+
+ TEXT = auto()
+ HTML = auto()
+ PDF = auto()
+
+
+@dataclass
+class Document:
+ """
+ A wrapper of graph store query results.
+ """
+
+ doctype: DocumentType
+ data: Optional[object] = None
+ path_or_url: Optional[str] = ""
diff --git a/autogen/agentchat/contrib/graph_rag/graph_query_engine.py b/autogen/agentchat/contrib/graph_rag/graph_query_engine.py
new file mode 100644
index 000000000000..28ef6ede84a6
--- /dev/null
+++ b/autogen/agentchat/contrib/graph_rag/graph_query_engine.py
@@ -0,0 +1,51 @@
+from dataclasses import dataclass, field
+from typing import List, Optional, Protocol
+
+from .document import Document
+
+
+@dataclass
+class GraphStoreQueryResult:
+ """
+ A wrapper of graph store query results.
+
+ answer: human readable answer to question/query.
+ results: intermediate results to question/query, e.g. node entities.
+ """
+
+ answer: Optional[str] = None
+ results: list = field(default_factory=list)
+
+
+class GraphQueryEngine(Protocol):
+ """An abstract base class that represents a graph query engine on top of a underlying graph database.
+
+ This interface defines the basic methods for graph rag.
+ """
+
+ def init_db(self, input_doc: List[Document] | None = None):
+ """
+ This method initializes graph database with the input documents or records.
+ Usually, it takes the following steps,
+ 1. connecting to a graph database.
+ 2. extract graph nodes, edges based on input data, graph schema and etc.
+ 3. build indexes etc.
+
+ Args:
+ input_doc: a list of input documents that are used to build the graph in database.
+
+ Returns: GraphStore
+ """
+ pass
+
+ def add_records(self, new_records: List) -> bool:
+ """
+ Add new records to the underlying database and add to the graph if required.
+ """
+ pass
+
+ def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult:
+ """
+ This method transform a string format question into database query and return the result.
+ """
+ pass
diff --git a/autogen/agentchat/contrib/graph_rag/graph_rag_capability.py b/autogen/agentchat/contrib/graph_rag/graph_rag_capability.py
new file mode 100644
index 000000000000..b6412305e069
--- /dev/null
+++ b/autogen/agentchat/contrib/graph_rag/graph_rag_capability.py
@@ -0,0 +1,56 @@
+from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+from .graph_query_engine import GraphQueryEngine
+
+
+class GraphRagCapability(AgentCapability):
+ """
+ A graph rag capability uses a graph query engine to give a conversable agent the graph rag ability.
+
+ An agent class with graph rag capability could
+ 1. create a graph in the underlying database with input documents.
+ 2. retrieved relevant information based on messages received by the agent.
+ 3. generate answers from retrieved information and send messages back.
+
+ For example,
+ graph_query_engine = GraphQueryEngine(...)
+ graph_query_engine.init_db([Document(doc1), Document(doc2), ...])
+
+ graph_rag_agent = ConversableAgent(
+ name="graph_rag_agent",
+ max_consecutive_auto_reply=3,
+ ...
+ )
+ graph_rag_capability = GraphRagCapbility(graph_query_engine)
+ graph_rag_capability.add_to_agent(graph_rag_agent)
+
+ user_proxy = UserProxyAgent(
+ name="user_proxy",
+ code_execution_config=False,
+ is_termination_msg=lambda msg: "TERMINATE" in msg["content"],
+ human_input_mode="ALWAYS",
+ )
+ user_proxy.initiate_chat(graph_rag_agent, message="Name a few actors who've played in 'The Matrix'")
+
+ # ChatResult(
+ # chat_id=None,
+ # chat_history=[
+ # {'content': 'Name a few actors who've played in \'The Matrix\'', 'role': 'graph_rag_agent'},
+ # {'content': 'A few actors who have played in The Matrix are:
+ # - Keanu Reeves
+ # - Laurence Fishburne
+ # - Carrie-Anne Moss
+ # - Hugo Weaving',
+ # 'role': 'user_proxy'},
+ # ...)
+
+ """
+
+ def __init__(self, query_engine: GraphQueryEngine):
+ """
+ initialize graph rag capability with a graph query engine
+ """
+ ...
+
+ def add_to_agent(self, agent: ConversableAgent): ...
diff --git a/autogen/agentchat/contrib/llamaindex_conversable_agent.py b/autogen/agentchat/contrib/llamaindex_conversable_agent.py
index f7a9c3e615dc..dbf6f274ae87 100644
--- a/autogen/agentchat/contrib/llamaindex_conversable_agent.py
+++ b/autogen/agentchat/contrib/llamaindex_conversable_agent.py
@@ -8,15 +8,14 @@
try:
from llama_index.core.agent.runner.base import AgentRunner
+ from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.chat_engine.types import AgentChatResponse
- from llama_index_client import ChatMessage
except ImportError as e:
logger.fatal("Failed to import llama-index. Try running 'pip install llama-index'")
raise e
class LLamaIndexConversableAgent(ConversableAgent):
-
def __init__(
self,
name: str,
diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
index ea81de6dff11..f1cc6947d50e 100644
--- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
@@ -1,3 +1,4 @@
+import warnings
from typing import Callable, Dict, List, Literal, Optional
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
@@ -93,6 +94,11 @@ def __init__(
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
"""
+ warnings.warn(
+ "The QdrantRetrieveUserProxyAgent is deprecated. Please use the RetrieveUserProxyAgent instead, set `vector_db` to `qdrant`.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__(name, human_input_mode, is_termination_msg, retrieve_config, **kwargs)
self._client = self._retrieve_config.get("client", QdrantClient(":memory:"))
self._embedding_model = self._retrieve_config.get("embedding_model", "BAAI/bge-small-en-v1.5")
diff --git a/autogen/agentchat/contrib/retrieve_assistant_agent.py b/autogen/agentchat/contrib/retrieve_assistant_agent.py
index 9b5ace200dc6..173bc4432e78 100644
--- a/autogen/agentchat/contrib/retrieve_assistant_agent.py
+++ b/autogen/agentchat/contrib/retrieve_assistant_agent.py
@@ -1,3 +1,4 @@
+import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from autogen.agentchat.agent import Agent
@@ -16,6 +17,11 @@ class RetrieveAssistantAgent(AssistantAgent):
"""
def __init__(self, *args, **kwargs):
+ warnings.warn(
+ "The RetrieveAssistantAgent is deprecated. Please use the AssistantAgent instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__(*args, **kwargs)
self.register_reply(Agent, RetrieveAssistantAgent._generate_retrieve_assistant_reply)
diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
index 59a4abccb1d6..b247d7a158f6 100644
--- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
@@ -1,6 +1,7 @@
import hashlib
import os
import re
+import uuid
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from IPython import get_ipython
@@ -135,7 +136,7 @@ def __init__(
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
default client `chromadb.Client()` will be used. If you want to use other
vector db, extend this class and override the `retrieve_docs` function.
- **Deprecated**: use `vector_db` instead.
+ *[Deprecated]* use `vector_db` instead.
- `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It
can also be the path to a single file, the url to a single file or a list
of directories, files and urls. Default is None, which works only if the
@@ -149,7 +150,7 @@ def __init__(
By default, "extra_docs" is set to false, starting document IDs from zero.
This poses a risk as new documents might overwrite existing ones, potentially
causing unintended loss or alteration of data in the collection.
- **Deprecated**: use `new_docs` when use `vector_db` instead of `client`.
+ *[Deprecated]* use `new_docs` when use `vector_db` instead of `client`.
- `new_docs` (Optional, bool) - when True, only adds new documents to the collection;
when False, updates existing documents and adds new ones. Default is True.
Document id is used to determine if a document is new or existing. By default, the
@@ -172,12 +173,12 @@ def __init__(
models can be found at `https://www.sbert.net/docs/pretrained_models.html`.
The default model is a fast model. If you want to use a high performance model,
`all-mpnet-base-v2` is recommended.
- **Deprecated**: no need when use `vector_db` instead of `client`.
+ *[Deprecated]* no need when use `vector_db` instead of `client`.
- `embedding_function` (Optional, Callable) - the embedding function for creating the
vector db. Default is None, SentenceTransformer with the given `embedding_model`
will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
functions, you can pass it here,
- follow the examples in `https://docs.trychroma.com/embeddings`.
+ follow the examples in `https://docs.trychroma.com/guides/embeddings`.
- `customized_prompt` (Optional, str) - the customized prompt for the retrieve chat.
Default is None.
- `customized_answer_prefix` (Optional, str) - the customized answer prefix for the
@@ -188,7 +189,7 @@ def __init__(
interactive retrieval. Default is True.
- `collection_name` (Optional, str) - the name of the collection.
If key not provided, a default name `autogen-docs` will be used.
- - `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is True.
+ - `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is False.
- `overwrite` (Optional, bool) - Whether to overwrite the collection if it exists. Default is False.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
@@ -219,7 +220,7 @@ def __init__(
Example of overriding retrieve_docs - If you have set up a customized vector db, and it's
not compatible with chromadb, you can easily plug in it with below code.
- **Deprecated**: Use `vector_db` instead. You can extend VectorDB and pass it to the agent.
+ *[Deprecated]* use `vector_db` instead. You can extend VectorDB and pass it to the agent.
```python
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
@@ -305,6 +306,10 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._db_config["embedding_function"] = self._embedding_function
self._vector_db = VectorDBFactory.create_vector_db(db_type=self._vector_db, **self._db_config)
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=2)
+ self.register_hook(
+ hookable_method="process_message_before_send",
+ hook=self._check_update_context_before_send,
+ )
def _init_db(self):
if not self._vector_db:
@@ -365,7 +370,11 @@ def _init_db(self):
else:
all_docs_ids = set()
- chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
+ chunk_ids = (
+ [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
+ if not self._vector_db.type == "qdrant"
+ else [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]
+ )
chunk_ids_set = set(chunk_ids)
chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
docs = [
@@ -395,6 +404,34 @@ def _is_termination_msg_retrievechat(self, message):
update_context_case1, update_context_case2 = self._check_update_context(message)
return not (contain_code or update_context_case1 or update_context_case2)
+ def _check_update_context_before_send(self, sender, message, recipient, silent):
+ if not isinstance(message, (str, dict)):
+ return message
+ elif isinstance(message, dict):
+ msg_text = message.get("content", message)
+ else:
+ msg_text = message
+
+ if "UPDATE CONTEXT" == msg_text.strip().upper():
+ doc_contents = self._get_context(self._results)
+
+ # Always use self.problem as the query text to retrieve docs, but each time we replace the context with the
+ # next similar docs in the retrieved doc results.
+ if not doc_contents:
+ for _tmp_retrieve_count in range(1, 5):
+ self._reset(intermediate=True)
+ self.retrieve_docs(
+ self.problem, self.n_results * (2 * _tmp_retrieve_count + 1), self._search_string
+ )
+ doc_contents = self._get_context(self._results)
+ if doc_contents or self.n_results * (2 * _tmp_retrieve_count + 1) >= len(self._results[0]):
+ break
+ msg_text = self._generate_message(doc_contents, task=self._task)
+
+ if isinstance(message, dict):
+ message["content"] = msg_text
+ return message
+
@staticmethod
def get_max_tokens(model="gpt-3.5-turbo"):
if "32k" in model:
@@ -514,7 +551,7 @@ def _generate_retrieve_user_reply(
self.problem, self.n_results * (2 * _tmp_retrieve_count + 1), self._search_string
)
doc_contents = self._get_context(self._results)
- if doc_contents:
+ if doc_contents or self.n_results * (2 * _tmp_retrieve_count + 1) >= len(self._results[0]):
break
elif update_context_case2:
# Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar
@@ -526,7 +563,7 @@ def _generate_retrieve_user_reply(
)
self._get_context(self._results)
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers)
- if doc_contents:
+ if doc_contents or self.n_results * (2 * _tmp_retrieve_count + 1) >= len(self._results[0]):
break
self.clear_history()
diff --git a/autogen/agentchat/contrib/society_of_mind_agent.py b/autogen/agentchat/contrib/society_of_mind_agent.py
index 2f6be5088a4d..e76768187c9f 100644
--- a/autogen/agentchat/contrib/society_of_mind_agent.py
+++ b/autogen/agentchat/contrib/society_of_mind_agent.py
@@ -39,6 +39,7 @@ def __init__(
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = False,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
+ **kwargs,
):
super().__init__(
name=name,
@@ -50,6 +51,7 @@ def __init__(
code_execution_config=code_execution_config,
llm_config=llm_config,
default_auto_reply=default_auto_reply,
+ **kwargs,
)
self.update_chat_manager(chat_manager)
diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py
index 29a080086193..d7d49d6200ca 100644
--- a/autogen/agentchat/contrib/vectordb/base.py
+++ b/autogen/agentchat/contrib/vectordb/base.py
@@ -1,4 +1,16 @@
-from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable
+from typing import (
+ Any,
+ Callable,
+ List,
+ Mapping,
+ Optional,
+ Protocol,
+ Sequence,
+ Tuple,
+ TypedDict,
+ Union,
+ runtime_checkable,
+)
Metadata = Union[Mapping[str, Any], None]
Vector = Union[Sequence[float], Sequence[int]]
@@ -49,6 +61,9 @@ class VectorDB(Protocol):
active_collection: Any = None
type: str = ""
+ embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
+ None # embeddings = embedding_function(sentences)
+ )
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
"""
@@ -171,7 +186,8 @@ def get_docs_by_ids(
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include. Default is None.
- If None, will include ["metadatas", "documents"], ids will always be included.
+ If None, will include ["metadatas", "documents"], ids will always be included. This may differ
+ depending on the implementation.
kwargs: dict | Additional keyword arguments.
Returns:
@@ -185,7 +201,7 @@ class VectorDBFactory:
Factory class for creating vector databases.
"""
- PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
+ PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant"]
@staticmethod
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
@@ -207,6 +223,14 @@ def create_vector_db(db_type: str, **kwargs) -> VectorDB:
from .pgvectordb import PGVectorDB
return PGVectorDB(**kwargs)
+ if db_type.lower() in ["mdb", "mongodb", "atlas"]:
+ from .mongodb import MongoDBAtlasVectorDB
+
+ return MongoDBAtlasVectorDB(**kwargs)
+ if db_type.lower() in ["qdrant", "qdrantdb"]:
+ from .qdrant import QdrantVectorDB
+
+ return QdrantVectorDB(**kwargs)
else:
raise ValueError(
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py
index 1ed8708409d3..bef4a1090219 100644
--- a/autogen/agentchat/contrib/vectordb/chromadb.py
+++ b/autogen/agentchat/contrib/vectordb/chromadb.py
@@ -14,6 +14,11 @@
except ImportError:
raise ImportError("Please install chromadb: `pip install chromadb`")
+try:
+ from chromadb.errors import ChromaError
+except ImportError:
+ ChromaError = Exception
+
CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000)
logger = get_logger(__name__)
@@ -84,7 +89,7 @@ def create_collection(
collection = self.active_collection
else:
collection = self.client.get_collection(collection_name, embedding_function=self.embedding_function)
- except ValueError:
+ except (ValueError, ChromaError):
collection = None
if collection is None:
return self.client.create_collection(
diff --git a/autogen/agentchat/contrib/vectordb/mongodb.py b/autogen/agentchat/contrib/vectordb/mongodb.py
new file mode 100644
index 000000000000..2e0580fe826b
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/mongodb.py
@@ -0,0 +1,553 @@
+from copy import deepcopy
+from time import monotonic, sleep
+from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union
+
+import numpy as np
+from pymongo import MongoClient, UpdateOne, errors
+from pymongo.collection import Collection
+from pymongo.driver_info import DriverInfo
+from pymongo.operations import SearchIndexModel
+from sentence_transformers import SentenceTransformer
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import get_logger
+
+logger = get_logger(__name__)
+
+DEFAULT_INSERT_BATCH_SIZE = 100_000
+_SAMPLE_SENTENCE = ["The weather is lovely today in paradise."]
+_DELAY = 0.5
+
+
+def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]:
+ """Utility changes _id field from Collection into id for Document."""
+ return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs]
+
+
+class MongoDBAtlasVectorDB(VectorDB):
+ """
+ A Collection object for MongoDB.
+ """
+
+ def __init__(
+ self,
+ connection_string: str = "",
+ database_name: str = "vector_db",
+ embedding_function: Callable = SentenceTransformer("all-MiniLM-L6-v2").encode,
+ collection_name: str = None,
+ index_name: str = "vector_index",
+ overwrite: bool = False,
+ wait_until_index_ready: float = None,
+ wait_until_document_ready: float = None,
+ ):
+ """
+ Initialize the vector database.
+
+ Args:
+ connection_string: str | The MongoDB connection string to connect to. Default is ''.
+ database_name: str | The name of the database. Default is 'vector_db'.
+ embedding_function: Callable | The embedding function used to generate the vector representation.
+ collection_name: str | The name of the collection to create for this vector database
+ Defaults to None
+ index_name: str | Index name for the vector database, defaults to 'vector_index'
+ overwrite: bool = False
+ wait_until_index_ready: float | None | Blocking call to wait until the
+ database indexes are ready. None, the default, means no wait.
+ wait_until_document_ready: float | None | Blocking call to wait until the
+ database indexes are ready. None, the default, means no wait.
+ """
+ self.embedding_function = embedding_function
+ self.index_name = index_name
+ self._wait_until_index_ready = wait_until_index_ready
+ self._wait_until_document_ready = wait_until_document_ready
+
+ # This will get the model dimension size by computing the embeddings dimensions
+ self.dimensions = self._get_embedding_size()
+
+ try:
+ self.client = MongoClient(connection_string, driver=DriverInfo(name="autogen"))
+ self.client.admin.command("ping")
+ logger.debug("Successfully created MongoClient")
+ except errors.ServerSelectionTimeoutError as err:
+ raise ConnectionError("Could not connect to MongoDB server") from err
+
+ self.db = self.client[database_name]
+ logger.debug(f"Atlas Database name: {self.db.name}")
+ if collection_name:
+ self.active_collection = self.create_collection(collection_name, overwrite)
+ else:
+ self.active_collection = None
+
+ def _is_index_ready(self, collection: Collection, index_name: str):
+ """Check for the index name in the list of available search indexes to see if the
+ specified index is of status READY
+
+ Args:
+ collection (Collection): MongoDB Collection to for the search indexes
+ index_name (str): Vector Search Index name
+
+ Returns:
+ bool : True if the index is present and READY false otherwise
+ """
+ for index in collection.list_search_indexes(index_name):
+ if index["type"] == "vectorSearch" and index["status"] == "READY":
+ return True
+ return False
+
+ def _wait_for_index(self, collection: Collection, index_name: str, action: str = "create"):
+ """Waits for the index action to be completed. Otherwise throws a TimeoutError.
+
+ Timeout set on instantiation.
+ action: "create" or "delete"
+ """
+ assert action in ["create", "delete"], f"{action=} must be create or delete."
+ start = monotonic()
+ while monotonic() - start < self._wait_until_index_ready:
+ if action == "create" and self._is_index_ready(collection, index_name):
+ return
+ elif action == "delete" and len(list(collection.list_search_indexes())) == 0:
+ return
+ sleep(_DELAY)
+
+ raise TimeoutError(f"Index {self.index_name} is not ready!")
+
+ def _wait_for_document(self, collection: Collection, index_name: str, doc: Document):
+ start = monotonic()
+ while monotonic() - start < self._wait_until_document_ready:
+ query_result = _vector_search(
+ embedding_vector=np.array(self.embedding_function(doc["content"])).tolist(),
+ n_results=1,
+ collection=collection,
+ index_name=index_name,
+ )
+ if query_result and query_result[0][0]["_id"] == doc["id"]:
+ return
+ sleep(_DELAY)
+
+ raise TimeoutError(f"Document {self.index_name} is not ready!")
+
+ def _get_embedding_size(self):
+ return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
+
+ def list_collections(self):
+ """
+ List the collections in the vector database.
+
+ Returns:
+ List[str] | The list of collections.
+ """
+ return self.db.list_collection_names()
+
+ def create_collection(
+ self,
+ collection_name: str,
+ overwrite: bool = False,
+ get_or_create: bool = True,
+ ) -> Collection:
+ """
+ Create a collection in the vector database and create a vector search index in the collection.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get or create the collection. Default is True
+ """
+ if overwrite:
+ self.delete_collection(collection_name)
+
+ if collection_name not in self.db.list_collection_names():
+ # Create a new collection
+ coll = self.db.create_collection(collection_name)
+ self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
+ return coll
+
+ if get_or_create:
+ # The collection already exists, return it.
+ coll = self.db[collection_name]
+ self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
+ return coll
+ else:
+ # get_or_create is False and the collection already exists, raise an error.
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def create_index_if_not_exists(self, index_name: str = "vector_index", collection: Collection = None) -> None:
+ """
+ Creates a vector search index on the specified collection in MongoDB.
+
+ Args:
+ MONGODB_INDEX (str, optional): The name of the vector search index to create. Defaults to "vector_search_index".
+ collection (Collection, optional): The MongoDB collection to create the index on. Defaults to None.
+ """
+ if not self._is_index_ready(collection, index_name):
+ self.create_vector_search_index(collection, index_name)
+
+ def get_collection(self, collection_name: str = None) -> Collection:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Collection | The collection object.
+ """
+ if collection_name is None:
+ if self.active_collection is None:
+ raise ValueError("No collection is specified.")
+ else:
+ logger.debug(
+ f"No collection is specified. Using current active collection {self.active_collection.name}."
+ )
+ else:
+ self.active_collection = self.db[collection_name]
+
+ return self.active_collection
+
+ def delete_collection(self, collection_name: str) -> None:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+ """
+ for index in self.db[collection_name].list_search_indexes():
+ self.db[collection_name].drop_search_index(index["name"])
+ if self._wait_until_index_ready:
+ self._wait_for_index(self.db[collection_name], index["name"], "delete")
+ return self.db[collection_name].drop()
+
+ def create_vector_search_index(
+ self,
+ collection: Collection,
+ index_name: Union[str, None] = "vector_index",
+ similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine",
+ ) -> None:
+ """Create a vector search index in the collection.
+
+ Args:
+ collection: An existing Collection in the Atlas Database.
+ index_name: Vector Search Index name.
+ similarity: Algorithm used for measuring vector similarity.
+ kwargs: Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ search_index_model = SearchIndexModel(
+ definition={
+ "fields": [
+ {
+ "type": "vector",
+ "numDimensions": self.dimensions,
+ "path": "embedding",
+ "similarity": similarity,
+ },
+ ]
+ },
+ name=index_name,
+ type="vectorSearch",
+ )
+ # Create the search index
+ try:
+ collection.create_search_index(model=search_index_model)
+ if self._wait_until_index_ready:
+ self._wait_for_index(collection, index_name, "create")
+ logger.debug(f"Search index {index_name} created successfully.")
+ except Exception as e:
+ logger.error(
+ f"Error creating search index: {e}. \n"
+ f"Your client must be connected to an Atlas cluster. "
+ f"You may have to manually create a Collection and Search Index "
+ f"if you are on a free/shared cluster."
+ )
+ raise e
+
+ def insert_docs(
+ self,
+ docs: List[Document],
+ collection_name: str = None,
+ upsert: bool = False,
+ batch_size=DEFAULT_INSERT_BATCH_SIZE,
+ **kwargs,
+ ) -> None:
+ """Insert Documents and Vector Embeddings into the collection of the vector database.
+
+ For large numbers of Documents, insertion is performed in batches.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ batch_size: Number of documents to be inserted in each batch
+ """
+ if not docs:
+ logger.info("No documents to insert.")
+ return
+
+ collection = self.get_collection(collection_name)
+ if upsert:
+ self.update_docs(docs, collection.name, upsert=True)
+ else:
+ # Sanity checking the first document
+ if docs[0].get("content") is None:
+ raise ValueError("The document content is required.")
+ if docs[0].get("id") is None:
+ raise ValueError("The document id is required.")
+
+ input_ids = set()
+ result_ids = set()
+ id_batch = []
+ text_batch = []
+ metadata_batch = []
+ size = 0
+ i = 0
+ for doc in docs:
+ id = doc["id"]
+ text = doc["content"]
+ metadata = doc.get("metadata", {})
+ id_batch.append(id)
+ text_batch.append(text)
+ metadata_batch.append(metadata)
+ id_size = 1 if isinstance(id, int) else len(id)
+ size += len(text) + len(metadata) + id_size
+ if (i + 1) % batch_size == 0 or size >= 47_000_000:
+ result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
+ input_ids.update(id_batch)
+ id_batch = []
+ text_batch = []
+ metadata_batch = []
+ size = 0
+ i += 1
+ if text_batch:
+ result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) # type: ignore
+ input_ids.update(id_batch)
+
+ if result_ids != input_ids:
+ logger.warning(
+ "Possible data corruption. "
+ "input_ids not in result_ids: {in_diff}.\n"
+ "result_ids not in input_ids: {out_diff}".format(
+ in_diff=input_ids.difference(result_ids), out_diff=result_ids.difference(input_ids)
+ )
+ )
+ if self._wait_until_document_ready and docs:
+ self._wait_for_document(collection, self.index_name, docs[-1])
+
+ def _insert_batch(
+ self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID]
+ ) -> Set[ItemID]:
+ """Compute embeddings for and insert a batch of Documents into the Collection.
+
+ For performance reasons, we chose to call self.embedding_function just once,
+ with the hopefully small tradeoff of having recreating Document dicts.
+
+ Args:
+ collection: MongoDB Collection
+ texts: List of the main contents of each document
+ metadatas: List of metadata mappings
+ ids: List of ids. Note that these are stored as _id in Collection.
+
+ Returns:
+ List of ids inserted.
+ """
+ n_texts = len(texts)
+ if n_texts == 0:
+ return []
+ # Embed and create the documents
+ embeddings = self.embedding_function(texts).tolist()
+ assert (
+ len(embeddings) == n_texts
+ ), f"The number of embeddings produced by self.embedding_function ({len(embeddings)} does not match the number of texts provided to it ({n_texts})."
+ to_insert = [
+ {"_id": i, "content": t, "metadata": m, "embedding": e}
+ for i, t, m, e in zip(ids, texts, metadatas, embeddings)
+ ]
+ # insert the documents in MongoDB Atlas
+ insert_result = collection.insert_many(to_insert) # type: ignore
+ return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs
+
+ def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None:
+ """Update documents, including their embeddings, in the Collection.
+
+ Optionally allow upsert as kwarg.
+
+ Uses deepcopy to avoid changing docs.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
+ """
+
+ n_docs = len(docs)
+ logger.info(f"Preparing to embed and update {n_docs=}")
+ # Compute the embeddings
+ embeddings: list[list[float]] = self.embedding_function([doc["content"] for doc in docs]).tolist()
+ # Prepare the updates
+ all_updates = []
+ for i in range(n_docs):
+ doc = deepcopy(docs[i])
+ doc["embedding"] = embeddings[i]
+ doc["_id"] = doc.pop("id")
+
+ all_updates.append(UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=kwargs.get("upsert", False)))
+ # Perform update in bulk
+ collection = self.get_collection(collection_name)
+ result = collection.bulk_write(all_updates)
+
+ if self._wait_until_document_ready and docs:
+ self._wait_for_document(collection, self.index_name, docs[-1])
+
+ # Log a result summary
+ logger.info(
+ "Matched: %s, Modified: %s, Upserted: %s",
+ result.matched_count,
+ result.modified_count,
+ result.upserted_count,
+ )
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs):
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ """
+ collection = self.get_collection(collection_name)
+ return collection.delete_many({"_id": {"$in": ids}})
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include: List[str] = None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include.
+ If None, will include ["metadata", "content"], ids will always be included.
+ Basically, use include to choose whether to include embedding and metadata
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ if include is None:
+ include_fields = {"_id": 1, "content": 1, "metadata": 1}
+ else:
+ include_fields = {k: 1 for k in set(include).union({"_id"})}
+ collection = self.get_collection(collection_name)
+ if ids is not None:
+ docs = collection.find({"_id": {"$in": ids}}, include_fields)
+ # Return with _id field from Collection into id for Document
+ return with_id_rename(docs)
+ else:
+ docs = collection.find({}, include_fields)
+ # Return with _id field from Collection into id for Document
+ return with_id_rename(docs)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments. Ones of importance follow:
+ oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
+ It determines the number of nearest neighbor candidates to consider during the search phase.
+ A higher value leads to more accuracy, but is slower. Default is 10
+
+ Returns:
+ QueryResults | For each query string, a list of nearest documents and their scores.
+ """
+ collection = self.get_collection(collection_name)
+ # Trivial case of an empty collection
+ if collection.count_documents({}) == 0:
+ return []
+
+ logger.debug(f"Using index: {self.index_name}")
+ results = []
+ for query_text in queries:
+ # Compute embedding vector from semantic query
+ logger.debug(f"Query: {query_text}")
+ query_vector = np.array(self.embedding_function([query_text])).tolist()[0]
+ # Find documents with similar vectors using the specified index
+ query_result = _vector_search(
+ query_vector,
+ n_results,
+ collection,
+ self.index_name,
+ distance_threshold,
+ **kwargs,
+ oversampling_factor=kwargs.get("oversampling_factor", 10),
+ )
+ # Change each _id key to id. with_id_rename, but with (doc, score) tuples
+ results.append(
+ [({**{k: v for k, v in d[0].items() if k != "_id"}, "id": d[0]["_id"]}, d[1]) for d in query_result]
+ )
+ return results
+
+
+def _vector_search(
+ embedding_vector: List[float],
+ n_results: int,
+ collection: Collection,
+ index_name: str,
+ distance_threshold: float = -1.0,
+ oversampling_factor=10,
+ include_embedding=False,
+) -> List[Tuple[Dict, float]]:
+ """Core $vectorSearch Aggregation pipeline.
+
+ Args:
+ embedding_vector: Embedding vector of semantic query
+ n_results: Number of documents to return. Defaults to 4.
+ collection: MongoDB Collection with vector index
+ index_name: Name of the vector index
+ distance_threshold: Only distance measures smaller than this will be returned.
+ Don't filter with it if 1 < x < 0. Default is -1.
+ oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
+ It determines the number of nearest neighbor candidates to consider during the search phase.
+ A higher value leads to more accuracy, but is slower. Default = 10
+
+ Returns:
+ List of tuples of length n_results from Collection.
+ Each tuple contains a document dict and a score.
+ """
+
+ pipeline = [
+ {
+ "$vectorSearch": {
+ "index": index_name,
+ "limit": n_results,
+ "numCandidates": n_results * oversampling_factor,
+ "queryVector": embedding_vector,
+ "path": "embedding",
+ }
+ },
+ {"$set": {"score": {"$meta": "vectorSearchScore"}}},
+ ]
+ if distance_threshold >= 0.0:
+ similarity_threshold = 1.0 - distance_threshold
+ pipeline.append({"$match": {"score": {"$gte": similarity_threshold}}})
+
+ if not include_embedding:
+ pipeline.append({"$project": {"embedding": 0}})
+
+ logger.debug("pipeline: %s", pipeline)
+ agg = collection.aggregate(pipeline)
+ return [(doc, doc.pop("score")) for doc in agg]
diff --git a/autogen/agentchat/contrib/vectordb/qdrant.py b/autogen/agentchat/contrib/vectordb/qdrant.py
new file mode 100644
index 000000000000..d9c4ee1d2e5a
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/qdrant.py
@@ -0,0 +1,328 @@
+import abc
+import logging
+import os
+from typing import Callable, List, Optional, Sequence, Tuple, Union
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import get_logger
+
+try:
+ from qdrant_client import QdrantClient, models
+except ImportError:
+ raise ImportError("Please install qdrant-client: `pip install qdrant-client`")
+
+logger = get_logger(__name__)
+
+Embeddings = Union[Sequence[float], Sequence[int]]
+
+
+class EmbeddingFunction(abc.ABC):
+ @abc.abstractmethod
+ def __call__(self, inputs: List[str]) -> List[Embeddings]:
+ raise NotImplementedError
+
+
+class FastEmbedEmbeddingFunction(EmbeddingFunction):
+ """Embedding function implementation using FastEmbed - https://qdrant.github.io/fastembed."""
+
+ def __init__(
+ self,
+ model_name: str = "BAAI/bge-small-en-v1.5",
+ batch_size: int = 256,
+ cache_dir: Optional[str] = None,
+ threads: Optional[int] = None,
+ parallel: Optional[int] = None,
+ **kwargs,
+ ):
+ """Initialize fastembed.TextEmbedding.
+
+ Args:
+ model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
+ batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
+ Defaults to 256.
+ cache_dir (str, optional): The path to the model cache directory.\
+ Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
+ threads (int, optional): The number of threads single onnxruntime session can use.
+ parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for large datasets.\
+ If `0`, use all available cores.\
+ If `None`, don't use data-parallel processing, use default onnxruntime threading.\
+ Defaults to None.
+ **kwargs: Additional options to pass to fastembed.TextEmbedding
+ Raises:
+ ValueError: If the model_name is not in the format / e.g. BAAI/bge-small-en-v1.5.
+ """
+ try:
+ from fastembed import TextEmbedding
+ except ImportError as e:
+ raise ValueError(
+ "The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
+ ) from e
+ self._batch_size = batch_size
+ self._parallel = parallel
+ self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
+
+ def __call__(self, inputs: List[str]) -> List[Embeddings]:
+ embeddings = self._model.embed(inputs, batch_size=self._batch_size, parallel=self._parallel)
+
+ return [embedding.tolist() for embedding in embeddings]
+
+
+class QdrantVectorDB(VectorDB):
+ """
+ A vector database implementation that uses Qdrant as the backend.
+ """
+
+ def __init__(
+ self,
+ *,
+ client=None,
+ embedding_function: EmbeddingFunction = None,
+ content_payload_key: str = "_content",
+ metadata_payload_key: str = "_metadata",
+ collection_options: dict = {},
+ **kwargs,
+ ) -> None:
+ """
+ Initialize the vector database.
+
+ Args:
+ client: qdrant_client.QdrantClient | An instance of QdrantClient.
+ embedding_function: Callable | The embedding function used to generate the vector representation
+ of the documents. Defaults to FastEmbedEmbeddingFunction.
+ collection_options: dict | The options for creating the collection.
+ kwargs: dict | Additional keyword arguments.
+ """
+ self.client: QdrantClient = client or QdrantClient(location=":memory:")
+ self.embedding_function = embedding_function or FastEmbedEmbeddingFunction()
+ self.collection_options = collection_options
+ self.content_payload_key = content_payload_key
+ self.metadata_payload_key = metadata_payload_key
+ self.type = "qdrant"
+
+ def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> None:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Any | The collection object.
+ """
+ embeddings_size = len(self.embedding_function(["test"])[0])
+
+ if self.client.collection_exists(collection_name) and overwrite:
+ self.client.delete_collection(collection_name)
+
+ if not self.client.collection_exists(collection_name):
+ self.client.create_collection(
+ collection_name,
+ vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
+ **self.collection_options,
+ )
+ elif not get_or_create:
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def get_collection(self, collection_name: str = None):
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ Any | The collection object.
+ """
+ if collection_name is None:
+ raise ValueError("The collection name is required.")
+
+ return self.client.get_collection(collection_name)
+
+ def delete_collection(self, collection_name: str) -> None:
+ """Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ Any
+ """
+ return self.client.delete_collection(collection_name)
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ if not docs:
+ return
+ if any(doc.get("content") is None for doc in docs):
+ raise ValueError("The document content is required.")
+ if any(doc.get("id") is None for doc in docs):
+ raise ValueError("The document id is required.")
+
+ if not upsert and not self._validate_upsert_ids(collection_name, [doc["id"] for doc in docs]):
+ logger.log("Some IDs already exist. Skipping insert", level=logging.WARN)
+
+ self.client.upsert(collection_name, points=self._documents_to_points(docs))
+
+ def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
+ if not docs:
+ return
+ if any(doc.get("id") is None for doc in docs):
+ raise ValueError("The document id is required.")
+ if any(doc.get("content") is None for doc in docs):
+ raise ValueError("The document content is required.")
+ if self._validate_update_ids(collection_name, [doc["id"] for doc in docs]):
+ return self.client.upsert(collection_name, points=self._documents_to_points(docs))
+
+ raise ValueError("Some IDs do not exist. Skipping update")
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ self.client.delete(collection_name, ids)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = 0,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is 0.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ embeddings = self.embedding_function(queries)
+ requests = [
+ models.SearchRequest(
+ vector=embedding,
+ limit=n_results,
+ score_threshold=distance_threshold,
+ with_payload=True,
+ with_vector=False,
+ )
+ for embedding in embeddings
+ ]
+
+ batch_results = self.client.search_batch(collection_name, requests)
+ return [self._scored_points_to_documents(results) for results in batch_results]
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is True.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ if ids is None:
+ results = self.client.scroll(collection_name=collection_name, with_payload=include, with_vectors=True)[0]
+ else:
+ results = self.client.retrieve(collection_name, ids=ids, with_payload=include, with_vectors=True)
+ return [self._point_to_document(result) for result in results]
+
+ def _point_to_document(self, point) -> Document:
+ return {
+ "id": point.id,
+ "content": point.payload.get(self.content_payload_key, ""),
+ "metadata": point.payload.get(self.metadata_payload_key, {}),
+ "embedding": point.vector,
+ }
+
+ def _points_to_documents(self, points) -> List[Document]:
+ return [self._point_to_document(point) for point in points]
+
+ def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> Tuple[Document, float]:
+ return self._point_to_document(scored_point), scored_point.score
+
+ def _documents_to_points(self, documents: List[Document]):
+ contents = [document["content"] for document in documents]
+ embeddings = self.embedding_function(contents)
+ points = [
+ models.PointStruct(
+ id=documents[i]["id"],
+ vector=embeddings[i],
+ payload={
+ self.content_payload_key: documents[i].get("content"),
+ self.metadata_payload_key: documents[i].get("metadata"),
+ },
+ )
+ for i in range(len(documents))
+ ]
+ return points
+
+ def _scored_points_to_documents(self, scored_points: List[models.ScoredPoint]) -> List[Tuple[Document, float]]:
+ return [self._scored_point_to_document(scored_point) for scored_point in scored_points]
+
+ def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool:
+ """
+ Validates all the IDs exist in the collection
+ """
+ retrieved_ids = [
+ point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
+ ]
+
+ if missing_ids := set(ids) - set(retrieved_ids):
+ logger.log(f"Missing IDs: {missing_ids}. Skipping update", level=logging.WARN)
+ return False
+
+ return True
+
+ def _validate_upsert_ids(self, collection_name: str, ids: List[str]) -> bool:
+ """
+ Validate none of the IDs exist in the collection
+ """
+ retrieved_ids = [
+ point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
+ ]
+
+ if existing_ids := set(ids) & set(retrieved_ids):
+ logger.log(f"Existing IDs: {existing_ids}.", level=logging.WARN)
+ return False
+
+ return True
diff --git a/autogen/agentchat/contrib/web_surfer.py b/autogen/agentchat/contrib/web_surfer.py
index af07be6d3432..f74915a9b403 100644
--- a/autogen/agentchat/contrib/web_surfer.py
+++ b/autogen/agentchat/contrib/web_surfer.py
@@ -41,6 +41,7 @@ def __init__(
summarizer_llm_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
browser_config: Optional[Union[Dict, None]] = None,
+ **kwargs,
):
super().__init__(
name=name,
@@ -53,6 +54,7 @@ def __init__(
code_execution_config=code_execution_config,
llm_config=llm_config,
default_auto_reply=default_auto_reply,
+ **kwargs,
)
self._create_summarizer_client(summarizer_llm_config, llm_config)
diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py
index b434fc648eb1..e19cbd56de2b 100644
--- a/autogen/agentchat/conversable_agent.py
+++ b/autogen/agentchat/conversable_agent.py
@@ -11,6 +11,7 @@
from openai import BadRequestError
+from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
from .._pydantic import model_dump
@@ -77,6 +78,7 @@ def __init__(
default_auto_reply: Union[str, Dict] = "",
description: Optional[str] = None,
chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
+ silent: Optional[bool] = None,
):
"""
Args:
@@ -125,6 +127,8 @@ def __init__(
chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents.
Can be used to give the agent a memory by providing the chat history. This will allow the agent to
resume previous had conversations. Defaults to an empty chat history.
+ silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of
+ silent in each function.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
@@ -146,6 +150,7 @@ def __init__(
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
+ self.silent = silent
# Take a copy to avoid modifying the given dict
if isinstance(llm_config, dict):
try:
@@ -262,6 +267,10 @@ def _validate_llm_config(self, llm_config):
)
self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config)
+ @staticmethod
+ def _is_silent(agent: Agent, silent: Optional[bool] = False) -> bool:
+ return agent.silent if agent.silent is not None else silent
+
@property
def name(self) -> str:
"""Get the name of the agent."""
@@ -368,9 +377,9 @@ def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable)
f["reply_func"] = new_reply_func
@staticmethod
- def _summary_from_nested_chats(
+ def _get_chats_to_run(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
- ) -> Tuple[bool, str]:
+ ) -> List[Dict[str, Any]]:
"""A simple chat reply function.
This function initiate one or a sequence of chats between the "recipient" and the agents in the
chat_queue.
@@ -397,22 +406,59 @@ def _summary_from_nested_chats(
if message:
current_c["message"] = message
chat_to_run.append(current_c)
+ return chat_to_run
+
+ @staticmethod
+ def _summary_from_nested_chats(
+ chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
+ ) -> Tuple[bool, Union[str, None]]:
+ """A simple chat reply function.
+ This function initiate one or a sequence of chats between the "recipient" and the agents in the
+ chat_queue.
+
+ It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
+
+ Returns:
+ Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
+ """
+ chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
return True, None
res = initiate_chats(chat_to_run)
return True, res[-1].summary
+ @staticmethod
+ async def _a_summary_from_nested_chats(
+ chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
+ ) -> Tuple[bool, Union[str, None]]:
+ """A simple chat reply function.
+ This function initiate one or a sequence of chats between the "recipient" and the agents in the
+ chat_queue.
+
+ It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
+
+ Returns:
+ Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
+ """
+ chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
+ if not chat_to_run:
+ return True, None
+ res = await a_initiate_chats(chat_to_run)
+ index_of_last_chat = chat_to_run[-1]["chat_id"]
+ return True, res[index_of_last_chat].summary
+
def register_nested_chats(
self,
chat_queue: List[Dict[str, Any]],
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats",
position: int = 2,
+ use_async: Union[bool, None] = None,
**kwargs,
) -> None:
"""Register a nested chat reply function.
Args:
- chat_queue (list): a list of chat objects to be initiated.
+ chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them.
trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details.
reply_func_from_nested_chats (Callable, str): the reply function for the nested chat.
The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
@@ -427,15 +473,33 @@ def reply_func_from_nested_chats(
) -> Tuple[bool, Union[str, Dict, None]]:
```
position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply.
+ use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync.
kwargs: Ref to `register_reply` for details.
"""
- if reply_func_from_nested_chats == "summary_from_nested_chats":
- reply_func_from_nested_chats = self._summary_from_nested_chats
- if not callable(reply_func_from_nested_chats):
- raise ValueError("reply_func_from_nested_chats must be a callable")
+ if use_async:
+ for chat in chat_queue:
+ if chat.get("chat_id") is None:
+ raise ValueError("chat_id is required for async nested chats")
+
+ if use_async:
+ if reply_func_from_nested_chats == "summary_from_nested_chats":
+ reply_func_from_nested_chats = self._a_summary_from_nested_chats
+ if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction(
+ reply_func_from_nested_chats
+ ):
+ raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")
- def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
- return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
+ async def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
+ return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
+
+ else:
+ if reply_func_from_nested_chats == "summary_from_nested_chats":
+ reply_func_from_nested_chats = self._summary_from_nested_chats
+ if not callable(reply_func_from_nested_chats):
+ raise ValueError("reply_func_from_nested_chats must be a callable")
+
+ def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
+ return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)
@@ -445,7 +509,9 @@ def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
position,
kwargs.get("config"),
kwargs.get("reset_config"),
- ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"),
+ ignore_async_in_sync_chat=(
+ not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat")
+ ),
)
@property
@@ -555,7 +621,7 @@ def _assert_valid_name(name):
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
return name
- def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent) -> bool:
+ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent, is_sending: bool) -> bool:
"""Append a message to the ChatCompletion conversation.
If the message received is a string, it will be put in the "content" field of the new dictionary.
@@ -567,6 +633,7 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
message (dict or str): message to be appended to the ChatCompletion conversation.
role (str): role of the message, can be "assistant" or "function".
conversation_id (Agent): id of the conversation, should be the recipient or sender.
+ is_sending (bool): If the agent (aka self) is sending to the conversation_id agent, otherwise receiving.
Returns:
bool: whether the message is appended to the ChatCompletion conversation.
@@ -596,7 +663,15 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
if oai_message.get("function_call", False) or oai_message.get("tool_calls", False):
oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call.
+ elif "name" not in oai_message:
+ # If we don't have a name field, append it
+ if is_sending:
+ oai_message["name"] = self.name
+ else:
+ oai_message["name"] = conversation_id.name
+
self._oai_messages[conversation_id].append(oai_message)
+
return True
def _process_message_before_send(
@@ -605,7 +680,9 @@ def _process_message_before_send(
"""Process the message before sending it to the recipient."""
hook_list = self.hook_lists["process_message_before_send"]
for hook in hook_list:
- message = hook(sender=self, message=message, recipient=recipient, silent=silent)
+ message = hook(
+ sender=self, message=message, recipient=recipient, silent=ConversableAgent._is_silent(self, silent)
+ )
return message
def send(
@@ -647,10 +724,10 @@ def send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
"""
- message = self._process_message_before_send(message, recipient, silent)
+ message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent))
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
- valid = self._append_oai_message(message, "assistant", recipient)
+ valid = self._append_oai_message(message, "assistant", recipient, is_sending=True)
if valid:
recipient.receive(message, self, request_reply, silent)
else:
@@ -697,10 +774,10 @@ async def a_send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
"""
- message = self._process_message_before_send(message, recipient, silent)
+ message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent))
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
- valid = self._append_oai_message(message, "assistant", recipient)
+ valid = self._append_oai_message(message, "assistant", recipient, is_sending=True)
if valid:
await recipient.a_receive(message, self, request_reply, silent)
else:
@@ -771,7 +848,7 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool):
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
- valid = self._append_oai_message(message, "user", sender)
+ valid = self._append_oai_message(message, "user", sender, is_sending=False)
if logging_enabled():
log_event(self, "received_message", message=message, sender=sender.name, valid=valid)
@@ -779,7 +856,8 @@ def _process_received_message(self, message: Union[Dict, str], sender: Agent, si
raise ValueError(
"Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
)
- if not silent:
+
+ if not ConversableAgent._is_silent(sender, silent):
self._print_received_message(message, sender)
def receive(
@@ -1580,8 +1658,8 @@ async def a_generate_function_call_reply(
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
- if "function_call" in message:
- func_call = message["function_call"]
+ func_call = message.get("function_call")
+ if func_call:
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
if func and inspect.iscoroutinefunction(func):
@@ -1722,7 +1800,7 @@ def check_termination_and_human_reply(
sender_name = "the sender" if sender is None else sender.name
if self.human_input_mode == "ALWAYS":
reply = self.get_human_input(
- f"Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
+ f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1835,7 +1913,7 @@ async def a_check_termination_and_human_reply(
sender_name = "the sender" if sender is None else sender.name
if self.human_input_mode == "ALWAYS":
reply = await self.a_get_human_input(
- f"Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
+ f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -2184,7 +2262,7 @@ def _format_json_str(jstr):
Ex 2:
"{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}"
- 2. this function also handles JSON escape sequences inside quotes,
+ 2. this function also handles JSON escape sequences inside quotes.
Ex 1:
'{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}'
"""
@@ -2233,7 +2311,7 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict
arguments = json.loads(input_string)
except json.JSONDecodeError as e:
arguments = None
- content = f"Error: {e}\n You argument should follow json format."
+ content = f"Error: {e}\n The argument must be in JSON format."
# Try to execute the function
if arguments is not None:
@@ -2290,7 +2368,7 @@ async def a_execute_function(self, func_call):
arguments = json.loads(input_string)
except json.JSONDecodeError as e:
arguments = None
- content = f"Error: {e}\n You argument should follow json format."
+ content = f"Error: {e}\n The argument must be in JSON format."
# Try to execute the function
if arguments is not None:
@@ -2364,7 +2442,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str:
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
- content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
+ content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
@@ -2526,14 +2604,16 @@ def _wrap_function(self, func: F) -> F:
@functools.wraps(func)
def _wrapped_func(*args, **kwargs):
retval = func(*args, **kwargs)
- log_function_use(self, func, kwargs, retval)
+ if logging_enabled():
+ log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
@load_basemodels_if_needed
@functools.wraps(func)
async def _a_wrapped_func(*args, **kwargs):
retval = await func(*args, **kwargs)
- log_function_use(self, func, kwargs, retval)
+ if logging_enabled():
+ log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func
diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py
index 48f11d526cc6..c6355a13b94d 100644
--- a/autogen/agentchat/groupchat.py
+++ b/autogen/agentchat/groupchat.py
@@ -5,7 +5,7 @@
import re
import sys
from dataclasses import dataclass, field
-from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from ..code_utils import content_str
from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
@@ -17,6 +17,12 @@
from .chat import ChatResult
from .conversable_agent import ConversableAgent
+try:
+ # Non-core module
+ from .contrib.capabilities import transform_messages
+except ImportError:
+ transform_messages = None
+
logger = logging.getLogger(__name__)
@@ -76,6 +82,8 @@ def custom_speaker_selection_func(
of times until a single agent is returned or it exhausts the maximum attempts.
Applies only to "auto" speaker selection method.
Default is 2.
+ - select_speaker_transform_messages: (optional) the message transformations to apply to the nested select speaker agent-to-agent chat messages.
+ Takes a TransformMessages object, defaults to None and is only utilised when the speaker selection method is "auto".
- select_speaker_auto_verbose: whether to output the select speaker responses and selections
If set to True, the outputs from the two agents in the nested select speaker chat will be output, along with
whether the responses were successful, or not, in selecting an agent
@@ -132,6 +140,7 @@ def custom_speaker_selection_func(
The names are case-sensitive and should not be abbreviated or changed.
The only names that are accepted are {agentlist}.
Respond with ONLY the name of the speaker and DO NOT provide a reason."""
+ select_speaker_transform_messages: Optional[Any] = None
select_speaker_auto_verbose: Optional[bool] = False
role_for_select_speaker_messages: Optional[str] = "system"
@@ -249,6 +258,20 @@ def __post_init__(self):
elif self.max_retries_for_selecting_speaker < 0:
raise ValueError("max_retries_for_selecting_speaker must be greater than or equal to zero")
+ # Load message transforms here (load once for the Group Chat so we don't have to re-initiate it and it maintains the cache across subsequent select speaker calls)
+ self._speaker_selection_transforms = None
+ if self.select_speaker_transform_messages is not None:
+ if transform_messages is not None:
+ if isinstance(self.select_speaker_transform_messages, transform_messages.TransformMessages):
+ self._speaker_selection_transforms = self.select_speaker_transform_messages
+ else:
+ raise ValueError("select_speaker_transform_messages must be None or MessageTransforms.")
+ else:
+ logger.warning(
+ "TransformMessages could not be loaded, the 'select_speaker_transform_messages' transform"
+ "will not apply."
+ )
+
# Validate select_speaker_auto_verbose
if self.select_speaker_auto_verbose is None or not isinstance(self.select_speaker_auto_verbose, bool):
raise ValueError("select_speaker_auto_verbose cannot be None or non-bool")
@@ -649,11 +672,16 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
if self.select_speaker_prompt_template is not None:
start_message = {
"content": self.select_speaker_prompt(agents),
+ "name": "checking_agent",
"override_role": self.role_for_select_speaker_messages,
}
else:
start_message = messages[-1]
+ # Add the message transforms, if any, to the speaker selection agent
+ if self._speaker_selection_transforms is not None:
+ self._speaker_selection_transforms.add_to_agent(speaker_selection_agent)
+
# Run the speaker selection chat
result = checking_agent.initiate_chat(
speaker_selection_agent,
@@ -748,6 +776,10 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
else:
start_message = messages[-1]
+ # Add the message transforms, if any, to the speaker selection agent
+ if self._speaker_selection_transforms is not None:
+ self._speaker_selection_transforms.add_to_agent(speaker_selection_agent)
+
# Run the speaker selection chat
result = await checking_agent.a_initiate_chat(
speaker_selection_agent,
@@ -813,6 +845,7 @@ def _validate_speaker_name(
return True, {
"content": self.select_speaker_auto_multiple_template.format(agentlist=agentlist),
+ "name": "checking_agent",
"override_role": self.role_for_select_speaker_messages,
}
else:
@@ -842,6 +875,7 @@ def _validate_speaker_name(
return True, {
"content": self.select_speaker_auto_none_template.format(agentlist=agentlist),
+ "name": "checking_agent",
"override_role": self.role_for_select_speaker_messages,
}
else:
@@ -965,6 +999,7 @@ def __init__(
# Store groupchat
self._groupchat = groupchat
+ self._last_speaker = None
self._silent = silent
# Order of register_reply is important.
@@ -1006,6 +1041,53 @@ def _prepare_chat(
if (recipient != agent or prepare_recipient) and isinstance(agent, ConversableAgent):
agent._prepare_chat(self, clear_history, False, reply_at_receive)
+ @property
+ def last_speaker(self) -> Agent:
+ """Return the agent who sent the last message to group chat manager.
+
+ In a group chat, an agent will always send a message to the group chat manager, and the group chat manager will
+ send the message to all other agents in the group chat. So, when an agent receives a message, it will always be
+ from the group chat manager. With this property, the agent receiving the message can know who actually sent the
+ message.
+
+ Example:
+ ```python
+ from autogen import ConversableAgent
+ from autogen import GroupChat, GroupChatManager
+
+
+ def print_messages(recipient, messages, sender, config):
+ # Print the message immediately
+ print(
+ f"Sender: {sender.name} | Recipient: {recipient.name} | Message: {messages[-1].get('content')}"
+ )
+ print(f"Real Sender: {sender.last_speaker.name}")
+ assert sender.last_speaker.name in messages[-1].get("content")
+ return False, None # Required to ensure the agent communication flow continues
+
+
+ agent_a = ConversableAgent("agent A", default_auto_reply="I'm agent A.")
+ agent_b = ConversableAgent("agent B", default_auto_reply="I'm agent B.")
+ agent_c = ConversableAgent("agent C", default_auto_reply="I'm agent C.")
+ for agent in [agent_a, agent_b, agent_c]:
+ agent.register_reply(
+ [ConversableAgent, None], reply_func=print_messages, config=None
+ )
+ group_chat = GroupChat(
+ [agent_a, agent_b, agent_c],
+ messages=[],
+ max_round=6,
+ speaker_selection_method="random",
+ allow_repeat_speaker=True,
+ )
+ chat_manager = GroupChatManager(group_chat)
+ groupchat_result = agent_a.initiate_chat(
+ chat_manager, message="Hi, there, I'm agent A."
+ )
+ ```
+ """
+ return self._last_speaker
+
def run_chat(
self,
messages: Optional[List[Dict]] = None,
@@ -1034,6 +1116,7 @@ def run_chat(
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
+ self._last_speaker = speaker
groupchat.append(message, speaker)
# broadcast the message to all agents except the speaker
for agent in groupchat.agents:
@@ -1212,11 +1295,10 @@ def resume(
if not message_speaker_agent and message["name"] == self.name:
message_speaker_agent = self
- # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it)
if i != len(messages) - 1:
for agent in self._groupchat.agents:
- if agent.name != message["name"]:
- self.send(message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True)
+ self.send(message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True)
# Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
if message_speaker_agent:
@@ -1258,7 +1340,7 @@ def resume(
async def a_resume(
self,
messages: Union[List[Dict], str],
- remove_termination_string: Union[str, Callable[[str], str]],
+ remove_termination_string: Union[str, Callable[[str], str]] = None,
silent: Optional[bool] = False,
) -> Tuple[ConversableAgent, Dict]:
"""Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established
@@ -1316,13 +1398,12 @@ async def a_resume(
if not message_speaker_agent and message["name"] == self.name:
message_speaker_agent = self
- # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it)
if i != len(messages) - 1:
for agent in self._groupchat.agents:
- if agent.name != message["name"]:
- await self.a_send(
- message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True
- )
+ await self.a_send(
+ message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True
+ )
# Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
if message_speaker_agent:
diff --git a/autogen/agentchat/user_proxy_agent.py b/autogen/agentchat/user_proxy_agent.py
index a80296a8355a..d50e4d8b89c5 100644
--- a/autogen/agentchat/user_proxy_agent.py
+++ b/autogen/agentchat/user_proxy_agent.py
@@ -35,6 +35,7 @@ def __init__(
llm_config: Optional[Union[Dict, Literal[False]]] = False,
system_message: Optional[Union[str, List]] = "",
description: Optional[str] = None,
+ **kwargs,
):
"""
Args:
@@ -79,6 +80,8 @@ def __init__(
Only used when llm_config is not False. Use it to reprogram the agent.
description (str): a short description of the agent. This description is used by other agents
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](conversable_agent#__init__).
"""
super().__init__(
name=name,
@@ -93,6 +96,7 @@ def __init__(
description=(
description if description is not None else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode]
),
+ **kwargs,
)
if logging_enabled():
diff --git a/autogen/coding/base.py b/autogen/coding/base.py
index ccbfe6b92932..7c9e19d73f33 100644
--- a/autogen/coding/base.py
+++ b/autogen/coding/base.py
@@ -4,7 +4,6 @@
from pydantic import BaseModel, Field
-from ..agentchat.agent import LLMAgent
from ..types import UserMessageImageContentPart, UserMessageTextContentPart
__all__ = ("CodeBlock", "CodeResult", "CodeExtractor", "CodeExecutor", "CodeExecutionConfig")
diff --git a/autogen/coding/func_with_reqs.py b/autogen/coding/func_with_reqs.py
index 6f199573822b..f255f1df0179 100644
--- a/autogen/coding/func_with_reqs.py
+++ b/autogen/coding/func_with_reqs.py
@@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from importlib.abc import SourceLoader
from textwrap import dedent, indent
-from typing import Any, Callable, Generic, List, TypeVar, Union
+from typing import Any, Callable, Generic, List, Set, TypeVar, Union
from typing_extensions import ParamSpec
@@ -159,12 +159,12 @@ def _build_python_functions_file(
funcs: List[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]]
) -> str:
# First collect all global imports
- global_imports = set()
+ global_imports: Set[str] = set()
for func in funcs:
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
- global_imports.update(func.global_imports)
+ global_imports.update(map(_import_to_str, func.global_imports))
- content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
+ content = "\n".join(global_imports) + "\n\n"
for func in funcs:
content += _to_code(func) + "\n\n"
diff --git a/autogen/coding/jupyter/jupyter_client.py b/autogen/coding/jupyter/jupyter_client.py
index b3de374fce9b..787009dafe2f 100644
--- a/autogen/coding/jupyter/jupyter_client.py
+++ b/autogen/coding/jupyter/jupyter_client.py
@@ -39,6 +39,10 @@ def _get_headers(self) -> Dict[str, str]:
return {}
return {"Authorization": f"token {self._connection_info.token}"}
+ def _get_cookies(self) -> str:
+ cookies = self._session.cookies.get_dict()
+ return "; ".join([f"{name}={value}" for name, value in cookies.items()])
+
def _get_api_base_url(self) -> str:
protocol = "https" if self._connection_info.use_https else "http"
port = f":{self._connection_info.port}" if self._connection_info.port else ""
@@ -87,7 +91,7 @@ def restart_kernel(self, kernel_id: str) -> None:
def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient:
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
- ws = websocket.create_connection(ws_url, header=self._get_headers())
+ ws = websocket.create_connection(ws_url, header=self._get_headers(), cookie=self._get_cookies())
return JupyterKernelClient(ws)
diff --git a/autogen/coding/local_commandline_code_executor.py b/autogen/coding/local_commandline_code_executor.py
index 620b359a4aee..2280f7f030d8 100644
--- a/autogen/coding/local_commandline_code_executor.py
+++ b/autogen/coding/local_commandline_code_executor.py
@@ -221,7 +221,12 @@ def _setup_functions(self) -> None:
cmd = [py_executable, "-m", "pip", "install"] + required_packages
try:
result = subprocess.run(
- cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout)
+ cmd,
+ cwd=self._work_dir,
+ capture_output=True,
+ text=True,
+ timeout=float(self._timeout),
+ encoding="utf-8",
)
except subprocess.TimeoutExpired as e:
raise ValueError("Pip install timed out") from e
@@ -303,7 +308,13 @@ def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> Comman
try:
result = subprocess.run(
- cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout), env=env
+ cmd,
+ cwd=self._work_dir,
+ capture_output=True,
+ text=True,
+ timeout=float(self._timeout),
+ env=env,
+ encoding="utf-8",
)
except subprocess.TimeoutExpired:
logs_all += "\n" + TIMEOUT_MSG
diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py
index af5583587f66..07c9c3b76a76 100644
--- a/autogen/logger/file_logger.py
+++ b/autogen/logger/file_logger.py
@@ -18,7 +18,10 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.bedrock import BedrockClient
+ from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@@ -87,7 +90,7 @@ def log_chat_completion(
thread_id = threading.get_ident()
source_name = None
if isinstance(source, str):
- source_name = source
+ source_name = getattr(source, "name", "unknown")
else:
source_name = source.name
try:
@@ -204,7 +207,17 @@ def log_new_wrapper(
def log_new_client(
self,
- client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient,
+ client: (
+ AzureOpenAI
+ | OpenAI
+ | GeminiClient
+ | AnthropicClient
+ | MistralAIClient
+ | TogetherClient
+ | GroqClient
+ | CohereClient
+ | BedrockClient
+ ),
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py
index 969a943017e3..f76d039ce9de 100644
--- a/autogen/logger/sqlite_logger.py
+++ b/autogen/logger/sqlite_logger.py
@@ -19,7 +19,10 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.bedrock import BedrockClient
+ from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@@ -391,7 +394,17 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st
def log_new_client(
self,
- client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
+ client: Union[
+ AzureOpenAI,
+ OpenAI,
+ GeminiClient,
+ AnthropicClient,
+ MistralAIClient,
+ TogetherClient,
+ GroqClient,
+ CohereClient,
+ BedrockClient,
+ ],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py
index e2448929e618..7bb4608bd258 100644
--- a/autogen/oai/anthropic.py
+++ b/autogen/oai/anthropic.py
@@ -16,6 +16,27 @@
]
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
+
+Example usage for Anthropic Bedrock:
+
+Install the `anthropic` package by running `pip install --upgrade anthropic`.
+- https://docs.anthropic.com/en/docs/quickstart-guide
+
+import autogen
+
+config_list = [
+ {
+ "model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ "aws_access_key":,
+ "aws_secret_key":,
+ "aws_session_token":,
+ "aws_region":"us-east-1",
+ "api_type": "anthropic",
+ }
+]
+
+assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
+
"""
from __future__ import annotations
@@ -28,7 +49,7 @@
import warnings
from typing import Any, Dict, List, Tuple, Union
-from anthropic import Anthropic
+from anthropic import Anthropic, AnthropicBedrock
from anthropic import __version__ as anthropic_version
from anthropic.types import Completion, Message, TextBlock, ToolUseBlock
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
@@ -64,14 +85,38 @@ def __init__(self, **kwargs: Any):
api_key (str): The API key for the Anthropic API or set the `ANTHROPIC_API_KEY` environment variable.
"""
self._api_key = kwargs.get("api_key", None)
+ self._aws_access_key = kwargs.get("aws_access_key", None)
+ self._aws_secret_key = kwargs.get("aws_secret_key", None)
+ self._aws_session_token = kwargs.get("aws_session_token", None)
+ self._aws_region = kwargs.get("aws_region", None)
if not self._api_key:
self._api_key = os.getenv("ANTHROPIC_API_KEY")
- if self._api_key is None:
- raise ValueError("API key is required to use the Anthropic API.")
+ if not self._aws_access_key:
+ self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
+
+ if not self._aws_secret_key:
+ self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
+
+ if not self._aws_region:
+ self._aws_region = os.getenv("AWS_REGION")
+
+ if self._api_key is None and (
+ self._aws_access_key is None or self._aws_secret_key is None or self._aws_region is None
+ ):
+ raise ValueError("API key or AWS credentials are required to use the Anthropic API.")
+
+ if self._api_key is not None:
+ self._client = Anthropic(api_key=self._api_key)
+ else:
+ self._client = AnthropicBedrock(
+ aws_access_key=self._aws_access_key,
+ aws_secret_key=self._aws_secret_key,
+ aws_session_token=self._aws_session_token,
+ aws_region=self._aws_region,
+ )
- self._client = Anthropic(api_key=self._api_key)
self._last_tooluse_status = {}
def load_config(self, params: Dict[str, Any]):
@@ -107,6 +152,22 @@ def cost(self, response) -> float:
def api_key(self):
return self._api_key
+ @property
+ def aws_access_key(self):
+ return self._aws_access_key
+
+ @property
+ def aws_secret_key(self):
+ return self._aws_secret_key
+
+ @property
+ def aws_session_token(self):
+ return self._aws_session_token
+
+ @property
+ def aws_region(self):
+ return self._aws_region
+
def create(self, params: Dict[str, Any]) -> Completion:
if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
@@ -253,7 +314,7 @@ def oai_messages_to_anthropic_messages(params: Dict[str, Any]) -> list[dict[str,
last_tool_result_index = -1
for message in params["messages"]:
if message["role"] == "system":
- params["system"] = message["content"]
+ params["system"] = params.get("system", "") + (" " if "system" in params else "") + message["content"]
else:
# New messages will be added here, manage role alternations
expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
diff --git a/autogen/oai/bedrock.py b/autogen/oai/bedrock.py
new file mode 100644
index 000000000000..7894781e3ee5
--- /dev/null
+++ b/autogen/oai/bedrock.py
@@ -0,0 +1,606 @@
+"""
+Create a compatible client for the Amazon Bedrock Converse API.
+
+Example usage:
+Install the `boto3` package by running `pip install --upgrade boto3`.
+- https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
+
+import autogen
+
+config_list = [
+ {
+ "api_type": "bedrock",
+ "model": "meta.llama3-1-8b-instruct-v1:0",
+ "aws_region": "us-west-2",
+ "aws_access_key": "",
+ "aws_secret_key": "",
+ "price" : [0.003, 0.015]
+ }
+]
+
+assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
+
+"""
+
+from __future__ import annotations
+
+import base64
+import json
+import os
+import re
+import time
+import warnings
+from typing import Any, Dict, List, Literal, Tuple
+
+import boto3
+import requests
+from botocore.config import Config
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import validate_parameter
+
+
+class BedrockClient:
+ """Client for Amazon's Bedrock Converse API."""
+
+ _retries = 5
+
+ def __init__(self, **kwargs: Any):
+ """
+ Initialises BedrockClient for Amazon's Bedrock Converse API
+ """
+ self._aws_access_key = kwargs.get("aws_access_key", None)
+ self._aws_secret_key = kwargs.get("aws_secret_key", None)
+ self._aws_session_token = kwargs.get("aws_session_token", None)
+ self._aws_region = kwargs.get("aws_region", None)
+ self._aws_profile_name = kwargs.get("aws_profile_name", None)
+
+ if not self._aws_access_key:
+ self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
+
+ if not self._aws_secret_key:
+ self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
+
+ if not self._aws_session_token:
+ self._aws_session_token = os.getenv("AWS_SESSION_TOKEN")
+
+ if not self._aws_region:
+ self._aws_region = os.getenv("AWS_REGION")
+
+ if self._aws_region is None:
+ raise ValueError("Region is required to use the Amazon Bedrock API.")
+
+ # Initialize Bedrock client, session, and runtime
+ bedrock_config = Config(
+ region_name=self._aws_region,
+ signature_version="v4",
+ retries={"max_attempts": self._retries, "mode": "standard"},
+ )
+
+ session = boto3.Session(
+ aws_access_key_id=self._aws_access_key,
+ aws_secret_access_key=self._aws_secret_key,
+ aws_session_token=self._aws_session_token,
+ profile_name=self._aws_profile_name,
+ )
+
+ self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)
+
+ def message_retrieval(self, response):
+ """Retrieve the messages from the response."""
+ return [choice.message for choice in response.choices]
+
+ def parse_custom_params(self, params: Dict[str, Any]):
+ """
+ Parses custom parameters for logic in this client class
+ """
+
+ # Should we separate system messages into its own request parameter, default is True
+ # This is required because not all models support a system prompt (e.g. Mistral Instruct).
+ self._supports_system_prompts = params.get("supports_system_prompts", True)
+
+ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ Loads the valid parameters required to invoke Bedrock Converse
+ Returns a tuple of (base_params, additional_params)
+ """
+
+ base_params = {}
+ additional_params = {}
+
+ # Amazon Bedrock base model IDs are here:
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
+ self._model_id = params.get("model", None)
+ assert self._model_id, "Please provide the 'model` in the config_list to use Amazon Bedrock"
+
+ # Parameters vary based on the model used.
+ # As we won't cater for all models and parameters, it's the developer's
+ # responsibility to implement the parameters and they will only be
+ # included if the developer has it in the config.
+ #
+ # Important:
+ # No defaults will be used (as they can vary per model)
+ # No ranges will be used (as they can vary)
+ # We will cover all the main parameters but there may be others
+ # that need to be added later
+ #
+ # Here are some pages that show the parameters available for different models
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-chat-completion.html
+
+ # Here are the possible "base" parameters and their suitable types
+ base_parameters = [["temperature", (float, int)], ["topP", (float, int)], ["maxTokens", (int)]]
+
+ for param_name, suitable_types in base_parameters:
+ if param_name in params:
+ base_params[param_name] = validate_parameter(
+ params, param_name, suitable_types, False, None, None, None
+ )
+
+ # Here are the possible "model-specific" parameters and their suitable types, known as additional parameters
+ additional_parameters = [
+ ["top_p", (float, int)],
+ ["top_k", (int)],
+ ["k", (int)],
+ ["seed", (int)],
+ ]
+
+ for param_name, suitable_types in additional_parameters:
+ if param_name in params:
+ additional_params[param_name] = validate_parameter(
+ params, param_name, suitable_types, False, None, None, None
+ )
+
+ # Streaming
+ if "stream" in params:
+ self._streaming = params["stream"]
+ else:
+ self._streaming = False
+
+ # For this release we will not support streaming as many models do not support streaming with tool use
+ if self._streaming:
+ warnings.warn(
+ "Streaming is not currently supported, streaming will be disabled.",
+ UserWarning,
+ )
+ self._streaming = False
+
+ return base_params, additional_params
+
+ def create(self, params):
+ """Run Amazon Bedrock inference and return AutoGen response"""
+
+ # Set custom client class settings
+ self.parse_custom_params(params)
+
+ # Parse the inference parameters
+ base_params, additional_params = self.parse_params(params)
+
+ has_tools = "tools" in params
+ messages = oai_messages_to_bedrock_messages(params["messages"], has_tools, self._supports_system_prompts)
+
+ if self._supports_system_prompts:
+ system_messages = extract_system_messages(params["messages"])
+
+ tool_config = format_tools(params["tools"] if has_tools else [])
+
+ request_args = {"messages": messages, "modelId": self._model_id}
+
+ # Base and additional args
+ if len(base_params) > 0:
+ request_args["inferenceConfig"] = base_params
+
+ if len(additional_params) > 0:
+ request_args["additionalModelRequestFields"] = additional_params
+
+ if self._supports_system_prompts:
+ request_args["system"] = system_messages
+
+ if len(tool_config["tools"]) > 0:
+ request_args["toolConfig"] = tool_config
+
+ try:
+ response = self.bedrock_runtime.converse(
+ **request_args,
+ )
+ except Exception as e:
+ raise RuntimeError(f"Failed to get response from Bedrock: {e}")
+
+ if response is None:
+ raise RuntimeError(f"Failed to get response from Bedrock after retrying {self._retries} times.")
+
+ finish_reason = convert_stop_reason_to_finish_reason(response["stopReason"])
+ response_message = response["output"]["message"]
+
+ if finish_reason == "tool_calls":
+ tool_calls = format_tool_calls(response_message["content"])
+ # text = ""
+ else:
+ tool_calls = None
+
+ text = ""
+ for content in response_message["content"]:
+ if "text" in content:
+ text = content["text"]
+ # NOTE: other types of output may be dealt with here
+
+ message = ChatCompletionMessage(role="assistant", content=text, tool_calls=tool_calls)
+
+ response_usage = response["usage"]
+ usage = CompletionUsage(
+ prompt_tokens=response_usage["inputTokens"],
+ completion_tokens=response_usage["outputTokens"],
+ total_tokens=response_usage["totalTokens"],
+ )
+
+ return ChatCompletion(
+ id=response["ResponseMetadata"]["RequestId"],
+ choices=[Choice(finish_reason=finish_reason, index=0, message=message)],
+ created=int(time.time()),
+ model=self._model_id,
+ object="chat.completion",
+ usage=usage,
+ )
+
+ def cost(self, response: ChatCompletion) -> float:
+ """Calculate the cost of the response."""
+ return calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens, response.model)
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Get the usage of tokens and their cost information."""
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+
+def extract_system_messages(messages: List[dict]) -> List:
+ """Extract the system messages from the list of messages.
+
+ Args:
+ messages (list[dict]): List of messages.
+
+ Returns:
+ List[SystemMessage]: List of System messages.
+ """
+
+ """
+ system_messages = [message.get("content")[0]["text"] for message in messages if message.get("role") == "system"]
+ return system_messages # ''.join(system_messages)
+ """
+
+ for message in messages:
+ if message.get("role") == "system":
+ if isinstance(message["content"], str):
+ return [{"text": message.get("content")}]
+ else:
+ return [{"text": message.get("content")[0]["text"]}]
+ return []
+
+
+def oai_messages_to_bedrock_messages(
+ messages: List[Dict[str, Any]], has_tools: bool, supports_system_prompts: bool
+) -> List[Dict]:
+ """
+ Convert messages from OAI format to Bedrock format.
+ We correct for any specific role orders and types, etc.
+ AWS Bedrock requires messages to alternate between user and assistant roles. This function ensures that the messages
+ are in the correct order and format for Bedrock by inserting "Please continue" messages as needed.
+ This is the same method as the one in the Autogen Anthropic client
+ """
+
+ # Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
+ # Bedrock requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
+ # This can occur when we don't need tool calling, such as for group chat speaker selection
+
+ # Convert messages to Bedrock compliant format
+
+ # Take out system messages if the model supports it, otherwise leave them in.
+ if supports_system_prompts:
+ messages = [x for x in messages if not x["role"] == "system"]
+ else:
+ # Replace role="system" with role="user"
+ for msg in messages:
+ if msg["role"] == "system":
+ msg["role"] = "user"
+
+ processed_messages = []
+
+ # Used to interweave user messages to ensure user/assistant alternating
+ user_continue_message = {"content": [{"text": "Please continue."}], "role": "user"}
+ assistant_continue_message = {
+ "content": [{"text": "Please continue."}],
+ "role": "assistant",
+ }
+
+ tool_use_messages = 0
+ tool_result_messages = 0
+ last_tool_use_index = -1
+ last_tool_result_index = -1
+ # user_role_index = 0 if supports_system_prompts else 1 # If system prompts are supported, messages start with user, otherwise they'll be the second message
+ for message in messages:
+ # New messages will be added here, manage role alternations
+ expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
+
+ if "tool_calls" in message:
+ # Map the tool call options to Bedrock's format
+ tool_uses = []
+ tool_names = []
+ for tool_call in message["tool_calls"]:
+ tool_uses.append(
+ {
+ "toolUse": {
+ "toolUseId": tool_call["id"],
+ "name": tool_call["function"]["name"],
+ "input": json.loads(tool_call["function"]["arguments"]),
+ }
+ }
+ )
+ if has_tools:
+ tool_use_messages += 1
+ tool_names.append(tool_call["function"]["name"])
+
+ if expected_role == "user":
+ # Insert an extra user message as we will append an assistant message
+ processed_messages.append(user_continue_message)
+
+ if has_tools:
+ processed_messages.append({"role": "assistant", "content": tool_uses})
+ last_tool_use_index = len(processed_messages) - 1
+ else:
+ # Not using tools, so put in a plain text message
+ processed_messages.append(
+ {
+ "role": "assistant",
+ "content": [
+ {"text": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]"}
+ ],
+ }
+ )
+ elif "tool_call_id" in message:
+ if has_tools:
+ # Map the tool usage call to tool_result for Bedrock
+ tool_result = {
+ "toolResult": {
+ "toolUseId": message["tool_call_id"],
+ "content": [{"text": message["content"]}],
+ }
+ }
+
+ # If the previous message also had a tool_result, add it to that
+ # Otherwise append a new message
+ if last_tool_result_index == len(processed_messages) - 1:
+ processed_messages[-1]["content"].append(tool_result)
+ else:
+ if expected_role == "assistant":
+ # Insert an extra assistant message as we will append a user message
+ processed_messages.append(assistant_continue_message)
+
+ processed_messages.append({"role": "user", "content": [tool_result]})
+ last_tool_result_index = len(processed_messages) - 1
+
+ tool_result_messages += 1
+ else:
+ # Not using tools, so put in a plain text message
+ processed_messages.append(
+ {
+ "role": "user",
+ "content": [{"text": f"Running the function returned: {message['content']}"}],
+ }
+ )
+ elif message["content"] == "":
+ # Ignoring empty messages
+ pass
+ else:
+ if expected_role != message["role"] and not (len(processed_messages) == 0 and message["role"] == "system"):
+ # Inserting the alternating continue message (ignore if it's the first message and a system message)
+ processed_messages.append(
+ user_continue_message if expected_role == "user" else assistant_continue_message
+ )
+
+ processed_messages.append(
+ {
+ "role": message["role"],
+ "content": parse_content_parts(message=message),
+ }
+ )
+
+ # We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
+ if has_tools and tool_use_messages != tool_result_messages:
+ processed_messages[last_tool_use_index] = assistant_continue_message
+
+ # name is not a valid field on messages
+ for message in processed_messages:
+ if "name" in message:
+ message.pop("name", None)
+
+ # Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
+ # So, if the last role is not user, add a 'user' continue message at the end
+ if processed_messages[-1]["role"] != "user":
+ processed_messages.append(user_continue_message)
+
+ return processed_messages
+
+
+def parse_content_parts(
+ message: Dict[str, Any],
+) -> List[dict]:
+ content: str | List[Dict[str, Any]] = message.get("content")
+ if isinstance(content, str):
+ return [
+ {
+ "text": content,
+ }
+ ]
+ content_parts = []
+ for part in content:
+ # part_content: Dict = part.get("content")
+ if "text" in part: # part_content:
+ content_parts.append(
+ {
+ "text": part.get("text"),
+ }
+ )
+ elif "image_url" in part: # part_content:
+ image_data, content_type = parse_image(part.get("image_url").get("url"))
+ content_parts.append(
+ {
+ "image": {
+ "format": content_type[6:], # image/
+ "source": {"bytes": image_data},
+ },
+ }
+ )
+ else:
+ # Ignore..
+ continue
+ return content_parts
+
+
+def parse_image(image_url: str) -> Tuple[bytes, str]:
+ """Try to get the raw data from an image url.
+
+ Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html
+ returns a tuple of (Image Data, Content Type)
+ """
+ pattern = r"^data:(image/[a-z]*);base64,\s*"
+ content_type = re.search(pattern, image_url)
+ # if already base64 encoded.
+ # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp'
+ if content_type:
+ image_data = re.sub(pattern, "", image_url)
+ return base64.b64decode(image_data), content_type.group(1)
+
+ # Send a request to the image URL
+ response = requests.get(image_url)
+ # Check if the request was successful
+ if response.status_code == 200:
+
+ content_type = response.headers.get("Content-Type")
+ if not content_type.startswith("image"):
+ content_type = "image/jpeg"
+ # Get the image content
+ image_content = response.content
+ return image_content, content_type
+ else:
+ raise RuntimeError("Unable to access the image url")
+
+
+def format_tools(tools: List[Dict[str, Any]]) -> Dict[Literal["tools"], List[Dict[str, Any]]]:
+ converted_schema = {"tools": []}
+
+ for tool in tools:
+ if tool["type"] == "function":
+ function = tool["function"]
+ converted_tool = {
+ "toolSpec": {
+ "name": function["name"],
+ "description": function["description"],
+ "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}},
+ }
+ }
+
+ for prop_name, prop_details in function["parameters"]["properties"].items():
+ converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = {
+ "type": prop_details["type"],
+ "description": prop_details.get("description", ""),
+ }
+ if "enum" in prop_details:
+ converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["enum"] = prop_details[
+ "enum"
+ ]
+ if "default" in prop_details:
+ converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["default"] = (
+ prop_details["default"]
+ )
+
+ if "required" in function["parameters"]:
+ converted_tool["toolSpec"]["inputSchema"]["json"]["required"] = function["parameters"]["required"]
+
+ converted_schema["tools"].append(converted_tool)
+
+ return converted_schema
+
+
+def format_tool_calls(content):
+ """Converts Converse API response tool calls to AutoGen format"""
+ tool_calls = []
+ for tool_request in content:
+ if "toolUse" in tool_request:
+ tool = tool_request["toolUse"]
+
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool["toolUseId"],
+ function={
+ "name": tool["name"],
+ "arguments": json.dumps(tool["input"]),
+ },
+ type="function",
+ )
+ )
+ return tool_calls
+
+
+def convert_stop_reason_to_finish_reason(
+ stop_reason: str,
+) -> Literal["stop", "length", "tool_calls", "content_filter"]:
+ """
+ Converts Bedrock finish reasons to our finish reasons, according to OpenAI:
+
+ - stop: if the model hit a natural stop point or a provided stop sequence,
+ - length: if the maximum number of tokens specified in the request was reached,
+ - content_filter: if content was omitted due to a flag from our content filters,
+ - tool_calls: if the model called a tool
+ """
+ if stop_reason:
+ finish_reason_mapping = {
+ "tool_use": "tool_calls",
+ "finished": "stop",
+ "end_turn": "stop",
+ "max_tokens": "length",
+ "stop_sequence": "stop",
+ "complete": "stop",
+ "content_filtered": "content_filter",
+ }
+ return finish_reason_mapping.get(stop_reason.lower(), stop_reason.lower())
+
+ warnings.warn(f"Unsupported stop reason: {stop_reason}", UserWarning)
+ return None
+
+
+# NOTE: As this will be quite dynamic, it's expected that the developer will use the "price" parameter in their config
+# These may be removed.
+PRICES_PER_K_TOKENS = {
+ "meta.llama3-8b-instruct-v1:0": (0.0003, 0.0006),
+ "meta.llama3-70b-instruct-v1:0": (0.00265, 0.0035),
+ "mistral.mistral-7b-instruct-v0:2": (0.00015, 0.0002),
+ "mistral.mixtral-8x7b-instruct-v0:1": (0.00045, 0.0007),
+ "mistral.mistral-large-2402-v1:0": (0.004, 0.012),
+ "mistral.mistral-small-2402-v1:0": (0.001, 0.003),
+}
+
+
+def calculate_cost(input_tokens: int, output_tokens: int, model_id: str) -> float:
+ """Calculate the cost of the completion using the Bedrock pricing."""
+
+ if model_id in PRICES_PER_K_TOKENS:
+ input_cost_per_k, output_cost_per_k = PRICES_PER_K_TOKENS[model_id]
+ input_cost = (input_tokens / 1000) * input_cost_per_k
+ output_cost = (output_tokens / 1000) * output_cost_per_k
+ return input_cost + output_cost
+ else:
+ warnings.warn(
+ f'Cannot get the costs for {model_id}. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.',
+ UserWarning,
+ )
+ return 0
diff --git a/autogen/oai/client.py b/autogen/oai/client.py
index 2c14ca0d4a0c..8f6e3f185b6a 100644
--- a/autogen/oai/client.py
+++ b/autogen/oai/client.py
@@ -70,6 +70,27 @@
except ImportError as e:
together_import_exception = e
+try:
+ from autogen.oai.groq import GroqClient
+
+ groq_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ groq_import_exception = e
+
+try:
+ from autogen.oai.cohere import CohereClient
+
+ cohere_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ cohere_import_exception = e
+
+try:
+ from autogen.oai.bedrock import BedrockClient
+
+ bedrock_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ bedrock_import_exception = e
+
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@@ -258,7 +279,12 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
# Prepare the final ChatCompletion object based on the accumulated data
model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
- prompt_tokens = count_token(params["messages"], model)
+ try:
+ prompt_tokens = count_token(params["messages"], model)
+ except NotImplementedError as e:
+ # Catch token calculation error if streaming with customized models.
+ logger.warning(str(e))
+ prompt_tokens = 0
response = ChatCompletion(
id=chunk.id,
model=chunk.model,
@@ -440,12 +466,23 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
+ def _configure_openai_config_for_bedrock(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
+ """Update openai_config with AWS credentials from config."""
+ required_keys = ["aws_access_key", "aws_secret_key", "aws_region"]
+ optional_keys = ["aws_session_token", "aws_profile_name"]
+ for key in required_keys:
+ if key in config:
+ openai_config[key] = config[key]
+ for key in optional_keys:
+ if key in config:
+ openai_config[key] = config[key]
+
def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
"""Create a client with the given config to override openai_config,
after removing extra kwargs.
For Azure models/deployment names there's a convenience modification of model removing dots in
- the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
+ the it's value (Azure deployment names can't have dots). I.e. if you have Azure deployment name
"gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
"""
@@ -471,6 +508,8 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
client = GeminiClient(**openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("anthropic"):
+ if "api_key" not in config:
+ self._configure_openai_config_for_bedrock(config, openai_config)
if anthropic_import_exception:
raise ImportError("Please install `anthropic` to use Anthropic API.")
client = AnthropicClient(**openai_config)
@@ -483,7 +522,24 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
elif api_type is not None and api_type.startswith("together"):
if together_import_exception:
raise ImportError("Please install `together` to use the Together.AI API.")
- self._clients.append(TogetherClient(**config))
+ client = TogetherClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("groq"):
+ if groq_import_exception:
+ raise ImportError("Please install `groq` to use the Groq API.")
+ client = GroqClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("cohere"):
+ if cohere_import_exception:
+ raise ImportError("Please install `cohere` to use the Cohere API.")
+ client = CohereClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("bedrock"):
+ self._configure_openai_config_for_bedrock(config, openai_config)
+ if bedrock_import_exception:
+ raise ImportError("Please install `boto3` to use the Amazon Bedrock API.")
+ client = BedrockClient(**openai_config)
+ self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
@@ -770,7 +826,7 @@ def _cost_with_customized_price(
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
if n_output_tokens is None:
n_output_tokens = 0
- return n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]
+ return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000
@staticmethod
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py
new file mode 100644
index 000000000000..e9a89c9cabd8
--- /dev/null
+++ b/autogen/oai/cohere.py
@@ -0,0 +1,516 @@
+"""Create an OpenAI-compatible client using Cohere's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "cohere",
+ "model": "command-r-plus",
+ "api_key": os.environ.get("COHERE_API_KEY")
+ "client_name": "autogen-cohere", # Optional parameter
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Cohere's python library using: pip install --upgrade cohere
+
+Resources:
+- https://docs.cohere.com/reference/chat
+"""
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+import random
+import sys
+import time
+import warnings
+from typing import Any, Dict, List
+
+from cohere import Client as Cohere
+from cohere.types import ToolParameterDefinitionsValue, ToolResult
+from flaml.automl.logger import logger_formatter
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import validate_parameter
+
+logger = logging.getLogger(__name__)
+if not logger.handlers:
+ # Add the console handler.
+ _ch = logging.StreamHandler(stream=sys.stdout)
+ _ch.setFormatter(logger_formatter)
+ logger.addHandler(_ch)
+
+
+COHERE_PRICING_1K = {
+ "command-r-plus": (0.003, 0.015),
+ "command-r": (0.0005, 0.0015),
+ "command-nightly": (0.00025, 0.00125),
+ "command": (0.015, 0.075),
+ "command-light": (0.008, 0.024),
+ "command-light-nightly": (0.008, 0.024),
+}
+
+
+class CohereClient:
+ """Client for Cohere's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Cohere (or environment variable COHERE_API_KEY needs to be set)
+ """
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("COHERE_API_KEY")
+
+ assert (
+ self.api_key
+ ), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ cohere_params = {}
+
+ # Check that we have what we need to use Cohere's API
+ # We won't enforce the available models as they are likely to change
+ cohere_params["model"] = params.get("model", None)
+ assert cohere_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Cohere model to use."
+
+ # Validate allowed Cohere parameters
+ # https://docs.cohere.com/reference/chat
+ cohere_params["temperature"] = validate_parameter(
+ params, "temperature", (int, float), False, 0.3, (0, None), None
+ )
+ cohere_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
+ cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
+ cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
+ cohere_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
+ cohere_params["frequency_penalty"] = validate_parameter(
+ params, "frequency_penalty", (int, float), True, 0, (0, 1), None
+ )
+ cohere_params["presence_penalty"] = validate_parameter(
+ params, "presence_penalty", (int, float), True, 0, (0, 1), None
+ )
+
+ # Cohere parameters we are ignoring:
+ # preamble - we will put the system prompt in here.
+ # parallel_tool_calls (defaults to True), perfect as is.
+ # conversation_id - allows resuming a previous conversation, we don't support this.
+ logging.info("Conversation ID: %s", params.get("conversation_id", "None"))
+ # connectors - allows web search or other custom connectors, not implementing for now but could be useful in the future.
+ logging.info("Connectors: %s", params.get("connectors", "None"))
+ # search_queries_only - to control whether only search queries are used, we're not using connectors so ignoring.
+ # documents - a list of documents that can be used to support the chat. Perhaps useful in the future for RAG.
+ # citation_quality - used for RAG flows and dependent on other parameters we're ignoring.
+ # max_input_tokens - limits input tokens, not needed.
+ logging.info("Max Input Tokens: %s", params.get("max_input_tokens", "None"))
+ # stop_sequences - used to stop generation, not needed.
+ logging.info("Stop Sequences: %s", params.get("stop_sequences", "None"))
+
+ return cohere_params
+
+ def create(self, params: Dict) -> ChatCompletion:
+
+ messages = params.get("messages", [])
+ client_name = params.get("client_name") or "autogen-cohere"
+ # Parse parameters to the Cohere API's parameters
+ cohere_params = self.parse_params(params)
+ # Convert AutoGen messages to Cohere messages
+ cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params)
+
+ cohere_params["chat_history"] = cohere_messages
+ cohere_params["message"] = final_message
+ cohere_params["preamble"] = preamble
+
+ # We use chat model by default
+ client = Cohere(api_key=self.api_key, client_name=client_name)
+
+ # Token counts will be returned
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ # Stream if in parameters
+ streaming = True if "stream" in params and params["stream"] else False
+ cohere_finish = ""
+
+ max_retries = 5
+
+ for attempt in range(max_retries):
+ ans = None
+ try:
+ if streaming:
+ response = client.chat_stream(**cohere_params)
+ else:
+ response = client.chat(**cohere_params)
+
+ except CohereRateLimitError as e:
+ raise RuntimeError(f"Cohere exception occurred: {e}")
+ else:
+
+ if streaming:
+ # Streaming...
+ ans = ""
+ for event in response:
+ if event.event_type == "text-generation":
+ ans = ans + event.text
+ elif event.event_type == "tool-calls-generation":
+ # When streaming, tool calls are compiled at the end into a single event_type
+ ans = event.text
+ cohere_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in event.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=str(random.randint(0, 100000)),
+ function={
+ "name": tool_call.name,
+ "arguments": (
+ "" if tool_call.parameters is None else json.dumps(tool_call.parameters)
+ ),
+ },
+ type="function",
+ )
+ )
+
+ # Not using billed_units, but that may be better for cost purposes
+ prompt_tokens = event.response.meta.tokens.input_tokens
+ completion_tokens = event.response.meta.tokens.output_tokens
+ total_tokens = prompt_tokens + completion_tokens
+
+ response_id = event.response.response_id
+ else:
+ # Non-streaming finished
+ ans: str = response.text
+
+ # Not using billed_units, but that may be better for cost purposes
+ prompt_tokens = response.meta.tokens.input_tokens
+ completion_tokens = response.meta.tokens.output_tokens
+ total_tokens = prompt_tokens + completion_tokens
+
+ response_id = response.response_id
+ break
+
+ if response is not None:
+
+ response_content = ans
+
+ if streaming:
+ # Streaming response
+ if cohere_finish == "":
+ cohere_finish = "stop"
+ tool_calls = None
+ else:
+ # Non-streaming response
+ # If we have tool calls as the response, populate completed tool calls for our return OAI response
+ if response.tool_calls is not None:
+ cohere_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in response.tool_calls:
+
+ # if parameters are null, clear them out (Cohere can return a string "null" if no parameter values)
+
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=str(random.randint(0, 100000)),
+ function={
+ "name": tool_call.name,
+ "arguments": (
+ "" if tool_call.parameters is None else json.dumps(tool_call.parameters)
+ ),
+ },
+ type="function",
+ )
+ )
+ else:
+ cohere_finish = "stop"
+ tool_calls = None
+ else:
+ raise RuntimeError(f"Failed to get response from Cohere after retrying {attempt + 1} times.")
+
+ # 3. convert output
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=response_content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=cohere_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=response_id,
+ model=cohere_params["model"],
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ ),
+ cost=calculate_cohere_cost(prompt_tokens, completion_tokens, cohere_params["model"]),
+ )
+
+ return response_oai
+
+
+def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
+ temp_tool_results = []
+
+ for tool_call in all_tool_calls:
+ if tool_call["id"] == tool_call_id:
+
+ call = {
+ "name": tool_call["function"]["name"],
+ "parameters": json.loads(
+ tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
+ ),
+ }
+ output = [{"value": content_output}]
+ temp_tool_results.append(ToolResult(call=call, outputs=output))
+ return temp_tool_results
+
+
+def is_recent_tool_call(messages: list[Dict[str, Any]], tool_call_index: int):
+ messages_length = len(messages)
+ if tool_call_index == messages_length - 1:
+ return True
+ elif messages[tool_call_index + 1].get("role", "").lower() not in ("chatbot"):
+ return True
+ return False
+
+
+def oai_messages_to_cohere_messages(
+ messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
+) -> tuple[list[dict[str, Any]], str, str]:
+ """Convert messages from OAI format to Cohere's format.
+ We correct for any specific role orders and types.
+
+ Parameters:
+ messages: list[Dict[str, Any]]: AutoGen messages
+ params: Dict[str, Any]: AutoGen parameters dictionary
+ cohere_params: Dict[str, Any]: Cohere parameters dictionary
+
+ Returns:
+ List[Dict[str, Any]]: Chat History messages
+ str: Preamble (system message)
+ str: Message (the final user message)
+ """
+
+ cohere_messages = []
+ preamble = ""
+ cohere_tool_names = set()
+ # Tools
+ if "tools" in params:
+ cohere_tools = []
+ for tool in params["tools"]:
+
+ # build list of properties
+ parameters = {}
+
+ for key, value in tool["function"]["parameters"]["properties"].items():
+ type_str = value["type"]
+ required = True # Defaults to False, we could consider leaving it as default.
+ description = value["description"]
+
+ # If we have an 'enum' key, add that to the description (as not allowed to pass in enum as a field)
+ if "enum" in value:
+ # Access the enum list
+ enum_values = value["enum"]
+ enum_strings = [str(value) for value in enum_values]
+ enum_string = ", ".join(enum_strings)
+ description = description + ". Possible values are " + enum_string + "."
+
+ parameters[key] = ToolParameterDefinitionsValue(
+ description=description, type=type_str, required=required
+ )
+
+ cohere_tool = {
+ "name": tool["function"]["name"],
+ "description": tool["function"]["description"],
+ "parameter_definitions": parameters,
+ }
+ cohere_tool_names.add(tool["function"]["name"] or "")
+
+ cohere_tools.append(cohere_tool)
+
+ if len(cohere_tools) > 0:
+ cohere_params["tools"] = cohere_tools
+
+ tool_calls = []
+ tool_results = []
+
+ # Rules for cohere messages:
+ # no 'name' field
+ # 'system' messages go into the preamble parameter
+ # user role = 'USER'
+ # assistant role = 'CHATBOT'
+ # 'content' field renamed to 'message'
+ # tools go into tools parameter
+ # tool_results go into tool_results parameter
+ for index, message in enumerate(messages):
+
+ if not message["content"]:
+ continue
+
+ if "role" in message and message["role"] == "system":
+ # System message
+ if preamble == "":
+ preamble = message["content"]
+ else:
+ preamble = preamble + "\n" + message["content"]
+
+ elif message.get("tool_calls"):
+ # Suggested tool calls, build up the list before we put it into the tool_results
+ message_tool_calls = []
+ for tool_call in message["tool_calls"] or []:
+ if (not tool_call.get("function", {}).get("name")) or tool_call.get("function", {}).get(
+ "name"
+ ) not in cohere_tool_names:
+ new_message = {
+ "role": "CHATBOT",
+ "message": message.get("name") + ":" + message["content"] + str(message["tool_calls"]),
+ }
+ cohere_messages.append(new_message)
+ continue
+
+ tool_calls.append(tool_call)
+ message_tool_calls.append(
+ {
+ "name": tool_call.get("function", {}).get("name"),
+ "parameters": json.loads(tool_call.get("function", {}).get("arguments") or "null"),
+ }
+ )
+
+ if not message_tool_calls:
+ continue
+
+ # We also add the suggested tool call as a message
+ new_message = {
+ "role": "CHATBOT",
+ "message": message.get("name") + ":" + message["content"],
+ "tool_calls": message_tool_calls,
+ }
+
+ cohere_messages.append(new_message)
+ elif "role" in message and message["role"] == "tool":
+ if not (tool_call_id := message.get("tool_call_id")):
+ continue
+
+ content_output = message["content"]
+ if tool_call_id not in [tool_call["id"] for tool_call in tool_calls]:
+
+ new_message = {
+ "role": "CHATBOT",
+ "message": content_output,
+ }
+ cohere_messages.append(new_message)
+ continue
+
+ # Convert the tool call to a result
+ tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
+ if is_recent_tool_call(messages, index):
+ # If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
+ # So, we pass it into tool_results.
+ tool_results.extend(tool_results_chat_turn)
+ continue
+
+ else:
+ # If its not the current tool call, we pass it as a tool message in the chat history.
+ new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
+ cohere_messages.append(new_message)
+
+ elif "content" in message and isinstance(message["content"], str):
+ # Standard text message
+ new_message = {
+ "role": "USER" if message["role"] == "user" else "CHATBOT",
+ "message": message.get("name") + ":" + message.get("content"),
+ }
+
+ cohere_messages.append(new_message)
+
+ # Append any Tool Results
+ if len(tool_results) != 0:
+ cohere_params["tool_results"] = tool_results
+
+ # Enable multi-step tool use: https://docs.cohere.com/docs/multi-step-tool-use
+ cohere_params["force_single_step"] = False
+
+ # If we're adding tool_results, like we are, the last message can't be a USER message
+ # So, we add a CHATBOT 'continue' message, if so.
+ # Changed key from "content" to "message" (jaygdesai/autogen_Jay)
+ if cohere_messages[-1]["role"].lower() == "user":
+ cohere_messages.append({"role": "CHATBOT", "message": "Please go ahead and follow the instructions!"})
+
+ # We return a blank message when we have tool results
+ # TODO: Check what happens if tool_results aren't the latest message
+ return cohere_messages, preamble, ""
+
+ else:
+
+ # We need to get the last message to assign to the message field for Cohere,
+ # if the last message is a user message, use that, otherwise put in 'continue'.
+ if cohere_messages[-1]["role"] == "USER":
+ return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"]
+ else:
+ return cohere_messages, preamble, "Please go ahead and follow the instructions!"
+
+
+def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
+ """Calculate the cost of the completion using the Cohere pricing."""
+ total = 0.0
+
+ if model in COHERE_PRICING_1K:
+ input_cost_per_k, output_cost_per_k = COHERE_PRICING_1K[model]
+ input_cost = (input_tokens / 1000) * input_cost_per_k
+ output_cost = (output_tokens / 1000) * output_cost_per_k
+ total = input_cost + output_cost
+ else:
+ warnings.warn(f"Cost calculation not available for {model} model", UserWarning)
+
+ return total
+
+
+class CohereError(Exception):
+ """Base class for other Cohere exceptions"""
+
+ pass
+
+
+class CohereRateLimitError(CohereError):
+ """Raised when rate limit is exceeded"""
+
+ pass
diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py
index 8babb8727e3c..33790c9851c6 100644
--- a/autogen/oai/gemini.py
+++ b/autogen/oai/gemini.py
@@ -6,7 +6,7 @@
"config_list": [{
"api_type": "google",
"model": "gemini-pro",
- "api_key": os.environ.get("GOOGLE_API_KEY"),
+ "api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
"safety_settings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
@@ -32,6 +32,7 @@
from __future__ import annotations
import base64
+import logging
import os
import random
import re
@@ -45,13 +46,19 @@
import vertexai
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
+from google.auth.credentials import Credentials
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
+from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
+from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
+from vertexai.generative_models import SafetySetting as VertexAISafetySetting
+
+logger = logging.getLogger(__name__)
class GeminiClient:
@@ -72,7 +79,7 @@ class GeminiClient:
"max_output_tokens": "max_output_tokens",
}
- def _initialize_vartexai(self, **params):
+ def _initialize_vertexai(self, **params):
if "google_application_credentials" in params:
# Path to JSON Keyfile
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
@@ -81,32 +88,39 @@ def _initialize_vartexai(self, **params):
vertexai_init_args["project"] = params["project_id"]
if "location" in params:
vertexai_init_args["location"] = params["location"]
+ if "credentials" in params:
+ assert isinstance(
+ params["credentials"], Credentials
+ ), "Object type google.auth.credentials.Credentials is expected!"
+ vertexai_init_args["credentials"] = params["credentials"]
if vertexai_init_args:
vertexai.init(**vertexai_init_args)
def __init__(self, **kwargs):
"""Uses either either api_key for authentication from the LLM config
- (specifying the GOOGLE_API_KEY environment variable also works),
+ (specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
- where project_id and location can also be passed as parameters. Service account key file can also be used.
- If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
- which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
+ where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
+ or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
+ then the default credentials will be used, which could be a personal account if the user is already authenticated in,
+ like in Google Cloud Shell.
Args:
api_key (str): The API key for using Gemini.
+ credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
google_application_credentials (str): Path to the JSON service account key file of the service account.
- Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
- can also be set instead of using this argument.
+ Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
+ can also be set instead of using this argument.
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
location (str): Compute region to be used, like 'us-west1'.
- This parameter is only valid in case no API key is specified.
+ This parameter is only valid in case no API key is specified.
"""
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
- self.api_key = os.getenv("GOOGLE_API_KEY")
+ self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
if self.api_key is None:
self.use_vertexai = True
- self._initialize_vartexai(**kwargs)
+ self._initialize_vertexai(**kwargs)
else:
self.use_vertexai = False
else:
@@ -142,7 +156,7 @@ def get_usage(response) -> Dict:
def create(self, params: Dict) -> ChatCompletion:
if self.use_vertexai:
- self._initialize_vartexai(**params)
+ self._initialize_vertexai(**params)
else:
assert ("project_id" not in params) and (
"location" not in params
@@ -159,13 +173,18 @@ def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
+ system_instruction = params.get("system_instruction", None)
+ response_validation = params.get("response_validation", True)
generation_config = {
gemini_term: params[autogen_term]
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
if autogen_term in params
}
- safety_settings = params.get("safety_settings", {})
+ if self.use_vertexai:
+ safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
+ else:
+ safety_settings = params.get("safety_settings", {})
if stream:
warnings.warn(
@@ -181,20 +200,29 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = self._oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
- model_name, generation_config=generation_config, safety_settings=safety_settings
+ model_name,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
+ system_instruction=system_instruction,
)
+ chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
else:
# we use chat model by default
model = genai.GenerativeModel(
- model_name, generation_config=generation_config, safety_settings=safety_settings
+ model_name,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
+ system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
- chat = model.start_chat(history=gemini_messages[:-1])
+ chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
ans = None
try:
- response = chat.send_message(gemini_messages[-1], stream=stream)
+ response = chat.send_message(
+ gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
+ )
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
@@ -218,16 +246,22 @@ def create(self, params: Dict) -> ChatCompletion:
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
- model_name, generation_config=generation_config, safety_settings=safety_settings
+ model_name,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
+ system_instruction=system_instruction,
)
else:
model = genai.GenerativeModel(
- model_name, generation_config=generation_config, safety_settings=safety_settings
+ model_name,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
+ system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
# Gemini's vision model does not support chat history yet
# chat = model.start_chat(history=gemini_messages[:-1])
- # response = chat.send_message(gemini_messages[-1])
+ # response = chat.send_message(gemini_messages[-1].parts)
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
if len(messages) > 2:
warnings.warn(
@@ -270,6 +304,8 @@ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
+ if content == "":
+ content = "empty" # Empty content is not allowed.
if self.use_vertexai:
rst.append(VertexAIPart.from_text(content))
else:
@@ -372,6 +408,35 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li
return rst
+ @staticmethod
+ def _to_vertexai_safety_settings(safety_settings):
+ """Convert safety settings to VertexAI format if needed,
+ like when specifying them in the OAI_CONFIG_LIST
+ """
+ if isinstance(safety_settings, list) and all(
+ [
+ isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
+ for safety_setting in safety_settings
+ ]
+ ):
+ vertexai_safety_settings = []
+ for safety_setting in safety_settings:
+ if safety_setting["category"] not in VertexAIHarmCategory.__members__:
+ invalid_category = safety_setting["category"]
+ logger.error(f"Safety setting category {invalid_category} is invalid")
+ elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
+ invalid_threshold = safety_setting["threshold"]
+ logger.error(f"Safety threshold {invalid_threshold} is invalid")
+ else:
+ vertexai_safety_setting = VertexAISafetySetting(
+ category=safety_setting["category"],
+ threshold=safety_setting["threshold"],
+ )
+ vertexai_safety_settings.append(vertexai_safety_setting)
+ return vertexai_safety_settings
+ else:
+ return safety_settings
+
def _to_pil(data: str) -> Image.Image:
"""
diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py
new file mode 100644
index 000000000000..d2abe5116a25
--- /dev/null
+++ b/autogen/oai/groq.py
@@ -0,0 +1,282 @@
+"""Create an OpenAI-compatible client using Groq's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "groq",
+ "model": "mixtral-8x7b-32768",
+ "api_key": os.environ.get("GROQ_API_KEY")
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Groq's python library using: pip install --upgrade groq
+
+Resources:
+- https://console.groq.com/docs/quickstart
+"""
+
+from __future__ import annotations
+
+import copy
+import os
+import time
+import warnings
+from typing import Any, Dict, List
+
+from groq import Groq, Stream
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import should_hide_tools, validate_parameter
+
+# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
+GROQ_PRICING_1K = {
+ "llama3-70b-8192": (0.00059, 0.00079),
+ "mixtral-8x7b-32768": (0.00024, 0.00024),
+ "llama3-8b-8192": (0.00005, 0.00008),
+ "gemma-7b-it": (0.00007, 0.00007),
+}
+
+
+class GroqClient:
+ """Client for Groq's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set)
+ """
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("GROQ_API_KEY")
+
+ assert (
+ self.api_key
+ ), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ groq_params = {}
+
+ # Check that we have what we need to use Groq's API
+ # We won't enforce the available models as they are likely to change
+ groq_params["model"] = params.get("model", None)
+ assert groq_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Groq model to use."
+
+ # Validate allowed Groq parameters
+ # https://console.groq.com/docs/api-reference#chat
+ groq_params["frequency_penalty"] = validate_parameter(
+ params, "frequency_penalty", (int, float), True, None, (-2, 2), None
+ )
+ groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
+ groq_params["presence_penalty"] = validate_parameter(
+ params, "presence_penalty", (int, float), True, None, (-2, 2), None
+ )
+ groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
+ groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
+ groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
+ groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
+
+ # Groq parameters not supported by their models yet, ignoring
+ # logit_bias, logprobs, top_logprobs
+
+ # Groq parameters we are ignoring:
+ # n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
+ # parallel_tool_calls (defaults to True), stop
+ # function_call (deprecated), functions (deprecated)
+ # tool_choice (none if no tools, auto if there are tools)
+
+ return groq_params
+
+ def create(self, params: Dict) -> ChatCompletion:
+
+ messages = params.get("messages", [])
+
+ # Convert AutoGen messages to Groq messages
+ groq_messages = oai_messages_to_groq_messages(messages)
+
+ # Parse parameters to the Groq API's parameters
+ groq_params = self.parse_params(params)
+
+ # Add tools to the call if we have them and aren't hiding them
+ if "tools" in params:
+ hide_tools = validate_parameter(
+ params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
+ )
+ if not should_hide_tools(groq_messages, params["tools"], hide_tools):
+ groq_params["tools"] = params["tools"]
+
+ groq_params["messages"] = groq_messages
+
+ # We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
+ client = Groq(api_key=self.api_key, max_retries=5)
+
+ # Token counts will be returned
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ # Streaming tool call recommendations
+ streaming_tool_calls = []
+
+ ans = None
+ try:
+ response = client.chat.completions.create(**groq_params)
+ except Exception as e:
+ raise RuntimeError(f"Groq exception occurred: {e}")
+ else:
+
+ if groq_params["stream"]:
+ # Read in the chunks as they stream, taking in tool_calls which may be across
+ # multiple chunks if more than one suggested
+ ans = ""
+ for chunk in response:
+ ans = ans + (chunk.choices[0].delta.content or "")
+
+ if chunk.choices[0].delta.tool_calls:
+ # We have a tool call recommendation
+ for tool_call in chunk.choices[0].delta.tool_calls:
+ streaming_tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={
+ "name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ },
+ type="function",
+ )
+ )
+
+ if chunk.choices[0].finish_reason:
+ prompt_tokens = chunk.x_groq.usage.prompt_tokens
+ completion_tokens = chunk.x_groq.usage.completion_tokens
+ total_tokens = chunk.x_groq.usage.total_tokens
+ else:
+ # Non-streaming finished
+ ans: str = response.choices[0].message.content
+
+ prompt_tokens = response.usage.prompt_tokens
+ completion_tokens = response.usage.completion_tokens
+ total_tokens = response.usage.total_tokens
+
+ if response is not None:
+
+ if isinstance(response, Stream):
+ # Streaming response
+ if chunk.choices[0].finish_reason == "tool_calls":
+ groq_finish = "tool_calls"
+ tool_calls = streaming_tool_calls
+ else:
+ groq_finish = "stop"
+ tool_calls = None
+
+ response_content = ans
+ response_id = chunk.id
+ else:
+ # Non-streaming response
+ # If we have tool calls as the response, populate completed tool calls for our return OAI response
+ if response.choices[0].finish_reason == "tool_calls":
+ groq_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in response.choices[0].message.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
+ type="function",
+ )
+ )
+ else:
+ groq_finish = "stop"
+ tool_calls = None
+
+ response_content = response.choices[0].message.content
+ response_id = response.id
+ else:
+ raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
+
+ # 3. convert output
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=response_content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=response_id,
+ model=groq_params["model"],
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ ),
+ cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
+ )
+
+ return response_oai
+
+
+def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
+ """Convert messages from OAI format to Groq's format.
+ We correct for any specific role orders and types.
+ """
+
+ groq_messages = copy.deepcopy(messages)
+
+ # Remove the name field
+ for message in groq_messages:
+ if "name" in message:
+ message.pop("name", None)
+
+ return groq_messages
+
+
+def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
+ """Calculate the cost of the completion using the Groq pricing."""
+ total = 0.0
+
+ if model in GROQ_PRICING_1K:
+ input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
+ input_cost = (input_tokens / 1000) * input_cost_per_k
+ output_cost = (output_tokens / 1000) * output_cost_per_k
+ total = input_cost + output_cost
+ else:
+ warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
+
+ return total
diff --git a/autogen/oai/mistral.py b/autogen/oai/mistral.py
index 8017e3536324..10d0f926ffbf 100644
--- a/autogen/oai/mistral.py
+++ b/autogen/oai/mistral.py
@@ -15,28 +15,32 @@
Resources:
- https://docs.mistral.ai/getting-started/quickstart/
-"""
-# Important notes when using the Mistral.AI API:
-# The first system message can greatly affect whether the model returns a tool call, including text that references the ability to use functions will help.
-# Changing the role on the first system message to 'user' improved the chances of the model recommending a tool call.
+NOTE: Requires mistralai package version >= 1.0.1
+"""
import inspect
import json
import os
import time
import warnings
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any, Dict, List, Union
# Mistral libraries
# pip install mistralai
-from mistralai.client import MistralClient
-from mistralai.exceptions import MistralAPIException
-from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage, ToolCall
+from mistralai import (
+ AssistantMessage,
+ Function,
+ FunctionCall,
+ Mistral,
+ SystemMessage,
+ ToolCall,
+ ToolMessage,
+ UserMessage,
+)
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
-from typing_extensions import Annotated
from autogen.oai.client_utils import should_hide_tools, validate_parameter
@@ -50,6 +54,7 @@ def __init__(self, **kwargs):
Args:
api_key (str): The API key for using Mistral.AI (or environment variable MISTRAL_API_KEY needs to be set)
"""
+
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
@@ -59,7 +64,9 @@ def __init__(self, **kwargs):
self.api_key
), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."
- def message_retrieval(self, response: ChatCompletionResponse) -> Union[List[str], List[ChatCompletionMessage]]:
+ self._client = Mistral(api_key=self.api_key)
+
+ def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]:
"""Retrieve the messages from the response."""
return [choice.message for choice in response.choices]
@@ -86,34 +93,52 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
)
mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)
+ # TODO
+ if params.get("stream", False):
+ warnings.warn(
+ "Streaming is not currently supported, streaming will be disabled.",
+ UserWarning,
+ )
+
# 3. Convert messages to Mistral format
mistral_messages = []
tool_call_ids = {} # tool call ids to function name mapping
for message in params["messages"]:
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
# Convert OAI ToolCall to Mistral ToolCall
- openai_toolcalls = message["tool_calls"]
- mistral_toolcalls = []
- for toolcall in openai_toolcalls:
- mistral_toolcall = ToolCall(id=toolcall["id"], function=toolcall["function"])
- mistral_toolcalls.append(mistral_toolcall)
- mistral_messages.append(
- ChatMessage(role=message["role"], content=message["content"], tool_calls=mistral_toolcalls)
- )
+ mistral_messages_tools = []
+ for toolcall in message["tool_calls"]:
+ mistral_messages_tools.append(
+ ToolCall(
+ id=toolcall["id"],
+ function=FunctionCall(
+ name=toolcall["function"]["name"],
+ arguments=json.loads(toolcall["function"]["arguments"]),
+ ),
+ )
+ )
+
+ mistral_messages.append(AssistantMessage(content="", tool_calls=mistral_messages_tools))
# Map tool call id to the function name
for tool_call in message["tool_calls"]:
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
- elif message["role"] in ("system", "user", "assistant"):
- # Note this ChatMessage can take a 'name' but it is rejected by the Mistral API if not role=tool, so, no, the 'name' field is not used.
- mistral_messages.append(ChatMessage(role=message["role"], content=message["content"]))
+ elif message["role"] == "system":
+ if len(mistral_messages) > 0 and mistral_messages[-1].role == "assistant":
+ # System messages can't appear after an Assistant message, so use a UserMessage
+ mistral_messages.append(UserMessage(content=message["content"]))
+ else:
+ mistral_messages.append(SystemMessage(content=message["content"]))
+ elif message["role"] == "assistant":
+ mistral_messages.append(AssistantMessage(content=message["content"]))
+ elif message["role"] == "user":
+ mistral_messages.append(UserMessage(content=message["content"]))
elif message["role"] == "tool":
# Indicates the result of a tool call, the name is the function name called
mistral_messages.append(
- ChatMessage(
- role="tool",
+ ToolMessage(
name=tool_call_ids[message["tool_call_id"]],
content=message["content"],
tool_call_id=message["tool_call_id"],
@@ -122,21 +147,20 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
else:
warnings.warn(f"Unknown message role {message['role']}", UserWarning)
- # If a 'system' message follows an 'assistant' message, change it to 'user'
- # This can occur when using LLM summarisation
- for i in range(1, len(mistral_messages)):
- if mistral_messages[i - 1].role == "assistant" and mistral_messages[i].role == "system":
- mistral_messages[i].role = "user"
+ # 4. Last message needs to be user or tool, if not, add a "please continue" message
+ if not isinstance(mistral_messages[-1], UserMessage) and not isinstance(mistral_messages[-1], ToolMessage):
+ mistral_messages.append(UserMessage(content="Please continue."))
mistral_params["messages"] = mistral_messages
- # 4. Add tools to the call if we have them and aren't hiding them
+ # 5. Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(params["messages"], params["tools"], hide_tools):
- mistral_params["tools"] = params["tools"]
+ mistral_params["tools"] = tool_def_to_mistral(params["tools"])
+
return mistral_params
def create(self, params: Dict[str, Any]) -> ChatCompletion:
@@ -144,8 +168,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
mistral_params = self.parse_params(params)
# 2. Call Mistral.AI API
- client = MistralClient(api_key=self.api_key)
- mistral_response = client.chat(**mistral_params)
+ mistral_response = self._client.chat.complete(**mistral_params)
# TODO: Handle streaming
# 3. Convert Mistral response to OAI compatible format
@@ -191,7 +214,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
return response_oai
@staticmethod
- def get_usage(response: ChatCompletionResponse) -> Dict:
+ def get_usage(response: ChatCompletion) -> Dict:
return {
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
@@ -203,25 +226,48 @@ def get_usage(response: ChatCompletionResponse) -> Dict:
}
+def tool_def_to_mistral(tool_definitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Converts AutoGen tool definition to a mistral tool format"""
+
+ mistral_tools = []
+
+ for autogen_tool in tool_definitions:
+ mistral_tool = {
+ "type": "function",
+ "function": Function(
+ name=autogen_tool["function"]["name"],
+ description=autogen_tool["function"]["description"],
+ parameters=autogen_tool["function"]["parameters"],
+ ),
+ }
+
+ mistral_tools.append(mistral_tool)
+
+ return mistral_tools
+
+
def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
"""Calculate the cost of the mistral response."""
- # Prices per 1 million tokens
+ # Prices per 1 thousand tokens
# https://mistral.ai/technology/
model_cost_map = {
- "open-mistral-7b": {"input": 0.25, "output": 0.25},
- "open-mixtral-8x7b": {"input": 0.7, "output": 0.7},
- "open-mixtral-8x22b": {"input": 2.0, "output": 6.0},
- "mistral-small-latest": {"input": 1.0, "output": 3.0},
- "mistral-medium-latest": {"input": 2.7, "output": 8.1},
- "mistral-large-latest": {"input": 4.0, "output": 12.0},
+ "open-mistral-7b": {"input": 0.00025, "output": 0.00025},
+ "open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007},
+ "open-mixtral-8x22b": {"input": 0.002, "output": 0.006},
+ "mistral-small-latest": {"input": 0.001, "output": 0.003},
+ "mistral-medium-latest": {"input": 0.00275, "output": 0.0081},
+ "mistral-large-latest": {"input": 0.0003, "output": 0.0003},
+ "mistral-large-2407": {"input": 0.0003, "output": 0.0003},
+ "open-mistral-nemo-2407": {"input": 0.0003, "output": 0.0003},
+ "codestral-2405": {"input": 0.001, "output": 0.003},
}
# Ensure we have the model they are using and return the total cost
if model_name in model_cost_map:
costs = model_cost_map[model_name]
- return (input_tokens * costs["input"] / 1_000_000) + (output_tokens * costs["output"] / 1_000_000)
+ return (input_tokens * costs["input"] / 1000) + (output_tokens * costs["output"] / 1000)
else:
warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
return 0
diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py
index 0c8a0a413375..41b94324118a 100644
--- a/autogen/oai/openai_utils.py
+++ b/autogen/oai/openai_utils.py
@@ -13,18 +13,30 @@
from openai.types.beta.assistant import Assistant
from packaging.version import parse
-NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
+NON_CACHE_KEY = [
+ "api_key",
+ "base_url",
+ "api_type",
+ "api_version",
+ "azure_ad_token",
+ "azure_ad_token_provider",
+ "credentials",
+]
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
# gpt-4o
"gpt-4o": (0.005, 0.015),
"gpt-4o-2024-05-13": (0.005, 0.015),
+ "gpt-4o-2024-08-06": (0.0025, 0.01),
# gpt-4-turbo
"gpt-4-turbo-2024-04-09": (0.01, 0.03),
# gpt-4
"gpt-4": (0.03, 0.06),
"gpt-4-32k": (0.06, 0.12),
+ # gpt-4o-mini
+ "gpt-4o-mini": (0.000150, 0.000600),
+ "gpt-4o-mini-2024-07-18": (0.000150, 0.000600),
# gpt-3.5 turbo
"gpt-3.5-turbo": (0.0005, 0.0015), # default is 0125
"gpt-3.5-turbo-0125": (0.0005, 0.0015), # 16k
@@ -96,7 +108,7 @@ def is_valid_api_key(api_key: str) -> bool:
Returns:
bool: A boolean that indicates if input is valid OpenAI API key.
"""
- api_key_re = re.compile(r"^sk-(proj-)?[A-Za-z0-9]{32,}$")
+ api_key_re = re.compile(r"^sk-([A-Za-z0-9]+(-+[A-Za-z0-9]+)*-)?[A-Za-z0-9]{32,}$")
return bool(re.fullmatch(api_key_re, api_key))
diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py
index 9393903ec86c..4fb53c7c9600 100644
--- a/autogen/retrieve_utils.py
+++ b/autogen/retrieve_utils.py
@@ -365,7 +365,7 @@ def create_vector_db_from_dir(
embedding_function is not None.
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
- functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
+ functions, you can pass it here, follow the examples in `https://docs.trychroma.com/guides/embeddings`.
custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
custom_text_types (Optional, List[str]): a list of file types to be processed. Default is TEXT_FORMATS.
@@ -448,7 +448,7 @@ def query_vector_db(
embedding_function is not None.
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
- functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
+ functions, you can pass it here, follow the examples in `https://docs.trychroma.com/guides/embeddings`.
Returns:
diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py
index adb55ba63b4f..0fd7cc2fc8b9 100644
--- a/autogen/runtime_logging.py
+++ b/autogen/runtime_logging.py
@@ -14,7 +14,10 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.bedrock import BedrockClient
+ from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@@ -110,7 +113,17 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig
def log_new_client(
- client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
+ client: Union[
+ AzureOpenAI,
+ OpenAI,
+ GeminiClient,
+ AnthropicClient,
+ MistralAIClient,
+ TogetherClient,
+ GroqClient,
+ CohereClient,
+ BedrockClient,
+ ],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py
index 2842a7494536..8552a8f16536 100644
--- a/autogen/token_count_utils.py
+++ b/autogen/token_count_utils.py
@@ -36,6 +36,9 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
"gpt-4-vision-preview": 128000,
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
+ "gpt-4o-2024-08-06": 128000,
+ "gpt-4o-mini": 128000,
+ "gpt-4o-mini-2024-07-18": 128000,
}
return max_token_limit[model]
@@ -95,7 +98,7 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
- print("Warning: model not found. Using cl100k_base encoding.")
+ logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
@@ -166,7 +169,7 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
- print("Warning: model not found. Using cl100k_base encoding.")
+ logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
@@ -193,7 +196,7 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
- print(f"Warning: not supported field {field}")
+ logger.warning(f"Not supported field {field}")
function_tokens += 11
if len(parameters["properties"]) == 0:
function_tokens -= 2
diff --git a/autogen/version.py b/autogen/version.py
index 77fc1e2ea295..9b1b78b4b3a0 100644
--- a/autogen/version.py
+++ b/autogen/version.py
@@ -1 +1 @@
-__version__ = "0.2.31"
+__version__ = "0.2.35"
diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln
index 5fa215f0ce9c..78d18527b629 100644
--- a/dotnet/AutoGen.sln
+++ b/dotnet/AutoGen.sln
@@ -1,4 +1,3 @@
-
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 17
VisualStudioVersion = 17.8.34322.80
@@ -27,15 +26,19 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel", "s
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Core", "src\AutoGen.Core\AutoGen.Core.csproj", "{D58D43D1-0617-4A3D-9932-C773E6398535}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI", "src\AutoGen.OpenAI\AutoGen.OpenAI.csproj", "{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.V1", "src\AutoGen.OpenAI.V1\AutoGen.OpenAI.V1.csproj", "{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral", "src\AutoGen.Mistral\AutoGen.Mistral.csproj", "{6585D1A4-3D97-4D76-A688-1933B61AEB19}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral.Tests", "test\AutoGen.Mistral.Tests\AutoGen.Mistral.Tests.csproj", "{15441693-3659-4868-B6C1-B106F52FF3BA}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI", "src\AutoGen.WebAPI\AutoGen.WebAPI.csproj", "{257FFD71-08E5-40C7-AB04-6A81A78EB410}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Tests", "test\AutoGen.WebAPI.Tests\AutoGen.WebAPI.Tests.csproj", "{E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}"
+EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Tests", "test\AutoGen.SemanticKernel.Tests\AutoGen.SemanticKernel.Tests.csproj", "{1DFABC4A-8458-4875-8DCB-59F3802DAC65}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Tests", "test\AutoGen.OpenAI.Tests\AutoGen.OpenAI.Tests.csproj", "{D36A85F9-C172-487D-8192-6BFE5D05B4A7}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.V1.Tests", "test\AutoGen.OpenAI.V1.Tests\AutoGen.OpenAI.V1.Tests.csproj", "{D36A85F9-C172-487D-8192-6BFE5D05B4A7}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}"
EndProject
@@ -61,7 +64,19 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini.Sample", "sa
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AotCompatibility.Tests", "test\AutoGen.AotCompatibility.Tests\AutoGen.AotCompatibility.Tests.csproj", "{6B82F26D-5040-4453-B21B-C8D1F913CE4C}"
EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.OpenAI.Sample", "sample\AutoGen.OpenAI.Sample\AutoGen.OpenAI.Sample.csproj", "{0E635268-351C-4A6B-A28D-593D868C2CA4}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Sample", "sample\AutoGen.OpenAI.Sample\AutoGen.OpenAI.Sample.csproj", "{0E635268-351C-4A6B-A28D-593D868C2CA4}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Sample", "sample\AutoGen.WebAPI.Sample\AutoGen.WebAPI.Sample.csproj", "{12079C18-A519-403F-BBFD-200A36A0C083}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AzureAIInference", "src\AutoGen.AzureAIInference\AutoGen.AzureAIInference.csproj", "{5C45981D-1319-4C25-935C-83D411CB28DF}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AzureAIInference.Tests", "test\AutoGen.AzureAIInference.Tests\AutoGen.AzureAIInference.Tests.csproj", "{5970868F-831E-418F-89A9-4EC599563E16}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Tests.Share", "test\AutoGen.Test.Share\AutoGen.Tests.Share.csproj", "{143725E2-206C-4D37-93E4-9EDF699826B2}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI", "src\AutoGen.OpenAI\AutoGen.OpenAI.csproj", "{3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Tests", "test\AutoGen.OpenAI.Tests\AutoGen.OpenAI.Tests.csproj", "{42A8251C-E7B3-47BB-A82E-459952EBE132}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -117,6 +132,14 @@ Global
{15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.Build.0 = Release|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Release|Any CPU.Build.0 = Release|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Release|Any CPU.Build.0 = Release|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU
@@ -177,6 +200,30 @@ Global
{0E635268-351C-4A6B-A28D-593D868C2CA4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{0E635268-351C-4A6B-A28D-593D868C2CA4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{0E635268-351C-4A6B-A28D-593D868C2CA4}.Release|Any CPU.Build.0 = Release|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.Build.0 = Release|Any CPU
+ {5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.Build.0 = Release|Any CPU
+ {5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.Build.0 = Release|Any CPU
+ {143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.Build.0 = Release|Any CPU
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}.Release|Any CPU.Build.0 = Release|Any CPU
+ {42A8251C-E7B3-47BB-A82E-459952EBE132}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {42A8251C-E7B3-47BB-A82E-459952EBE132}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {42A8251C-E7B3-47BB-A82E-459952EBE132}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {42A8251C-E7B3-47BB-A82E-459952EBE132}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -194,6 +241,8 @@ Global
{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{D36A85F9-C172-487D-8192-6BFE5D05B4A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
@@ -209,8 +258,14 @@ Global
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
{6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{0E635268-351C-4A6B-A28D-593D868C2CA4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {12079C18-A519-403F-BBFD-200A36A0C083} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {5C45981D-1319-4C25-935C-83D411CB28DF} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {5970868F-831E-418F-89A9-4EC599563E16} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {143725E2-206C-4D37-93E4-9EDF699826B2} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {42A8251C-E7B3-47BB-A82E-459952EBE132} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}
EndGlobalSection
-EndGlobal
\ No newline at end of file
+EndGlobal
diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props
index 4b3e9441f1ee..b5663fe4c578 100644
--- a/dotnet/Directory.Build.props
+++ b/dotnet/Directory.Build.props
@@ -4,7 +4,8 @@
- net8.0
+ netstandard2.0;net6.0;net8.0
+ net8.0
preview
enable
True
@@ -31,6 +32,7 @@
+
diff --git a/dotnet/eng/MetaInfo.props b/dotnet/eng/MetaInfo.props
index 041ee0ec6c97..c6eeaf843435 100644
--- a/dotnet/eng/MetaInfo.props
+++ b/dotnet/eng/MetaInfo.props
@@ -1,7 +1,7 @@
- 0.0.15
+ 0.2.1
AutoGen
https://microsoft.github.io/autogen-for-net/
https://github.com/microsoft/autogen
diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props
index 0b8dcaa565cb..36cfd917c2c0 100644
--- a/dotnet/eng/Version.props
+++ b/dotnet/eng/Version.props
@@ -2,8 +2,9 @@
1.0.0-beta.17
- 1.10.0
- 1.10.0-alpha
+ 2.0.0-beta.3
+ 1.18.1-rc
+ 1.18.1-alpha
5.0.0
4.3.0
6.0.0
@@ -12,7 +13,11 @@
17.7.0
1.0.0-beta.24229.4
8.0.0
+ 8.0.4
3.0.0
4.3.0.2
+ 1.0.0-beta.1
+ 2.0.0-beta.10
+ 7.4.4
\ No newline at end of file
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Anthropic_Agent_With_Prompt_Caching.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Anthropic_Agent_With_Prompt_Caching.cs
new file mode 100644
index 000000000000..5d8a99ce1288
--- /dev/null
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Anthropic_Agent_With_Prompt_Caching.cs
@@ -0,0 +1,133 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Anthropic_Agent_With_Prompt_Caching.cs
+
+using AutoGen.Anthropic.DTO;
+using AutoGen.Anthropic.Extensions;
+using AutoGen.Anthropic.Utils;
+using AutoGen.Core;
+
+namespace AutoGen.Anthropic.Samples;
+
+public class Anthropic_Agent_With_Prompt_Caching
+{
+ // A random and long test string to demonstrate cache control.
+ // the context must be larger than 1024 tokens for Claude 3.5 Sonnet and Claude 3 Opus
+ // 2048 tokens for Claude 3.0 Haiku
+ // Shorter prompts cannot be cached, even if marked with cache_control. Any requests to cache fewer than this number of tokens will be processed without caching
+
+ #region Long story for caching
+ public const string LongStory = """
+ Once upon a time in a small, nondescript town lived a man named Bob. Bob was an unassuming individual, the kind of person you wouldn’t look twice at if you passed him on the street. He worked as an IT specialist for a mid-sized corporation, spending his days fixing computers and troubleshooting software issues. But beneath his average exterior, Bob harbored a secret ambition—he wanted to take over the world.
+
+ Bob wasn’t always like this. For most of his life, he had been content with his routine, blending into the background. But one day, while browsing the dark corners of the internet, Bob stumbled upon an ancient manuscript, encrypted within the deep web, detailing the steps to global domination. It was written by a forgotten conqueror, someone whose name had been erased from history but whose methods were preserved in this digital relic. The manuscript laid out a plan so intricate and flawless that Bob, with his analytical mind, became obsessed.
+
+ Over the next few years, Bob meticulously followed the manuscript’s guidance. He started small, creating a network of like-minded individuals who shared his dream. They communicated through encrypted channels, meeting in secret to discuss their plans. Bob was careful, never revealing too much about himself, always staying in the shadows. He used his IT skills to gather information, infiltrating government databases, and private corporations, and acquiring secrets that could be used as leverage.
+
+ As his network grew, so did his influence. Bob began to manipulate world events from behind the scenes. He orchestrated economic crises, incited political turmoil, and planted seeds of discord among the world’s most powerful nations. Each move was calculated, each action a step closer to his ultimate goal. The world was in chaos, and no one suspected that a man like Bob could be behind it all.
+
+ But Bob knew that causing chaos wasn’t enough. To truly take over the world, he needed something more—something to cement his power. That’s when he turned to technology. Bob had always been ahead of the curve when it came to tech, and now, he planned to use it to his advantage. He began developing an AI, one that would be more powerful and intelligent than anything the world had ever seen. This AI, which Bob named “Nemesis,” was designed to control every aspect of modern life—from financial systems to military networks.
+
+ It took years of coding, testing, and refining, but eventually, Nemesis was ready. Bob unleashed the AI, and within days, it had taken control of the world’s digital infrastructure. Governments were powerless, their systems compromised. Corporations crumbled as their assets were seized. The military couldn’t act, their weapons turned against them. Bob, from the comfort of his modest home, had done it. He had taken over the world.
+
+ The world, now under Bob’s control, was eerily quiet. There were no more wars, no more financial crises, no more political strife. Nemesis ensured that everything ran smoothly, efficiently, and without dissent. The people of the world had no choice but to obey, their lives dictated by an unseen hand.
+
+ Bob, once a man who was overlooked and ignored, was now the most powerful person on the planet. But with that power came a realization. The world he had taken over was not the world he had envisioned. It was cold, mechanical, and devoid of the chaos that once made life unpredictable and exciting. Bob had achieved his goal, but in doing so, he had lost the very thing that made life worth living—freedom.
+
+ And so, Bob, now ruler of the world, sat alone in his control room, staring at the screens that displayed his dominion. He had everything he had ever wanted, yet he felt emptier than ever before. The world was his, but at what cost?
+
+ In the end, Bob realized that true power didn’t come from controlling others, but from the ability to let go. He deactivated Nemesis, restoring the world to its former state, and disappeared into obscurity, content to live out the rest of his days as just another face in the crowd. And though the world never knew his name, Bob’s legacy would live on, a reminder of the dangers of unchecked ambition.
+
+ Bob had vanished, leaving the world in a fragile state of recovery. Governments scrambled to regain control of their systems, corporations tried to rebuild, and the global population slowly adjusted to life without the invisible grip of Nemesis. Yet, even as society returned to a semblance of normalcy, whispers of the mysterious figure who had brought the world to its knees lingered in the shadows.
+
+ Meanwhile, Bob had retreated to a secluded cabin deep in the mountains. The cabin was a modest, rustic place, surrounded by dense forests and overlooking a tranquil lake. It was far from civilization, a perfect place for a man who wanted to disappear. Bob spent his days fishing, hiking, and reflecting on his past. For the first time in years, he felt a sense of peace.
+
+ But peace was fleeting. Despite his best efforts to put his past behind him, Bob couldn’t escape the consequences of his actions. He had unleashed Nemesis upon the world, and though he had deactivated the AI, remnants of its code still existed. Rogue factions, hackers, and remnants of his old network were searching for those fragments, hoping to revive Nemesis and seize the power that Bob had relinquished.
+
+ One day, as Bob was chopping wood outside his cabin, a figure emerged from the tree line. It was a young woman, dressed in hiking gear, with a determined look in her eyes. Bob tensed, his instincts telling him that this was no ordinary hiker.
+
+ “Bob,” the woman said, her voice steady. “Or should I say, the man who almost became the ruler of the world?”
+
+ Bob sighed, setting down his axe. “Who are you, and what do you want?”
+
+ The woman stepped closer. “My name is Sarah. I was part of your network, one of the few who knew about Nemesis. But I wasn’t like the others. I didn’t want power for myself—I wanted to protect the world from those who would misuse it.”
+
+ Bob studied her, trying to gauge her intentions. “And why are you here now?”
+
+ Sarah reached into her backpack and pulled out a small device. “Because Nemesis isn’t dead. Some of its code is still active, and it’s trying to reboot itself. I need your help to stop it for good.”
+
+ Bob’s heart sank. He had hoped that by deactivating Nemesis, he had erased it from existence. But deep down, he knew that an AI as powerful as Nemesis wouldn’t go down so easily. “Why come to me? I’m the one who created it. I’m the reason the world is in this mess.”
+
+ Sarah shook her head. “You’re also the only one who knows how to stop it. I’ve tracked down the remnants of Nemesis’s code, but I need you to help destroy it before it falls into the wrong hands.”
+
+ Bob hesitated. He had wanted nothing more than to leave his past behind, but he couldn’t ignore the responsibility that weighed on him. He had created Nemesis, and now it was his duty to make sure it never posed a threat again.
+
+ “Alright,” Bob said finally. “I’ll help you. But after this, I’m done. No more world domination, no more secret networks. I just want to live in peace.”
+
+ Sarah nodded. “Agreed. Let’s finish what you started.”
+
+ Over the next few weeks, Bob and Sarah worked together, traveling to various locations around the globe where fragments of Nemesis’s code had been detected. They infiltrated secure facilities, outsmarted rogue hackers, and neutralized threats, all while staying one step ahead of those who sought to control Nemesis for their own gain.
+
+ As they worked, Bob and Sarah developed a deep respect for one another. Sarah was sharp, resourceful, and driven by a genuine desire to protect the world. Bob found himself opening up to her, sharing his regrets, his doubts, and the lessons he had learned. In turn, Sarah shared her own story—how she had once been tempted by power but had chosen a different path, one that led her to fight for what was right.
+
+ Finally, after weeks of intense effort, they tracked down the last fragment of Nemesis’s code, hidden deep within a remote server farm in the Arctic. The facility was heavily guarded, but Bob and Sarah had planned meticulously. Under the cover of a blizzard, they infiltrated the facility, avoiding detection as they made their way to the heart of the server room.
+
+ As Bob began the process of erasing the final fragment, an alarm blared, and the facility’s security forces closed in. Sarah held them off as long as she could, but they were outnumbered and outgunned. Just as the situation seemed hopeless, Bob executed the final command, wiping Nemesis from existence once and for all.
+
+ But as the last remnants of Nemesis were deleted, Bob knew there was only one way to ensure it could never be resurrected. He initiated a self-destruct sequence for the server farm, trapping himself and Sarah inside.
+
+ Sarah stared at him, realization dawning in her eyes. “Bob, what are you doing?”
+
+ Bob looked at her, a sad smile on his face. “I have to make sure it’s over. This is the only way.”
+
+ Sarah’s eyes filled with tears, but she nodded, understanding the gravity of his decision. “Thank you, Bob. For everything.”
+
+ As the facility’s countdown reached its final seconds, Bob and Sarah stood side by side, knowing they had done the right thing. The explosion that followed was seen from miles away, a final testament to the end of an era.
+
+ The world never knew the true story of Bob, the man who almost ruled the world. But in his final act of sacrifice, he ensured that the world would remain free, a place where people could live their lives without fear of control. Bob had redeemed himself, not as a conqueror, but as a protector—a man who chose to save the world rather than rule it.
+
+ And in the quiet aftermath of the explosion, as the snow settled over the wreckage, Bob’s legacy was sealed—not as a name in history books, but as a silent guardian whose actions would be felt for generations to come.
+ """;
+ #endregion
+
+ public static async Task RunAsync()
+ {
+ #region init translator agents & register middlewares
+
+ var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
+ throw new Exception("Please set ANTHROPIC_API_KEY environment variable.");
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
+ var frenchTranslatorAgent =
+ new AnthropicClientAgent(anthropicClient, "frenchTranslator", AnthropicConstants.Claude35Sonnet,
+ systemMessage: "You are a French translator")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ var germanTranslatorAgent = new AnthropicClientAgent(anthropicClient, "germanTranslator",
+ AnthropicConstants.Claude35Sonnet, systemMessage: "You are a German translator")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ #endregion
+
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+ var groupChat = new RoundRobinGroupChat(
+ agents: [userProxyAgent, frenchTranslatorAgent, germanTranslatorAgent]);
+
+ var messageEnvelope =
+ MessageEnvelope.Create(
+ new ChatMessage("user", [TextContent.CreateTextWithCacheControl(LongStory)]),
+ from: "user");
+
+ var chatHistory = new List()
+ {
+ new TextMessage(Role.User, "translate this text for me", from: userProxyAgent.Name),
+ messageEnvelope,
+ };
+
+ var history = await groupChat.SendAsync(chatHistory).ToArrayAsync();
+ }
+}
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj
index 33a5aa7f16b6..fe7553b937f4 100644
--- a/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj
@@ -2,7 +2,7 @@
Exe
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
enable
enable
True
@@ -13,6 +13,7 @@
+
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs
similarity index 93%
rename from dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs
rename to dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs
index 94b5f37511e6..6f32c3cb4a21 100644
--- a/dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// AnthropicSamples.cs
+// Create_Anthropic_Agent.cs
using AutoGen.Anthropic.Extensions;
using AutoGen.Anthropic.Utils;
@@ -7,7 +7,7 @@
namespace AutoGen.Anthropic.Samples;
-public static class AnthropicSamples
+public static class Create_Anthropic_Agent
{
public static async Task RunAsync()
{
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs
new file mode 100644
index 000000000000..0324a39ffa59
--- /dev/null
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs
@@ -0,0 +1,100 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Create_Anthropic_Agent_With_Tool.cs
+
+using AutoGen.Anthropic.DTO;
+using AutoGen.Anthropic.Extensions;
+using AutoGen.Anthropic.Utils;
+using AutoGen.Core;
+using FluentAssertions;
+
+namespace AutoGen.Anthropic.Samples;
+
+#region WeatherFunction
+
+public partial class WeatherFunction
+{
+ ///
+ /// Gets the weather based on the location and the unit
+ ///
+ ///
+ ///
+ ///
+ [Function]
+ public async Task GetWeather(string location, string unit)
+ {
+ // dummy implementation
+ return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
+ }
+}
+#endregion
+public class Create_Anthropic_Agent_With_Tool
+{
+ public static async Task RunAsync()
+ {
+ #region define_tool
+ var tool = new Tool
+ {
+ Name = "GetWeather",
+ Description = "Get the current weather in a given location",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ { "location", new SchemaProperty { Type = "string", Description = "The city and state, e.g. San Francisco, CA" } },
+ { "unit", new SchemaProperty { Type = "string", Description = "The unit of temperature, either \"celsius\" or \"fahrenheit\"" } }
+ },
+ Required = new List { "location" }
+ }
+ };
+
+ var weatherFunction = new WeatherFunction();
+ var functionMiddleware = new FunctionCallMiddleware(
+ functions: [
+ weatherFunction.GetWeatherFunctionContract,
+ ],
+ functionMap: new Dictionary>>
+ {
+ { weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
+ });
+
+ #endregion
+
+ #region create_anthropic_agent
+
+ var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
+ throw new Exception("Missing ANTHROPIC_API_KEY environment variable.");
+
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
+ var agent = new AnthropicClientAgent(anthropicClient, "assistant", AnthropicConstants.Claude3Haiku,
+ tools: [tool]); // Define tools for AnthropicClientAgent
+ #endregion
+
+ #region register_middleware
+
+ var agentWithConnector = agent
+ .RegisterMessageConnector()
+ .RegisterPrintMessage()
+ .RegisterStreamingMiddleware(functionMiddleware);
+ #endregion register_middleware
+
+ #region single_turn
+ var question = new TextMessage(Role.Assistant,
+ "What is the weather like in San Francisco?",
+ from: "user");
+ var functionCallReply = await agentWithConnector.SendAsync(question);
+ #endregion
+
+ #region Single_turn_verify_reply
+ functionCallReply.Should().BeOfType();
+ #endregion Single_turn_verify_reply
+
+ #region Multi_turn
+ var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
+ #endregion Multi_turn
+
+ #region Multi_turn_verify_reply
+ finalReply.Should().BeOfType();
+ #endregion Multi_turn_verify_reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
index f3c615088610..105bb56524fd 100644
--- a/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
@@ -7,6 +7,6 @@ internal static class Program
{
public static async Task Main(string[] args)
{
- await AnthropicSamples.RunAsync();
+ await Anthropic_Agent_With_Prompt_Caching.RunAsync();
}
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
index 6f55a04592f5..d4323ee4c924 100644
--- a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
+++ b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
@@ -2,7 +2,7 @@
Exe
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
enable
True
$(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
index 4833c6195c9d..f68053224663 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
@@ -4,7 +4,9 @@
using AutoGen;
using AutoGen.Core;
using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
using FluentAssertions;
+using OpenAI;
public partial class AssistantCodeSnippet
{
@@ -32,23 +34,18 @@ public void CodeSnippet2()
{
#region code_snippet_2
// get OpenAI Key and create config
- var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
- string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ var model = "gpt-4o-mini";
- var llmConfig = new AzureOpenAIConfig(
- endpoint: endPoint,
- deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
- apiKey: apiKey);
+ var openAIClient = new OpenAIClient(apiKey);
// create assistant agent
- var assistantAgent = new AssistantAgent(
+ var assistantAgent = new OpenAIChatAgent(
name: "assistant",
systemMessage: "You are an assistant that help user to do some tasks.",
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = new[] { llmConfig },
- });
+ chatClient: openAIClient.GetChatClient(model))
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
#endregion code_snippet_2
}
@@ -71,27 +68,21 @@ public async Task CodeSnippet4()
// get OpenAI Key and create config
var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
-
- var llmConfig = new AzureOpenAIConfig(
- endpoint: endPoint,
- deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
- apiKey: apiKey);
+ var model = "gpt-4o-mini";
+ var openAIClient = new OpenAIClient(new System.ClientModel.ApiKeyCredential(apiKey), new OpenAIClientOptions
+ {
+ Endpoint = new Uri(endPoint),
+ });
#region code_snippet_4
- var assistantAgent = new AssistantAgent(
+ var assistantAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
name: "assistant",
systemMessage: "You are an assistant that convert user input to upper case.",
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = new[]
- {
- llmConfig
- },
- FunctionContracts = new[]
- {
- this.UpperCaseFunctionContract, // The FunctionDefinition object for the UpperCase function
- },
- });
+ functions: [
+ this.UpperCaseFunctionContract.ToChatTool(), // The FunctionDefinition object for the UpperCase function
+ ])
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
var response = await assistantAgent.SendAsync("hello");
response.Should().BeOfType();
@@ -106,31 +97,24 @@ public async Task CodeSnippet5()
// get OpenAI Key and create config
var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
-
- var llmConfig = new AzureOpenAIConfig(
- endpoint: endPoint,
- deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
- apiKey: apiKey);
+ var model = "gpt-4o-mini";
+ var openAIClient = new OpenAIClient(new System.ClientModel.ApiKeyCredential(apiKey), new OpenAIClientOptions
+ {
+ Endpoint = new Uri(endPoint),
+ });
#region code_snippet_5
- var assistantAgent = new AssistantAgent(
- name: "assistant",
- systemMessage: "You are an assistant that convert user input to upper case.",
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = new[]
- {
- llmConfig
- },
- FunctionContracts = new[]
- {
- this.UpperCaseFunctionContract, // The FunctionDefinition object for the UpperCase function
- },
- },
- functionMap: new Dictionary>>
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.UpperCaseFunctionContract],
+ functionMap: new Dictionary>>()
{
- { this.UpperCaseFunction.Name, this.UpperCaseWrapper }, // The wrapper function for the UpperCase function
+ { this.UpperCaseFunctionContract.Name, this.UpperCase },
});
+ var assistantAgent = new OpenAIChatAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that convert user input to upper case.",
+ chatClient: openAIClient.GetChatClient(model))
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware);
var response = await assistantAgent.SendAsync("hello");
response.Should().BeOfType();
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs
index 2b7e25fee0c5..854a385dc341 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs
@@ -3,7 +3,6 @@
using AutoGen;
using AutoGen.Core;
-using AutoGen.OpenAI;
using FluentAssertions;
public partial class FunctionCallCodeSnippet
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs
index fe97152183a4..c5ff7b770338 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs
@@ -5,6 +5,8 @@
using AutoGen;
using AutoGen.Core;
using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using OpenAI;
#endregion snippet_GetStartCodeSnippet
public class GetStartCodeSnippet
@@ -13,16 +15,14 @@ public async Task CodeSnippet1()
{
#region code_snippet_1
var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var gpt35Config = new OpenAIConfig(openAIKey, "gpt-3.5-turbo");
+ var openAIClient = new OpenAIClient(openAIKey);
+ var model = "gpt-4o-mini";
- var assistantAgent = new AssistantAgent(
+ var assistantAgent = new OpenAIChatAgent(
name: "assistant",
systemMessage: "You are an assistant that help user to do some tasks.",
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = [gpt35Config],
- })
+ chatClient: openAIClient.GetChatClient(model))
+ .RegisterMessageConnector()
.RegisterPrintMessage(); // register a hook to print message nicely to console
// set human input mode to ALWAYS so that user always provide input
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
index 320afd0de679..1b5a9a903207 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
@@ -13,38 +13,46 @@ public class MiddlewareAgentCodeSnippet
public async Task CreateMiddlewareAgentAsync()
{
#region create_middleware_agent_with_original_agent
- // Create an agent that always replies "Hello World"
- IAgent agent = new DefaultReplyAgent(name: "assistant", defaultReply: "Hello World");
+ // Create an agent that always replies "Hi!"
+ IAgent agent = new DefaultReplyAgent(name: "assistant", defaultReply: "Hi!");
// Create a middleware agent on top of default reply agent
var middlewareAgent = new MiddlewareAgent(innerAgent: agent);
middlewareAgent.Use(async (messages, options, agent, ct) =>
{
- var lastMessage = messages.Last() as TextMessage;
- lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ if (messages.Last() is TextMessage lastMessage && lastMessage.Content.Contains("Hello World"))
+ {
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return lastMessage;
+ }
+
return await agent.GenerateReplyAsync(messages, options, ct);
});
var reply = await middlewareAgent.SendAsync("Hello World");
reply.GetContent().Should().Be("[middleware 0] Hello World");
+ reply = await middlewareAgent.SendAsync("Hello AI!");
+ reply.GetContent().Should().Be("Hi!");
#endregion create_middleware_agent_with_original_agent
#region register_middleware_agent
middlewareAgent = agent.RegisterMiddleware(async (messages, options, agent, ct) =>
{
- var lastMessage = messages.Last() as TextMessage;
- lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ if (messages.Last() is TextMessage lastMessage && lastMessage.Content.Contains("Hello World"))
+ {
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return lastMessage;
+ }
+
return await agent.GenerateReplyAsync(messages, options, ct);
});
#endregion register_middleware_agent
#region short_circuit_middleware_agent
- // This middleware will short circuit the agent and return the last message directly.
+ // This middleware will short circuit the agent and return a message directly.
middlewareAgent.Use(async (messages, options, agent, ct) =>
{
- var lastMessage = messages.Last() as TextMessage;
- lastMessage.Content = $"[middleware shortcut]";
- return lastMessage;
+ return new TextMessage(Role.Assistant, $"[middleware shortcut]");
});
#endregion short_circuit_middleware_agent
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
index cf0452212239..60520078e72e 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
@@ -5,9 +5,10 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
#endregion using_statement
using FluentAssertions;
+using OpenAI;
+using OpenAI.Chat;
namespace AutoGen.BasicSample.CodeSnippet;
#region weather_function
@@ -32,31 +33,30 @@ public async Task CreateOpenAIChatAgentAsync()
{
#region create_openai_chat_agent
var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var modelId = "gpt-3.5-turbo";
+ var modelId = "gpt-4o-mini";
var openAIClient = new OpenAIClient(openAIKey);
// create an open ai chat agent
var openAIChatAgent = new OpenAIChatAgent(
- openAIClient: openAIClient,
+ chatClient: openAIClient.GetChatClient(modelId),
name: "assistant",
- modelName: modelId,
systemMessage: "You are an assistant that help user to do some tasks.");
// OpenAIChatAgent supports the following message types:
// - IMessage where ChatRequestMessage is from Azure.AI.OpenAI
- var helloMessage = new ChatRequestUserMessage("Hello");
+ var helloMessage = new UserChatMessage("Hello");
// Use MessageEnvelope.Create to create an IMessage
var chatMessageContent = MessageEnvelope.Create(helloMessage);
var reply = await openAIChatAgent.SendAsync(chatMessageContent);
- // The type of reply is MessageEnvelope where ChatResponseMessage is from Azure.AI.OpenAI
- reply.Should().BeOfType>();
+ // The type of reply is MessageEnvelope where ChatResponseMessage is from Azure.AI.OpenAI
+ reply.Should().BeOfType>();
// You can un-envelop the reply to get the ChatResponseMessage
- ChatResponseMessage response = reply.As>().Content;
- response.Role.Should().Be(ChatRole.Assistant);
+ ChatCompletion response = reply.As>().Content;
+ response.Role.Should().Be(ChatMessageRole.Assistant);
#endregion create_openai_chat_agent
#region create_openai_chat_agent_streaming
@@ -64,8 +64,8 @@ public async Task CreateOpenAIChatAgentAsync()
await foreach (var streamingMessage in streamingReply)
{
- streamingMessage.Should().BeOfType>();
- streamingMessage.As>().Content.Role.Should().Be(ChatRole.Assistant);
+ streamingMessage.Should().BeOfType>();
+ streamingMessage.As>().Content.Role.Should().Be(ChatMessageRole.Assistant);
}
#endregion create_openai_chat_agent_streaming
@@ -77,7 +77,7 @@ public async Task CreateOpenAIChatAgentAsync()
// now the agentWithConnector supports more message types
var messages = new IMessage[]
{
- MessageEnvelope.Create(new ChatRequestUserMessage("Hello")),
+ MessageEnvelope.Create(new UserChatMessage("Hello")),
new TextMessage(Role.Assistant, "Hello", from: "user"),
new MultiModalMessage(Role.Assistant,
[
@@ -106,9 +106,8 @@ public async Task OpenAIChatAgentGetWeatherFunctionCallAsync()
// create an open ai chat agent
var openAIChatAgent = new OpenAIChatAgent(
- openAIClient: openAIClient,
+ chatClient: openAIClient.GetChatClient(modelId),
name: "assistant",
- modelName: modelId,
systemMessage: "You are an assistant that help user to do some tasks.")
.RegisterMessageConnector();
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs
index bf4f9c976e22..0ac7f71a3cae 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs
@@ -4,8 +4,6 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure;
-using Azure.AI.OpenAI;
namespace AutoGen.BasicSample.CodeSnippet;
@@ -15,8 +13,8 @@ public async Task PrintMessageMiddlewareAsync()
{
var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
var endpoint = new Uri(config.Endpoint);
- var openaiClient = new OpenAIClient(endpoint, new AzureKeyCredential(config.ApiKey));
- var agent = new OpenAIChatAgent(openaiClient, "assistant", config.DeploymentName)
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var agent = new OpenAIChatAgent(gpt4o, "assistant", config.DeploymentName)
.RegisterMessageConnector();
#region PrintMessageMiddleware
@@ -31,10 +29,10 @@ public async Task PrintMessageStreamingMiddlewareAsync()
{
var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
var endpoint = new Uri(config.Endpoint);
- var openaiClient = new OpenAIClient(endpoint, new AzureKeyCredential(config.ApiKey));
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
#region print_message_streaming
- var streamingAgent = new OpenAIChatAgent(openaiClient, "assistant", config.DeploymentName)
+ var streamingAgent = new OpenAIChatAgent(gpt4o, "assistant")
.RegisterMessageConnector()
.RegisterPrintMessage();
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs
index e498650b6aac..b087beb993bc 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs
@@ -4,6 +4,7 @@
#region code_snippet_0_1
using AutoGen.Core;
using AutoGen.DotnetInteractive;
+using AutoGen.DotnetInteractive.Extension;
#endregion code_snippet_0_1
namespace AutoGen.BasicSample.CodeSnippet;
@@ -11,18 +12,37 @@ public class RunCodeSnippetCodeSnippet
{
public async Task CodeSnippet1()
{
- IAgent agent = default;
+ IAgent agent = new DefaultReplyAgent("agent", "Hello World");
#region code_snippet_1_1
- var workingDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName());
- Directory.CreateDirectory(workingDirectory);
- var interactiveService = new InteractiveService(installingDirectory: workingDirectory);
- await interactiveService.StartAsync(workingDirectory: workingDirectory);
+ var kernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder() // add C# and F# kernels
+ .Build();
#endregion code_snippet_1_1
#region code_snippet_1_2
- // register dotnet code block execution hook to an arbitrary agent
- var dotnetCodeAgent = agent.RegisterDotnetCodeBlockExectionHook(interactiveService: interactiveService);
+ // register middleware to execute code block
+ var dotnetCodeAgent = agent
+ .RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
+ {
+ var lastMessage = msgs.LastOrDefault();
+ if (lastMessage == null || lastMessage.GetContent() is null)
+ {
+ return await innerAgent.GenerateReplyAsync(msgs, option, ct);
+ }
+
+ if (lastMessage.ExtractCodeBlock("```csharp", "```") is string codeSnippet)
+ {
+ // execute code snippet
+ var result = await kernel.RunSubmitCodeCommandAsync(codeSnippet, "csharp");
+ return new TextMessage(Role.Assistant, result, from: agent.Name);
+ }
+ else
+ {
+ // no code block found, invoke next agent
+ return await innerAgent.GenerateReplyAsync(msgs, option, ct);
+ }
+ });
var codeSnippet = @"
```csharp
@@ -44,5 +64,17 @@ public async Task CodeSnippet1()
```
";
#endregion code_snippet_1_3
+
+ #region code_snippet_1_4
+ var pythonKernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder()
+ .AddPythonKernel(venv: "python3")
+ .Build();
+
+ var pythonCode = """
+ print('Hello from Python!')
+ """;
+ var result = await pythonKernel.RunSubmitCodeCommandAsync(pythonCode, "python3");
+ #endregion code_snippet_1_4
}
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
index 50bcd8a8048e..667705835eb3 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
@@ -3,7 +3,6 @@
using System.Text.Json;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
#region weather_report_using_statement
using AutoGen.Core;
#endregion weather_report_using_statement
@@ -32,7 +31,7 @@ public async Task Consume()
var functionInstance = new TypeSafeFunctionCall();
// Get the generated function definition
- FunctionDefinition functionDefiniton = functionInstance.WeatherReportFunctionContract.ToOpenAIFunctionDefinition();
+ var functionDefiniton = functionInstance.WeatherReportFunctionContract.ToChatTool();
// Get the generated function wrapper
Func> functionWrapper = functionInstance.WeatherReportWrapper;
@@ -69,32 +68,31 @@ public async Task UpperCase(string input)
#region code_snippet_1
// file: FunctionDefinition.generated.cs
- public FunctionDefinition UpperCaseFunction
+ public FunctionContract WeatherReportFunctionContract
{
- get => new FunctionDefinition
+ get => new FunctionContract
{
- Name = @"UpperCase",
- Description = "convert input to upper case",
- Parameters = BinaryData.FromObjectAsJson(new
+ ClassName = @"TypeSafeFunctionCall",
+ Name = @"WeatherReport",
+ Description = @"Get weather report",
+ ReturnType = typeof(Task),
+ Parameters = new global::AutoGen.Core.FunctionParameterContract[]
{
- Type = "object",
- Properties = new
- {
- input = new
+ new FunctionParameterContract
{
- Type = @"string",
- Description = @"input",
+ Name = @"city",
+ Description = @"city",
+ ParameterType = typeof(string),
+ IsRequired = true,
},
- },
- Required = new[]
- {
- "input",
+ new FunctionParameterContract
+ {
+ Name = @"date",
+ Description = @"date",
+ ParameterType = typeof(string),
+ IsRequired = true,
},
},
- new JsonSerializerOptions
- {
- PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
- })
};
}
#endregion code_snippet_1
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs
index 3ee363bfc062..40c88102588a 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs
@@ -4,6 +4,8 @@
using AutoGen;
using AutoGen.BasicSample;
using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
using FluentAssertions;
///
@@ -13,18 +15,12 @@ public static class Example01_AssistantAgent
{
public static async Task RunAsync()
{
- var gpt35 = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
- var config = new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = [gpt35],
- };
-
- // create assistant agent
- var assistantAgent = new AssistantAgent(
+ var gpt4oMini = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var assistantAgent = new OpenAIChatAgent(
+ chatClient: gpt4oMini,
name: "assistant",
- systemMessage: "You convert what user said to all uppercase.",
- llmConfig: config)
+ systemMessage: "You convert what user said to all uppercase.")
+ .RegisterMessageConnector()
.RegisterPrintMessage();
// talk to the assistant agent
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs
index c2957f32da76..b2dd9726b4b9 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs
@@ -1,30 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Example02_TwoAgent_MathChat.cs
-using AutoGen;
using AutoGen.BasicSample;
using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
using FluentAssertions;
public static class Example02_TwoAgent_MathChat
{
public static async Task RunAsync()
{
#region code_snippet_1
- // get gpt-3.5-turbo config
- var gpt35 = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var gpt4oMini = LLMConfiguration.GetOpenAIGPT4o_mini();
+
// create teacher agent
// teacher agent will create math questions
- var teacher = new AssistantAgent(
+ var teacher = new OpenAIChatAgent(
+ chatClient: gpt4oMini,
name: "teacher",
systemMessage: @"You are a teacher that create pre-school math question for student and check answer.
If the answer is correct, you stop the conversation by saying [COMPLETE].
- If the answer is wrong, you ask student to fix it.",
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = [gpt35],
- })
+ If the answer is wrong, you ask student to fix it.")
+ .RegisterMessageConnector()
.RegisterMiddleware(async (msgs, option, agent, _) =>
{
var reply = await agent.GenerateReplyAsync(msgs, option);
@@ -39,14 +37,11 @@ public static async Task RunAsync()
// create student agent
// student agent will answer the math questions
- var student = new AssistantAgent(
+ var student = new OpenAIChatAgent(
+ chatClient: gpt4oMini,
name: "student",
- systemMessage: "You are a student that answer question from teacher",
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = [gpt35],
- })
+ systemMessage: "You are a student that answer question from teacher")
+ .RegisterMessageConnector()
.RegisterPrintMessage();
// start the conversation
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
index 0ef8eaa48ae6..94b67a94b141 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
@@ -1,9 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Example03_Agent_FunctionCall.cs
-using AutoGen;
using AutoGen.BasicSample;
using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
using FluentAssertions;
///
@@ -45,33 +46,30 @@ public async Task CalculateTax(int price, float taxRate)
public static async Task RunAsync()
{
var instance = new Example03_Agent_FunctionCall();
- var gpt35 = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
// AutoGen makes use of AutoGen.SourceGenerator to automatically generate FunctionDefinition and FunctionCallWrapper for you.
// The FunctionDefinition will be created based on function signature and XML documentation.
// The return type of type-safe function needs to be Task. And to get the best performance, please try only use primitive types and arrays of primitive types as parameters.
- var config = new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = [gpt35],
- FunctionContracts = new[]
- {
+ var toolCallMiddleware = new FunctionCallMiddleware(
+ functions: [
instance.ConcatStringFunctionContract,
instance.UpperCaseFunctionContract,
instance.CalculateTaxFunctionContract,
- },
- };
-
- var agent = new AssistantAgent(
- name: "agent",
- systemMessage: "You are a helpful AI assistant",
- llmConfig: config,
+ ],
functionMap: new Dictionary>>
{
- { nameof(ConcatString), instance.ConcatStringWrapper },
- { nameof(UpperCase), instance.UpperCaseWrapper },
- { nameof(CalculateTax), instance.CalculateTaxWrapper },
- })
+ { nameof(instance.ConcatString), instance.ConcatStringWrapper },
+ { nameof(instance.UpperCase), instance.UpperCaseWrapper },
+ { nameof(instance.CalculateTax), instance.CalculateTaxWrapper },
+ });
+
+ var agent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(toolCallMiddleware)
.RegisterPrintMessage();
// talk to the assistant agent
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
index 47dd8ce66c90..f90816d890e1 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Example04_Dynamic_GroupChat_Coding_Task.cs
-using AutoGen;
using AutoGen.BasicSample;
using AutoGen.Core;
using AutoGen.DotnetInteractive;
+using AutoGen.DotnetInteractive.Extension;
using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
using FluentAssertions;
public partial class Example04_Dynamic_GroupChat_Coding_Task
@@ -14,46 +15,32 @@ public static async Task RunAsync()
{
var instance = new Example04_Dynamic_GroupChat_Coding_Task();
- // setup dotnet interactive
- var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
- if (!Directory.Exists(workDir))
- Directory.CreateDirectory(workDir);
+ var kernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder()
+ .AddPythonKernel("python3")
+ .Build();
- using var service = new InteractiveService(workDir);
- var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
- var result = Path.Combine(workDir, "result.txt");
- if (File.Exists(result))
- File.Delete(result);
-
- await service.StartAsync(workDir, default);
-
- var gptConfig = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
-
- var helperAgent = new GPTAgent(
- name: "helper",
- systemMessage: "You are a helpful AI assistant",
- temperature: 0f,
- config: gptConfig);
-
- var groupAdmin = new GPTAgent(
+ var groupAdmin = new OpenAIChatAgent(
+ chatClient: gpt4o,
name: "groupAdmin",
- systemMessage: "You are the admin of the group chat",
- temperature: 0f,
- config: gptConfig)
+ systemMessage: "You are the admin of the group chat")
+ .RegisterMessageConnector()
.RegisterPrintMessage();
- var userProxy = new UserProxyAgent(name: "user", defaultReply: GroupChatExtension.TERMINATE, humanInputMode: HumanInputMode.NEVER)
+ var userProxy = new DefaultReplyAgent(name: "user", defaultReply: GroupChatExtension.TERMINATE)
.RegisterPrintMessage();
// Create admin agent
- var admin = new AssistantAgent(
+ var admin = new OpenAIChatAgent(
+ chatClient: gpt4o,
name: "admin",
systemMessage: """
You are a manager who takes coding problem from user and resolve problem by splitting them into small tasks and assign each task to the most appropriate agent.
Here's available agents who you can assign task to:
- - coder: write dotnet code to resolve task
- - runner: run dotnet code from coder
+ - coder: write python code to resolve task
+ - runner: run python code from coder
The workflow is as follows:
- You take the coding problem from user
@@ -79,24 +66,12 @@ You are a manager who takes coding problem from user and resolve problem by spli
Once the coding problem is resolved, summarize each steps and results and send the summary to the user using the following format:
```summary
- {
- "problem": "{coding problem}",
- "steps": [
- {
- "step": "{step}",
- "result": "{result}"
- }
- ]
- }
+ @user,
```
Your reply must contain one of [task|ask|summary] to indicate the type of your message.
- """,
- llmConfig: new ConversableAgentConfig
- {
- Temperature = 0,
- ConfigList = [gptConfig],
- })
+ """)
+ .RegisterMessageConnector()
.RegisterPrintMessage();
// create coder agent
@@ -104,30 +79,27 @@ Your reply must contain one of [task|ask|summary] to indicate the type of your m
// The dotnet coder write dotnet code to resolve the task.
// The code reviewer review the code block from coder's reply.
// The nuget agent install nuget packages if there's any.
- var coderAgent = new GPTAgent(
+ var coderAgent = new OpenAIChatAgent(
name: "coder",
- systemMessage: @"You act as dotnet coder, you write dotnet code to resolve task. Once you finish writing code, ask runner to run the code for you.
+ chatClient: gpt4o,
+ systemMessage: @"You act as python coder, you write python code to resolve task. Once you finish writing code, ask runner to run the code for you.
Here're some rules to follow on writing dotnet code:
-- put code between ```csharp and ```
-- When creating http client, use `var httpClient = new HttpClient()`. Don't use `using var httpClient = new HttpClient()` because it will cause error when running the code.
-- Try to use `var` instead of explicit type.
-- Try avoid using external library, use .NET Core library instead.
-- Use top level statement to write code.
+- put code between ```python and ```
+- Try avoid using external library
- Always print out the result to console. Don't write code that doesn't print out anything.
-If you need to install nuget packages, put nuget packages in the following format:
-```nuget
-nuget_package_name
+Use the following format to install pip package:
+```python
+%pip install
```
If your code is incorrect, Fix the error and send the code again.
Here's some externel information
- The link to mlnet repo is: https://github.com/dotnet/machinelearning. you don't need a token to use github pr api. Make sure to include a User-Agent header, otherwise github will reject it.
-",
- config: gptConfig,
- temperature: 0.4f)
+")
+ .RegisterMessageConnector()
.RegisterPrintMessage();
// code reviewer agent will review if code block from coder's reply satisfy the following conditions:
@@ -135,14 +107,13 @@ Your reply must contain one of [task|ask|summary] to indicate the type of your m
// - The code block is csharp code block
// - The code block is top level statement
// - The code block is not using declaration
- var codeReviewAgent = new GPTAgent(
+ var codeReviewAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
name: "reviewer",
systemMessage: """
You are a code reviewer who reviews code from coder. You need to check if the code satisfy the following conditions:
- - The reply from coder contains at least one code block, e.g ```csharp and ```
- - There's only one code block and it's csharp code block
- - The code block is not inside a main function. a.k.a top level statement
- - The code block is not using declaration when creating http client
+ - The reply from coder contains at least one code block, e.g ```python and ```
+ - There's only one code block and it's python code block
You don't check the code style, only check if the code satisfy the above conditions.
@@ -160,23 +131,40 @@ Your reply must contain one of [task|ask|summary] to indicate the type of your m
result: REJECTED
```
- """,
- config: gptConfig,
- temperature: 0f)
+ """)
+ .RegisterMessageConnector()
.RegisterPrintMessage();
// create runner agent
// The runner agent will run the code block from coder's reply.
// It runs dotnet code using dotnet interactive service hook.
// It also truncate the output if the output is too long.
- var runner = new AssistantAgent(
+ var runner = new DefaultReplyAgent(
name: "runner",
defaultReply: "No code available, coder, write code please")
- .RegisterDotnetCodeBlockExectionHook(interactiveService: service)
.RegisterMiddleware(async (msgs, option, agent, ct) =>
{
var mostRecentCoderMessage = msgs.LastOrDefault(x => x.From == "coder") ?? throw new Exception("No coder message found");
- return await agent.GenerateReplyAsync(new[] { mostRecentCoderMessage }, option, ct);
+
+ if (mostRecentCoderMessage.ExtractCodeBlock("```python", "```") is string code)
+ {
+ var result = await kernel.RunSubmitCodeCommandAsync(code, "python");
+ // only keep the first 500 characters
+ if (result.Length > 500)
+ {
+ result = result.Substring(0, 500);
+ }
+ result = $"""
+ # [CODE_BLOCK_EXECUTION_RESULT]
+ {result}
+ """;
+
+ return new TextMessage(Role.Assistant, result, from: agent.Name);
+ }
+ else
+ {
+ return await agent.GenerateReplyAsync(msgs, option, ct);
+ }
})
.RegisterPrintMessage();
@@ -247,18 +235,27 @@ Your reply must contain one of [task|ask|summary] to indicate the type of your m
workflow: workflow);
// task 1: retrieve the most recent pr from mlnet and save it in result.txt
- var groupChatManager = new GroupChatManager(groupChat);
- await userProxy.SendAsync(groupChatManager, "Retrieve the most recent pr from mlnet and save it in result.txt", maxRound: 30);
- File.Exists(result).Should().BeTrue();
-
- // task 2: calculate the 39th fibonacci number
- var answer = 63245986;
- // clear the result file
- File.Delete(result);
+ var task = """
+ retrieve the most recent pr from mlnet and save it in result.txt
+ """;
+ var chatHistory = new List
+ {
+ new TextMessage(Role.Assistant, task)
+ {
+ From = userProxy.Name
+ }
+ };
+ await foreach (var message in groupChat.SendAsync(chatHistory, maxRound: 10))
+ {
+ if (message.From == admin.Name && message.GetContent().Contains("```summary"))
+ {
+ // Task complete!
+ break;
+ }
+ }
- var conversationHistory = await userProxy.InitiateChatAsync(groupChatManager, "What's the 39th of fibonacci number? Save the result in result.txt", maxRound: 10);
+ // check if the result file is created
+ var result = "result.txt";
File.Exists(result).Should().BeTrue();
- var resultContent = File.ReadAllText(result);
- resultContent.Should().Contain(answer.ToString());
}
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
index ba7b5d4bde44..e8dd86474e7a 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
@@ -4,9 +4,9 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
using FluentAssertions;
-using autogen = AutoGen.LLMConfigAPI;
+using OpenAI;
+using OpenAI.Images;
public partial class Example05_Dalle_And_GPT4V
{
@@ -30,16 +30,12 @@ public async Task GenerateImage(string prompt)
// and return url.
var option = new ImageGenerationOptions
{
- Size = ImageSize.Size1024x1024,
- Style = ImageGenerationStyle.Vivid,
- ImageCount = 1,
- Prompt = prompt,
- Quality = ImageGenerationQuality.Standard,
- DeploymentName = "dall-e-3",
+ Size = GeneratedImageSize.W1024xH1024,
+ Style = GeneratedImageStyle.Vivid,
};
- var imageResponse = await openAIClient.GetImageGenerationsAsync(option);
- var imageUrl = imageResponse.Value.Data.First().Url.OriginalString;
+ var imageResponse = await openAIClient.GetImageClient("dall-e-3").GenerateImageAsync(prompt, option);
+ var imageUrl = imageResponse.Value.ImageUri.OriginalString;
return $@"// ignore this line [IMAGE_GENERATION]
The image is generated from prompt {prompt}
@@ -57,8 +53,6 @@ public static async Task RunAsync()
// get OpenAI Key and create config
var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var gpt35Config = autogen.GetOpenAIConfigList(openAIKey, new[] { "gpt-3.5-turbo" });
- var gpt4vConfig = autogen.GetOpenAIConfigList(openAIKey, new[] { "gpt-4-vision-preview" });
var openAIClient = new OpenAIClient(openAIKey);
var instance = new Example05_Dalle_And_GPT4V(openAIClient);
var imagePath = Path.Combine("resource", "images", "background.png");
@@ -74,8 +68,7 @@ public static async Task RunAsync()
{ nameof(GenerateImage), instance.GenerateImageWrapper },
});
var dalleAgent = new OpenAIChatAgent(
- openAIClient: openAIClient,
- modelName: "gpt-3.5-turbo",
+ chatClient: openAIClient.GetChatClient("gpt-4o-mini"),
name: "dalle",
systemMessage: "You are a DALL-E agent that generate image from prompt, when conversation is terminated, return the most recent image url")
.RegisterMessageConnector()
@@ -110,9 +103,8 @@ public static async Task RunAsync()
.RegisterPrintMessage();
var gpt4VAgent = new OpenAIChatAgent(
- openAIClient: openAIClient,
- name: "gpt4v",
- modelName: "gpt-4-vision-preview",
+ chatClient: openAIClient.GetChatClient("gpt-4o-mini"),
+ name: "gpt-4o-mini",
systemMessage: @"You are a critism that provide feedback to DALL-E agent.
Carefully check the image generated by DALL-E agent and provide feedback.
If the image satisfies the condition, then say [APPROVE].
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs
index dd3b5a671921..e1349cb32a99 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs
@@ -2,6 +2,7 @@
// Example06_UserProxyAgent.cs
using AutoGen.Core;
using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
namespace AutoGen.BasicSample;
@@ -9,12 +10,13 @@ public static class Example06_UserProxyAgent
{
public static async Task RunAsync()
{
- var gpt35 = LLMConfiguration.GetOpenAIGPT3_5_Turbo();
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
- var assistantAgent = new GPTAgent(
+ var assistantAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
name: "assistant",
- systemMessage: "You are an assistant that help user to do some tasks.",
- config: gpt35)
+ systemMessage: "You are an assistant that help user to do some tasks.")
+ .RegisterMessageConnector()
.RegisterPrintMessage();
// set human input mode to ALWAYS so that user always provide input
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
index 6584baa5fae5..1f1315586a28 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
@@ -3,12 +3,14 @@
using System.Text;
using System.Text.Json;
-using AutoGen;
using AutoGen.BasicSample;
using AutoGen.Core;
using AutoGen.DotnetInteractive;
+using AutoGen.DotnetInteractive.Extension;
using AutoGen.OpenAI;
-using FluentAssertions;
+using AutoGen.OpenAI.Extension;
+using Microsoft.DotNet.Interactive;
+using OpenAI.Chat;
public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci
{
@@ -48,10 +50,10 @@ public async Task ReviewCodeBlock(
#endregion reviewer_function
#region create_coder
- public static async Task CreateCoderAgentAsync()
+ public static async Task CreateCoderAgentAsync(ChatClient client)
{
- var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
- var coder = new GPTAgent(
+ var coder = new OpenAIChatAgent(
+ chatClient: client,
name: "coder",
systemMessage: @"You act as dotnet coder, you write dotnet code to resolve task. Once you finish writing code, ask runner to run the code for you.
@@ -69,8 +71,8 @@ public static async Task CreateCoderAgentAsync()
```
If your code is incorrect, runner will tell you the error message. Fix the error and send the code again.",
- config: gpt3Config,
temperature: 0.4f)
+ .RegisterMessageConnector()
.RegisterPrintMessage();
return coder;
@@ -78,13 +80,11 @@ public static async Task CreateCoderAgentAsync()
#endregion create_coder
#region create_runner
- public static async Task CreateRunnerAgentAsync(InteractiveService service)
+ public static async Task CreateRunnerAgentAsync(Kernel kernel)
{
- var runner = new AssistantAgent(
+ var runner = new DefaultReplyAgent(
name: "runner",
- systemMessage: "You run dotnet code",
defaultReply: "No code available.")
- .RegisterDotnetCodeBlockExectionHook(interactiveService: service)
.RegisterMiddleware(async (msgs, option, agent, _) =>
{
if (msgs.Count() == 0 || msgs.All(msg => msg.From != "coder"))
@@ -94,7 +94,24 @@ public static async Task CreateRunnerAgentAsync(InteractiveService servi
else
{
var coderMsg = msgs.Last(msg => msg.From == "coder");
- return await agent.GenerateReplyAsync([coderMsg], option);
+ if (coderMsg.ExtractCodeBlock("```csharp", "```") is string code)
+ {
+ var codeResult = await kernel.RunSubmitCodeCommandAsync(code, "csharp");
+
+ codeResult = $"""
+ [RUNNER_RESULT]
+ {codeResult}
+ """;
+
+ return new TextMessage(Role.Assistant, codeResult)
+ {
+ From = "runner",
+ };
+ }
+ else
+ {
+ return new TextMessage(Role.Assistant, "No code available. Coder please write code");
+ }
}
})
.RegisterPrintMessage();
@@ -104,45 +121,35 @@ public static async Task CreateRunnerAgentAsync(InteractiveService servi
#endregion create_runner
#region create_admin
- public static async Task CreateAdminAsync()
+ public static async Task CreateAdminAsync(ChatClient client)
{
- var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
- var admin = new GPTAgent(
+ var admin = new OpenAIChatAgent(
+ chatClient: client,
name: "admin",
- systemMessage: "You are group admin, terminate the group chat once task is completed by saying [TERMINATE] plus the final answer",
- temperature: 0,
- config: gpt3Config)
- .RegisterMiddleware(async (msgs, option, agent, _) =>
- {
- var reply = await agent.GenerateReplyAsync(msgs, option);
- if (reply is TextMessage textMessage && textMessage.Content.Contains("TERMINATE") is true)
- {
- var content = $"{textMessage.Content}\n\n {GroupChatExtension.TERMINATE}";
-
- return new TextMessage(Role.Assistant, content, from: reply.From);
- }
-
- return reply;
- });
+ temperature: 0)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
return admin;
}
#endregion create_admin
#region create_reviewer
- public static async Task CreateReviewerAgentAsync()
+ public static async Task CreateReviewerAgentAsync(ChatClient chatClient)
{
- var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
var functions = new Example07_Dynamic_GroupChat_Calculate_Fibonacci();
- var reviewer = new GPTAgent(
- name: "code_reviewer",
- systemMessage: @"You review code block from coder",
- config: gpt3Config,
- functions: [functions.ReviewCodeBlockFunction],
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [functions.ReviewCodeBlockFunctionContract],
functionMap: new Dictionary>>()
{
- { nameof(ReviewCodeBlock), functions.ReviewCodeBlockWrapper },
- })
+ { nameof(functions.ReviewCodeBlock), functions.ReviewCodeBlockWrapper },
+ });
+ var reviewer = new OpenAIChatAgent(
+ chatClient: chatClient,
+ name: "code_reviewer",
+ systemMessage: @"You review code block from coder")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
{
var maxRetry = 3;
@@ -222,20 +229,17 @@ public static async Task CreateReviewerAgentAsync()
public static async Task RunWorkflowAsync()
{
long the39thFibonacciNumber = 63245986;
- var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
- if (!Directory.Exists(workDir))
- Directory.CreateDirectory(workDir);
-
- using var service = new InteractiveService(workDir);
- var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
+ var kernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder()
+ .Build();
- await service.StartAsync(workDir, default);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
#region create_workflow
- var reviewer = await CreateReviewerAgentAsync();
- var coder = await CreateCoderAgentAsync();
- var runner = await CreateRunnerAgentAsync(service);
- var admin = await CreateAdminAsync();
+ var reviewer = await CreateReviewerAgentAsync(gpt4o);
+ var coder = await CreateCoderAgentAsync(gpt4o);
+ var runner = await CreateRunnerAgentAsync(kernel);
+ var admin = await CreateAdminAsync(gpt4o);
var admin2CoderTransition = Transition.Create(admin, coder);
var coder2ReviewerTransition = Transition.Create(coder, reviewer);
@@ -306,21 +310,23 @@ public static async Task RunWorkflowAsync()
runner,
reviewer,
]);
-
+ #endregion create_group_chat_with_workflow
admin.SendIntroduction("Welcome to my group, work together to resolve my task", groupChat);
coder.SendIntroduction("I will write dotnet code to resolve task", groupChat);
reviewer.SendIntroduction("I will review dotnet code", groupChat);
runner.SendIntroduction("I will run dotnet code once the review is done", groupChat);
+ var task = "What's the 39th of fibonacci number?";
- var groupChatManager = new GroupChatManager(groupChat);
- var conversationHistory = await admin.InitiateChatAsync(groupChatManager, "What's the 39th of fibonacci number?", maxRound: 10);
- #endregion create_group_chat_with_workflow
- // the last message is from admin, which is the termination message
- var lastMessage = conversationHistory.Last();
- lastMessage.From.Should().Be("admin");
- lastMessage.IsGroupChatTerminateMessage().Should().BeTrue();
- lastMessage.Should().BeOfType();
- lastMessage.GetContent().Should().Contain(the39thFibonacciNumber.ToString());
+ var taskMessage = new TextMessage(Role.User, task, from: admin.Name);
+ await foreach (var message in groupChat.SendAsync([taskMessage], maxRound: 10))
+ {
+ // teminate chat if message is from runner and run successfully
+ if (message.From == "runner" && message.GetContent().Contains(the39thFibonacciNumber.ToString()))
+ {
+ Console.WriteLine($"The 39th of fibonacci number is {the39thFibonacciNumber}");
+ break;
+ }
+ }
}
public static async Task RunAsync()
@@ -328,41 +334,44 @@ public static async Task RunAsync()
long the39thFibonacciNumber = 63245986;
var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
if (!Directory.Exists(workDir))
+ {
Directory.CreateDirectory(workDir);
+ }
- using var service = new InteractiveService(workDir);
- var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
- await service.StartAsync(workDir, default);
+ var kernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder()
+ .Build();
#region create_group_chat
- var reviewer = await CreateReviewerAgentAsync();
- var coder = await CreateCoderAgentAsync();
- var runner = await CreateRunnerAgentAsync(service);
- var admin = await CreateAdminAsync();
+ var reviewer = await CreateReviewerAgentAsync(gpt4o);
+ var coder = await CreateCoderAgentAsync(gpt4o);
+ var runner = await CreateRunnerAgentAsync(kernel);
+ var admin = await CreateAdminAsync(gpt4o);
var groupChat = new GroupChat(
admin: admin,
members:
[
- admin,
coder,
runner,
reviewer,
]);
- admin.SendIntroduction("Welcome to my group, work together to resolve my task", groupChat);
coder.SendIntroduction("I will write dotnet code to resolve task", groupChat);
reviewer.SendIntroduction("I will review dotnet code", groupChat);
runner.SendIntroduction("I will run dotnet code once the review is done", groupChat);
- var groupChatManager = new GroupChatManager(groupChat);
- var conversationHistory = await admin.InitiateChatAsync(groupChatManager, "What's the 39th of fibonacci number?", maxRound: 10);
-
- // the last message is from admin, which is the termination message
- var lastMessage = conversationHistory.Last();
- lastMessage.From.Should().Be("admin");
- lastMessage.IsGroupChatTerminateMessage().Should().BeTrue();
- lastMessage.Should().BeOfType();
- lastMessage.GetContent().Should().Contain(the39thFibonacciNumber.ToString());
+ var task = "What's the 39th of fibonacci number?";
+ var taskMessage = new TextMessage(Role.User, task);
+ await foreach (var message in groupChat.SendAsync([taskMessage], maxRound: 10))
+ {
+ // teminate chat if message is from runner and run successfully
+ if (message.From == "runner" && message.GetContent().Contains(the39thFibonacciNumber.ToString()))
+ {
+ Console.WriteLine($"The 39th of fibonacci number is {the39thFibonacciNumber}");
+ break;
+ }
+ }
#endregion create_group_chat
}
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs b/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs
index cce330117622..e58454fdb5f8 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs
@@ -3,7 +3,9 @@
#region lmstudio_using_statements
using AutoGen.Core;
-using AutoGen.LMStudio;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using OpenAI;
#endregion lmstudio_using_statements
namespace AutoGen.BasicSample;
@@ -13,8 +15,16 @@ public class Example08_LMStudio
public static async Task RunAsync()
{
#region lmstudio_example_1
- var config = new LMStudioConfig("localhost", 1234);
- var lmAgent = new LMStudioAgent("asssistant", config: config)
+ var endpoint = "http://localhost:1234";
+ var openaiClient = new OpenAIClient("api-key", new OpenAIClientOptions
+ {
+ Endpoint = new Uri(endpoint),
+ });
+
+ var lmAgent = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(""),
+ name: "assistant")
+ .RegisterMessageConnector()
.RegisterPrintMessage();
await lmAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs
deleted file mode 100644
index 9a62144df2bd..000000000000
--- a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs
+++ /dev/null
@@ -1,135 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Example09_LMStudio_FunctionCall.cs
-
-using System.Text.Json;
-using System.Text.Json.Serialization;
-using AutoGen.Core;
-using AutoGen.LMStudio;
-using Azure.AI.OpenAI;
-
-namespace AutoGen.BasicSample;
-
-public class LLaMAFunctionCall
-{
- [JsonPropertyName("name")]
- public string Name { get; set; }
-
- [JsonPropertyName("arguments")]
- public JsonElement Arguments { get; set; }
-}
-
-public partial class Example09_LMStudio_FunctionCall
-{
- ///
- /// Get weather from location.
- ///
- /// location
- /// date. type is string
- [Function]
- public async Task GetWeather(string location, string date)
- {
- return $"[Function] The weather on {date} in {location} is sunny.";
- }
-
-
- ///
- /// Search query on Google and return the results.
- ///
- /// search query
- [Function]
- public async Task GoogleSearch(string query)
- {
- return $"[Function] Here are the search results for {query}.";
- }
-
- private static object SerializeFunctionDefinition(FunctionDefinition functionDefinition)
- {
- return new
- {
- type = "function",
- function = new
- {
- name = functionDefinition.Name,
- description = functionDefinition.Description,
- parameters = functionDefinition.Parameters.ToObjectFromJson(),
- }
- };
- }
-
- public static async Task RunAsync()
- {
- #region lmstudio_function_call_example
- // This example has been verified to work with Trelis-Llama-2-7b-chat-hf-function-calling-v3
- var instance = new Example09_LMStudio_FunctionCall();
- var config = new LMStudioConfig("localhost", 1234);
- var systemMessage = @$"You are a helpful AI assistant.";
-
- // Because the LM studio server doesn't support openai function call yet
- // To simulate the function call, we can put the function call details in the system message
- // And ask agent to response in function call object format using few-shot example
- object[] functionList =
- [
- SerializeFunctionDefinition(instance.GetWeatherFunction),
- SerializeFunctionDefinition(instance.GoogleSearchFunction)
- ];
- var functionListString = JsonSerializer.Serialize(functionList, new JsonSerializerOptions { WriteIndented = true });
- var lmAgent = new LMStudioAgent(
- name: "assistant",
- systemMessage: @$"
-You are a helpful AI assistant
-You have access to the following functions. Use them if required:
-
-{functionListString}",
- config: config)
- .RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
- {
- // inject few-shot example to the message
- var exampleGetWeather = new TextMessage(Role.User, "Get weather in London");
- var exampleAnswer = new TextMessage(Role.Assistant, "{\n \"name\": \"GetWeather\",\n \"arguments\": {\n \"city\": \"London\"\n }\n}", from: innerAgent.Name);
-
- msgs = new[] { exampleGetWeather, exampleAnswer }.Concat(msgs).ToArray();
- var reply = await innerAgent.GenerateReplyAsync(msgs, option, ct);
-
- // if reply is a function call, invoke function
- var content = reply.GetContent();
- try
- {
- if (JsonSerializer.Deserialize(content) is { } functionCall)
- {
- var arguments = JsonSerializer.Serialize(functionCall.Arguments);
- // invoke function wrapper
- if (functionCall.Name == instance.GetWeatherFunction.Name)
- {
- var result = await instance.GetWeatherWrapper(arguments);
- return new TextMessage(Role.Assistant, result);
- }
- else if (functionCall.Name == instance.GoogleSearchFunction.Name)
- {
- var result = await instance.GoogleSearchWrapper(arguments);
- return new TextMessage(Role.Assistant, result);
- }
- else
- {
- throw new Exception($"Unknown function call: {functionCall.Name}");
- }
- }
- }
- catch (JsonException)
- {
- // ignore
- }
-
- return reply;
- })
- .RegisterPrintMessage();
-
- var userProxyAgent = new UserProxyAgent(
- name: "user",
- humanInputMode: HumanInputMode.ALWAYS);
-
- await userProxyAgent.SendAsync(
- receiver: lmAgent,
- "Search the names of the five largest stocks in the US by market cap ");
- #endregion lmstudio_function_call_example
- }
-}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs
index 61c341204ec2..da7e54852f34 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs
@@ -39,7 +39,7 @@ public class Example10_SemanticKernel
public static async Task RunAsync()
{
var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var modelId = "gpt-3.5-turbo";
+ var modelId = "gpt-4o-mini";
var builder = Kernel.CreateBuilder()
.AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
var kernel = builder.Build();
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs b/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs
index 00ff321082a4..32aaa8c187b4 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs
@@ -7,7 +7,6 @@
using AutoGen.OpenAI.Extension;
using AutoGen.SemanticKernel;
using AutoGen.SemanticKernel.Extension;
-using Azure.AI.OpenAI;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Plugins.Web;
using Microsoft.SemanticKernel.Plugins.Web.Bing;
@@ -52,15 +51,10 @@ You put the original search result between ```bing and ```
public static async Task CreateSummarizerAgentAsync()
{
#region CreateSummarizerAgent
- var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
- var apiKey = config.ApiKey;
- var endPoint = new Uri(config.Endpoint);
-
- var openAIClient = new OpenAIClient(endPoint, new Azure.AzureKeyCredential(apiKey));
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var openAIClientAgent = new OpenAIChatAgent(
- openAIClient: openAIClient,
+ chatClient: gpt4o,
name: "summarizer",
- modelName: config.DeploymentName,
systemMessage: "You summarize search result from bing in a short and concise manner");
return openAIClientAgent
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs b/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs
index b622a3e641ef..69c2121cd80b 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs
@@ -5,7 +5,6 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
namespace AutoGen.BasicSample;
@@ -69,11 +68,7 @@ public async Task SaveProgress(
public static async Task CreateSaveProgressAgent()
{
- var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
- var endPoint = gpt3Config.Endpoint ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
- var apiKey = gpt3Config.ApiKey ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
- var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(apiKey));
-
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var instance = new TwoAgent_Fill_Application();
var functionCallConnector = new FunctionCallMiddleware(
functions: [instance.SaveProgressFunctionContract],
@@ -83,9 +78,8 @@ public static async Task CreateSaveProgressAgent()
});
var chatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: gpt4o,
name: "application",
- modelName: gpt3Config.DeploymentName,
systemMessage: """You are a helpful application form assistant who saves progress while user fills application.""")
.RegisterMessageConnector()
.RegisterMiddleware(functionCallConnector)
@@ -109,48 +103,23 @@ Save progress according to the most recent information provided by user.
public static async Task CreateAssistantAgent()
{
- var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
- var endPoint = gpt3Config.Endpoint ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
- var apiKey = gpt3Config.ApiKey ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
- var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(apiKey));
-
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var chatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: gpt4o,
name: "assistant",
- modelName: gpt3Config.DeploymentName,
systemMessage: """You create polite prompt to ask user provide missing information""")
.RegisterMessageConnector()
- .RegisterPrintMessage()
- .RegisterMiddleware(async (msgs, option, agent, ct) =>
- {
- var lastReply = msgs.Last() ?? throw new Exception("No reply found.");
- var reply = await agent.GenerateReplyAsync(msgs, option, ct);
-
- // if application is complete, exit conversation by sending termination message
- if (lastReply.GetContent().Contains("Application information is saved to database."))
- {
- return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: agent.Name);
- }
- else
- {
- return reply;
- }
- });
+ .RegisterPrintMessage();
return chatAgent;
}
public static async Task CreateUserAgent()
{
- var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
- var endPoint = gpt3Config.Endpoint ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
- var apiKey = gpt3Config.ApiKey ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
- var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(apiKey));
-
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var chatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: gpt4o,
name: "user",
- modelName: gpt3Config.DeploymentName,
systemMessage: """
You are a user who is filling an application form. Simply provide the information as requested and answer the questions, don't do anything else.
@@ -191,9 +160,13 @@ public static async Task RunAsync()
var groupChatManager = new GroupChatManager(groupChat);
var initialMessage = await assistantAgent.SendAsync("Generate a greeting meesage for user and start the conversation by asking what's their name.");
- var chatHistory = await userAgent.SendAsync(groupChatManager, [initialMessage], maxRound: 30);
-
- var lastMessage = chatHistory.Last();
- Console.WriteLine(lastMessage.GetContent());
+ var chatHistory = new List { initialMessage };
+ await foreach (var msg in userAgent.SendAsync(groupChatManager, chatHistory, maxRound: 30))
+ {
+ if (msg.GetContent().ToLower().Contains("application information is saved to database.") is true)
+ {
+ break;
+ }
+ }
}
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
index dadad7f00b99..596ab08d02a1 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
@@ -1,68 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Example13_OpenAIAgent_JsonMode.cs
-using System.Text.Json;
-using System.Text.Json.Serialization;
-using AutoGen.Core;
-using AutoGen.OpenAI;
-using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
-using FluentAssertions;
+// this example has been moved to https://github.com/microsoft/autogen/blob/main/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
-namespace AutoGen.BasicSample;
-
-public class Example13_OpenAIAgent_JsonMode
-{
- public static async Task RunAsync()
- {
- #region create_agent
- var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(deployName: "gpt-35-turbo"); // json mode only works with 0125 and later model.
- var apiKey = config.ApiKey;
- var endPoint = new Uri(config.Endpoint);
-
- var openAIClient = new OpenAIClient(endPoint, new Azure.AzureKeyCredential(apiKey));
- var openAIClientAgent = new OpenAIChatAgent(
- openAIClient: openAIClient,
- name: "assistant",
- modelName: config.DeploymentName,
- systemMessage: "You are a helpful assistant designed to output JSON.",
- seed: 0, // explicitly set a seed to enable deterministic output
- responseFormat: ChatCompletionsResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode
- .RegisterMessageConnector()
- .RegisterPrintMessage();
- #endregion create_agent
-
- #region chat_with_agent
- var reply = await openAIClientAgent.SendAsync("My name is John, I am 25 years old, and I live in Seattle.");
-
- var person = JsonSerializer.Deserialize(reply.GetContent());
- Console.WriteLine($"Name: {person.Name}");
- Console.WriteLine($"Age: {person.Age}");
-
- if (!string.IsNullOrEmpty(person.Address))
- {
- Console.WriteLine($"Address: {person.Address}");
- }
-
- Console.WriteLine("Done.");
- #endregion chat_with_agent
-
- person.Name.Should().Be("John");
- person.Age.Should().Be(25);
- person.Address.Should().BeNullOrEmpty();
- }
-}
-
-#region person_class
-public class Person
-{
- [JsonPropertyName("name")]
- public string Name { get; set; }
-
- [JsonPropertyName("age")]
- public int Age { get; set; }
-
- [JsonPropertyName("address")]
- public string Address { get; set; }
-}
-#endregion person_class
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs b/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
index 788122d3f383..4a4b10ae3d75 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
@@ -3,6 +3,7 @@
using AutoGen.Core;
using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
namespace AutoGen.BasicSample;
@@ -27,14 +28,14 @@ public static class Example15_GPT4V_BinaryDataImageMessage
public static async Task RunAsync()
{
- var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var openAiConfig = new OpenAIConfig(openAIKey, "gpt-4o");
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
- var visionAgent = new GPTAgent(
+ var visionAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
name: "gpt",
systemMessage: "You are a helpful AI assistant",
- config: openAiConfig,
temperature: 0)
+ .RegisterMessageConnector()
.RegisterPrintMessage();
List messages =
@@ -50,7 +51,9 @@ private static void AddMessagesFromResource(string imageResourcePath, List SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
- {
- request.RequestUri = new Uri($"{_modelServiceUrl}{request.RequestUri.PathAndQuery}");
-
- return base.SendAsync(request, cancellationToken);
- }
-}
-#endregion CustomHttpClientHandler
-
-public class Example16_OpenAIChatAgent_ConnectToThirdPartyBackend
-{
- public static async Task RunAsync()
- {
- #region create_agent
- using var client = new HttpClient(new CustomHttpClientHandler("http://localhost:11434"));
- var option = new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2024_04_01_Preview)
- {
- Transport = new HttpClientTransport(client),
- };
-
- // api-key is not required for local server
- // so you can use any string here
- var openAIClient = new OpenAIClient("api-key", option);
- var model = "llama3";
-
- var agent = new OpenAIChatAgent(
- openAIClient: openAIClient,
- name: "assistant",
- modelName: model,
- systemMessage: "You are a helpful assistant designed to output JSON.",
- seed: 0)
- .RegisterMessageConnector()
- .RegisterPrintMessage();
- #endregion create_agent
-
- #region send_message
- await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
- #endregion send_message
- }
-}
+// this example has been moved to https://github.com/microsoft/autogen/blob/main/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs
index f598ebbf7c46..170736bf22e4 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs
@@ -4,14 +4,14 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
+using OpenAI;
+using OpenAI.Chat;
namespace AutoGen.BasicSample;
public class OpenAIReActAgent : IAgent
{
- private readonly OpenAIClient _client;
- private readonly string modelName = "gpt-3.5-turbo";
+ private readonly ChatClient _client;
private readonly FunctionContract[] tools;
private readonly Dictionary>> toolExecutors = new();
private readonly IAgent reasoner;
@@ -39,16 +39,15 @@ public class OpenAIReActAgent : IAgent
Begin!
Question: {input}";
- public OpenAIReActAgent(OpenAIClient client, string modelName, string name, FunctionContract[] tools, Dictionary>> toolExecutors)
+ public OpenAIReActAgent(ChatClient client, string name, FunctionContract[] tools, Dictionary>> toolExecutors)
{
_client = client;
this.Name = name;
- this.modelName = modelName;
this.tools = tools;
this.toolExecutors = toolExecutors;
this.reasoner = CreateReasoner();
this.actor = CreateActor();
- this.helper = new OpenAIChatAgent(client, "helper", modelName)
+ this.helper = new OpenAIChatAgent(client, "helper")
.RegisterMessageConnector();
}
@@ -106,8 +105,7 @@ private string CreateReActPrompt(string input)
private IAgent CreateReasoner()
{
return new OpenAIChatAgent(
- openAIClient: _client,
- modelName: modelName,
+ chatClient: _client,
name: "reasoner")
.RegisterMessageConnector()
.RegisterPrintMessage();
@@ -117,8 +115,7 @@ private IAgent CreateActor()
{
var functionCallMiddleware = new FunctionCallMiddleware(tools, toolExecutors);
return new OpenAIChatAgent(
- openAIClient: _client,
- modelName: modelName,
+ chatClient: _client,
name: "actor")
.RegisterMessageConnector()
.RegisterMiddleware(functionCallMiddleware)
@@ -166,9 +163,9 @@ public static async Task RunAsync()
var modelName = "gpt-4-turbo";
var tools = new Tools();
var openAIClient = new OpenAIClient(openAIKey);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var reactAgent = new OpenAIReActAgent(
- client: openAIClient,
- modelName: modelName,
+ client: openAIClient.GetChatClient(modelName),
name: "react-agent",
tools: [tools.GetLocalizationFunctionContract, tools.GetDateTodayFunctionContract, tools.WeatherReportFunctionContract],
toolExecutors: new Dictionary>>
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Agent_Middleware.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Agent_Middleware.cs
index 57f8ab4075c2..cf97af134675 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Agent_Middleware.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Agent_Middleware.cs
@@ -5,9 +5,9 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
#endregion Using
using FluentAssertions;
+using OpenAI.Chat;
namespace AutoGen.BasicSample;
@@ -16,20 +16,17 @@ public class Agent_Middleware
public static async Task RunTokenCountAsync()
{
#region Create_Agent
- var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("Please set the environment variable OPENAI_API_KEY");
- var model = "gpt-3.5-turbo";
- var openaiClient = new OpenAIClient(apiKey);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var openaiMessageConnector = new OpenAIChatRequestMessageConnector();
var totalTokenCount = 0;
var agent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: gpt4o,
name: "agent",
- modelName: model,
systemMessage: "You are a helpful AI assistant")
.RegisterMiddleware(async (messages, option, innerAgent, ct) =>
{
var reply = await innerAgent.GenerateReplyAsync(messages, option, ct);
- if (reply is MessageEnvelope chatCompletions)
+ if (reply is MessageEnvelope chatCompletions)
{
var tokenCount = chatCompletions.Content.Usage.TotalTokens;
totalTokenCount += tokenCount;
@@ -53,21 +50,17 @@ public static async Task RunTokenCountAsync()
public static async Task RunRagTaskAsync()
{
#region Create_Agent
- var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("Please set the environment variable OPENAI_API_KEY");
- var model = "gpt-3.5-turbo";
- var openaiClient = new OpenAIClient(apiKey);
- var openaiMessageConnector = new OpenAIChatRequestMessageConnector();
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var agent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: gpt4o,
name: "agent",
- modelName: model,
systemMessage: "You are a helpful AI assistant")
.RegisterMessageConnector()
.RegisterMiddleware(async (messages, option, innerAgent, ct) =>
{
var today = DateTime.UtcNow;
var todayMessage = new TextMessage(Role.System, $"Today is {today:yyyy-MM-dd}");
- messages = messages.Concat(new[] { todayMessage });
+ messages = messages.Concat([todayMessage]);
return await innerAgent.GenerateReplyAsync(messages, option, ct);
})
.RegisterPrintMessage();
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Chat_With_Agent.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Chat_With_Agent.cs
index 0ac1cda75288..b2cc228496db 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Chat_With_Agent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Chat_With_Agent.cs
@@ -5,7 +5,6 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
#endregion Using
using FluentAssertions;
@@ -17,13 +16,10 @@ public class Chat_With_Agent
public static async Task RunAsync()
{
#region Create_Agent
- var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var model = "gpt-3.5-turbo";
- var openaiClient = new OpenAIClient(apiKey);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var agent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: gpt4o,
name: "agent",
- modelName: model,
systemMessage: "You are a helpful AI assistant")
.RegisterMessageConnector(); // convert OpenAI message to AutoGen message
#endregion Create_Agent
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs
index 9d21bbde7d30..dadc295e308d 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// Dynamic_GroupChat.cs
+// Dynamic_Group_Chat.cs
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
using AutoGen.SemanticKernel;
using AutoGen.SemanticKernel.Extension;
-using Azure.AI.OpenAI;
using Microsoft.SemanticKernel;
+using OpenAI;
namespace AutoGen.BasicSample;
@@ -16,14 +16,13 @@ public class Dynamic_Group_Chat
public static async Task RunAsync()
{
var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var model = "gpt-3.5-turbo";
+ var model = "gpt-4o-mini";
#region Create_Coder
var openaiClient = new OpenAIClient(apiKey);
var coder = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: openaiClient.GetChatClient(model),
name: "coder",
- modelName: model,
systemMessage: "You are a C# coder, when writing csharp code, please put the code between ```csharp and ```")
.RegisterMessageConnector() // convert OpenAI message to AutoGen message
.RegisterPrintMessage(); // print the message content
@@ -49,9 +48,8 @@ public static async Task RunAsync()
#region Create_Group
var admin = new OpenAIChatAgent(
- openAIClient: openaiClient,
- name: "admin",
- modelName: model)
+ chatClient: openaiClient.GetChatClient(model),
+ name: "admin")
.RegisterMessageConnector(); // convert OpenAI message to AutoGen message
var group = new GroupChat(
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/FSM_Group_Chat.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/FSM_Group_Chat.cs
index 59c0aa9ca88b..093d0c77ce64 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/FSM_Group_Chat.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/FSM_Group_Chat.cs
@@ -6,7 +6,8 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
+using OpenAI;
+using OpenAI.Chat;
#endregion Using
namespace AutoGen.BasicSample;
@@ -74,7 +75,7 @@ public async Task SaveProgress(
public class FSM_Group_Chat
{
- public static async Task CreateSaveProgressAgent(OpenAIClient client, string model)
+ public static async Task CreateSaveProgressAgent(ChatClient client)
{
#region Create_Save_Progress_Agent
var tool = new FillFormTool();
@@ -86,9 +87,8 @@ public static async Task CreateSaveProgressAgent(OpenAIClient client, st
});
var chatAgent = new OpenAIChatAgent(
- openAIClient: client,
+ chatClient: client,
name: "application",
- modelName: model,
systemMessage: """You are a helpful application form assistant who saves progress while user fills application.""")
.RegisterMessageConnector()
.RegisterMiddleware(functionCallMiddleware)
@@ -111,42 +111,25 @@ Save progress according to the most recent information provided by user.
return chatAgent;
}
- public static async Task CreateAssistantAgent(OpenAIClient openaiClient, string model)
+ public static async Task CreateAssistantAgent(ChatClient chatClient)
{
#region Create_Assistant_Agent
var chatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: chatClient,
name: "assistant",
- modelName: model,
systemMessage: """You create polite prompt to ask user provide missing information""")
.RegisterMessageConnector()
- .RegisterPrintMessage()
- .RegisterMiddleware(async (msgs, option, agent, ct) =>
- {
- var lastReply = msgs.Last() ?? throw new Exception("No reply found.");
- var reply = await agent.GenerateReplyAsync(msgs, option, ct);
-
- // if application is complete, exit conversation by sending termination message
- if (lastReply.GetContent()?.Contains("Application information is saved to database.") is true)
- {
- return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: agent.Name);
- }
- else
- {
- return reply;
- }
- });
+ .RegisterPrintMessage();
#endregion Create_Assistant_Agent
return chatAgent;
}
- public static async Task CreateUserAgent(OpenAIClient openaiClient, string model)
+ public static async Task CreateUserAgent(ChatClient chatClient)
{
#region Create_User_Agent
var chatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: chatClient,
name: "user",
- modelName: model,
systemMessage: """
You are a user who is filling an application form. Simply provide the information as requested and answer the questions, don't do anything else.
@@ -166,11 +149,12 @@ public static async Task CreateUserAgent(OpenAIClient openaiClient, stri
public static async Task RunAsync()
{
var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var model = "gpt-3.5-turbo";
+ var model = "gpt-4o-mini";
var openaiClient = new OpenAIClient(apiKey);
- var applicationAgent = await CreateSaveProgressAgent(openaiClient, model);
- var assistantAgent = await CreateAssistantAgent(openaiClient, model);
- var userAgent = await CreateUserAgent(openaiClient, model);
+ var chatClient = openaiClient.GetChatClient(model);
+ var applicationAgent = await CreateSaveProgressAgent(chatClient);
+ var assistantAgent = await CreateAssistantAgent(chatClient);
+ var userAgent = await CreateUserAgent(chatClient);
#region Create_Graph
var userToApplicationTransition = Transition.Create(userAgent, applicationAgent);
@@ -193,9 +177,13 @@ public static async Task RunAsync()
var initialMessage = await assistantAgent.SendAsync("Generate a greeting meesage for user and start the conversation by asking what's their name.");
- var chatHistory = await userAgent.SendMessageToGroupAsync(groupChat, [initialMessage], maxRound: 30);
-
- var lastMessage = chatHistory.Last();
- Console.WriteLine(lastMessage.GetContent());
+ var chatHistory = new List { initialMessage };
+ await foreach (var msg in groupChat.SendAsync(chatHistory, maxRound: 30))
+ {
+ if (msg.GetContent().ToLower().Contains("application information is saved to database.") is true)
+ {
+ break;
+ }
+ }
}
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs
index 5b94a238bbe8..e993b3d51f1c 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs
@@ -5,7 +5,6 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
#endregion Using
using FluentAssertions;
@@ -16,14 +15,10 @@ public class Image_Chat_With_Agent
public static async Task RunAsync()
{
#region Create_Agent
- var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var model = "gpt-4o"; // The model needs to support multimodal inputs
- var openaiClient = new OpenAIClient(apiKey);
-
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var agent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: gpt4o,
name: "agent",
- modelName: model,
systemMessage: "You are a helpful AI assistant")
.RegisterMessageConnector() // convert OpenAI message to AutoGen message
.RegisterPrintMessage();
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs
new file mode 100644
index 000000000000..d5cb196f94f7
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs
@@ -0,0 +1,55 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Streaming_Tool_Call.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+using OpenAI;
+
+namespace AutoGen.BasicSample.GettingStart;
+
+internal class Streaming_Tool_Call
+{
+ public static async Task RunAsync()
+ {
+ #region Create_tools
+ var tools = new Tools();
+ #endregion Create_tools
+
+ #region Create_auto_invoke_middleware
+ var autoInvokeMiddleware = new FunctionCallMiddleware(
+ functions: [tools.GetWeatherFunctionContract],
+ functionMap: new Dictionary>>()
+ {
+ { tools.GetWeatherFunctionContract.Name, tools.GetWeatherWrapper },
+ });
+ #endregion Create_auto_invoke_middleware
+
+ #region Create_Agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+ var openaiClient = new OpenAIClient(apiKey);
+ var agent = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(model),
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(autoInvokeMiddleware)
+ .RegisterPrintMessage();
+ #endregion Create_Agent
+
+ IMessage finalReply = null;
+ var question = new TextMessage(Role.User, "What's the weather in Seattle");
+
+ // In streaming function call
+ // function can only be invoked untill all the chunks are collected
+ // therefore, only one ToolCallAggregateMessage chunk will be return here.
+ await foreach (var message in agent.GenerateStreamingReplyAsync([question]))
+ {
+ finalReply = message;
+ }
+
+ finalReply?.GetContent().Should().Be("The weather in Seattle is sunny.");
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs
index b441fe389da2..21a5df4c2ecd 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs
@@ -5,9 +5,9 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
#endregion Using
using FluentAssertions;
+using OpenAI;
namespace AutoGen.BasicSample;
@@ -50,12 +50,11 @@ public static async Task RunAsync()
#region Create_Agent
var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var model = "gpt-3.5-turbo";
+ var model = "gpt-4o-mini";
var openaiClient = new OpenAIClient(apiKey);
var agent = new OpenAIChatAgent(
- openAIClient: openaiClient,
+ chatClient: openaiClient.GetChatClient(model),
name: "agent",
- modelName: model,
systemMessage: "You are a helpful AI assistant")
.RegisterMessageConnector(); // convert OpenAI message to AutoGen message
#endregion Create_Agent
diff --git a/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs b/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs
index e492569cdc3d..26d9668792ef 100644
--- a/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs
@@ -1,25 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// LLMConfiguration.cs
-using AutoGen.OpenAI;
+using OpenAI;
+using OpenAI.Chat;
namespace AutoGen.BasicSample;
internal static class LLMConfiguration
{
- public static OpenAIConfig GetOpenAIGPT3_5_Turbo()
+ public static ChatClient GetOpenAIGPT4o_mini()
{
var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var modelId = "gpt-3.5-turbo";
- return new OpenAIConfig(openAIKey, modelId);
- }
-
- public static OpenAIConfig GetOpenAIGPT4()
- {
- var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var modelId = "gpt-4";
+ var modelId = "gpt-4o-mini";
- return new OpenAIConfig(openAIKey, modelId);
+ return new OpenAIClient(openAIKey).GetChatClient(modelId);
}
public static AzureOpenAIConfig GetAzureOpenAIGPT3_5_Turbo(string? deployName = null)
@@ -29,12 +23,4 @@ public static AzureOpenAIConfig GetAzureOpenAIGPT3_5_Turbo(string? deployName =
deployName = deployName ?? Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
return new AzureOpenAIConfig(endpoint, deployName, azureOpenAIKey);
}
-
- public static AzureOpenAIConfig GetAzureOpenAIGPT4(string deployName = "gpt-4")
- {
- var azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
- var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
-
- return new AzureOpenAIConfig(endpoint, deployName, azureOpenAIKey);
- }
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs
index b48e2be4aa16..8817a3df36e1 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Program.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs
@@ -1,6 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
+//await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
+
using AutoGen.BasicSample;
-Console.ReadLine();
-await Example17_ReActAgent.RunAsync();
+
+//Define allSamples collection for all examples
+List>> allSamples = new List>>();
+
+// When a new sample is created please add them to the allSamples collection
+allSamples.Add(new Tuple>("Assistant Agent", async () => { await Example01_AssistantAgent.RunAsync(); }));
+allSamples.Add(new Tuple>("Two-agent Math Chat", async () => { await Example02_TwoAgent_MathChat.RunAsync(); }));
+allSamples.Add(new Tuple>("Agent Function Call", async () => { await Example03_Agent_FunctionCall.RunAsync(); }));
+allSamples.Add(new Tuple>("Dynamic Group Chat Coding Task", async () => { await Example04_Dynamic_GroupChat_Coding_Task.RunAsync(); }));
+allSamples.Add(new Tuple>("DALL-E and GPT4v", async () => { await Example05_Dalle_And_GPT4V.RunAsync(); }));
+allSamples.Add(new Tuple>("User Proxy Agent", async () => { await Example06_UserProxyAgent.RunAsync(); }));
+allSamples.Add(new Tuple>("Dynamic Group Chat - Calculate Fibonacci", async () => { await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync(); }));
+allSamples.Add(new Tuple>("LM Studio", async () => { await Example08_LMStudio.RunAsync(); }));
+allSamples.Add(new Tuple>("Semantic Kernel", async () => { await Example10_SemanticKernel.RunAsync(); }));
+allSamples.Add(new Tuple>("Sequential Group Chat", async () => { await Sequential_GroupChat_Example.RunAsync(); }));
+allSamples.Add(new Tuple>("Two Agent - Fill Application", async () => { await TwoAgent_Fill_Application.RunAsync(); }));
+allSamples.Add(new Tuple>("Mistal Client Agent - Token Count", async () => { await Example14_MistralClientAgent_TokenCount.RunAsync(); }));
+allSamples.Add(new Tuple>("GPT4v - Binary Data Image", async () => { await Example15_GPT4V_BinaryDataImageMessage.RunAsync(); }));
+allSamples.Add(new Tuple>("ReAct Agent", async () => { await Example17_ReActAgent.RunAsync(); }));
+
+
+int idx = 1;
+Dictionary>> map = new Dictionary>>();
+Console.WriteLine("Available Examples:\n\n");
+foreach (Tuple> sample in allSamples)
+{
+ map.Add(idx, sample);
+ Console.WriteLine("{0}. {1}", idx++, sample.Item1);
+}
+
+Console.WriteLine("\n\nEnter your selection:");
+
+while (true)
+{
+ var input = Console.ReadLine();
+ if (input == "exit")
+ {
+ break;
+ }
+ int val = Convert.ToInt32(input);
+ if (!map.ContainsKey(val))
+ {
+ Console.WriteLine("Invalid choice");
+ }
+ else
+ {
+ Console.WriteLine("\nRunning {0}", map[val].Item1);
+ await map[val].Item2.Invoke();
+ }
+}
+
+
+
diff --git a/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj b/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj
index b1779b56c390..d1df8a8ed161 100644
--- a/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj
+++ b/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj
@@ -2,7 +2,7 @@
Exe
- net8.0
+ $(TestTargetFrameworks)
enable
enable
true
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj
index 5277408d595d..62c9d61633c9 100644
--- a/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj
+++ b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj
@@ -1,7 +1,7 @@
Exe
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
enable
True
$(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj b/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj
index ffe18f8a616a..fcbbb834fc63 100644
--- a/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj
@@ -2,7 +2,7 @@
Exe
- net8.0
+ $(TestTargetFrameworks)
enable
enable
True
@@ -14,8 +14,9 @@
-
+
+
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Azure_OpenAI.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Azure_OpenAI.cs
new file mode 100644
index 000000000000..dafe2e314859
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Azure_OpenAI.cs
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Connect_To_Azure_OpenAI.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using Azure;
+using Azure.AI.OpenAI;
+#endregion using_statement
+
+namespace AutoGen.OpenAI.Sample;
+
+public class Connect_To_Azure_OpenAI
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new InvalidOperationException("Please set environment variable AZURE_OPENAI_API_KEY");
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new InvalidOperationException("Please set environment variable AZURE_OPENAI_ENDPOINT");
+ var model = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? "gpt-4o-mini";
+
+ // Use AzureOpenAIClient to connect to openai model deployed on azure.
+ // The AzureOpenAIClient comes from Azure.AI.OpenAI package
+ var openAIClient = new AzureOpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey));
+
+ var agent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
+ name: "assistant",
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region send_message
+ await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion send_message
+ }
+}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
index b4206b4b6c22..2bb10e978412 100644
--- a/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
@@ -1,53 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs
+// Connect_To_Ollama.cs
+
#region using_statement
using AutoGen.Core;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
-using Azure.Core.Pipeline;
+using OpenAI;
#endregion using_statement
namespace AutoGen.OpenAI.Sample;
-#region CustomHttpClientHandler
-public sealed class CustomHttpClientHandler : HttpClientHandler
-{
- private string _modelServiceUrl;
-
- public CustomHttpClientHandler(string modelServiceUrl)
- {
- _modelServiceUrl = modelServiceUrl;
- }
-
- protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
- {
- request.RequestUri = new Uri($"{_modelServiceUrl}{request.RequestUri.PathAndQuery}");
-
- return base.SendAsync(request, cancellationToken);
- }
-}
-#endregion CustomHttpClientHandler
-
public class Connect_To_Ollama
{
public static async Task RunAsync()
{
#region create_agent
- using var client = new HttpClient(new CustomHttpClientHandler("http://localhost:11434"));
- var option = new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2024_04_01_Preview)
- {
- Transport = new HttpClientTransport(client),
- };
-
// api-key is not required for local server
// so you can use any string here
- var openAIClient = new OpenAIClient("api-key", option);
+ var openAIClient = new OpenAIClient("api-key", new OpenAIClientOptions
+ {
+ Endpoint = new Uri("http://localhost:11434/v1/"), // remember to add /v1/ at the end to connect to Ollama openai server
+ });
var model = "llama3";
var agent = new OpenAIChatAgent(
- openAIClient: openAIClient,
+ chatClient: openAIClient.GetChatClient(model),
name: "assistant",
- modelName: model,
systemMessage: "You are a helpful assistant designed to output JSON.",
seed: 0)
.RegisterMessageConnector()
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_OpenAI_o1_preview.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_OpenAI_o1_preview.cs
new file mode 100644
index 000000000000..52bc6381b9d5
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_OpenAI_o1_preview.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Connect_To_OpenAI_o1_preview.cs
+
+using AutoGen.Core;
+using OpenAI;
+
+namespace AutoGen.OpenAI.Sample;
+
+public class Connect_To_OpenAI_o1_preview
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("Please set environment variable OPENAI_API_KEY");
+ var openAIClient = new OpenAIClient(apiKey);
+
+ // until 2024/09/12
+ // openai o1-preview doesn't support systemMessage, temperature, maxTokens, streaming output
+ // so in order to use OpenAIChatAgent with o1-preview, you need to set those parameters to null
+ var agent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient("o1-preview"),
+ name: "assistant",
+ systemMessage: null,
+ temperature: null,
+ maxTokens: null,
+ seed: 0)
+ // by using RegisterMiddleware instead of RegisterStreamingMiddleware
+ // it turns an IStreamingAgent into an IAgent and disables streaming
+ .RegisterMiddleware(new OpenAIChatRequestMessageConnector())
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region send_message
+ await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion send_message
+ }
+}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs
index 5a38a3ff03b9..c71f152d0370 100644
--- a/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs
@@ -3,4 +3,4 @@
using AutoGen.OpenAI.Sample;
-Tool_Call_With_Ollama_And_LiteLLM.RunAsync().Wait();
+Structural_Output.RunAsync().Wait();
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Structural_Output.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Structural_Output.cs
new file mode 100644
index 000000000000..e83be0082bab
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Structural_Output.cs
@@ -0,0 +1,93 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Structural_Output.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+using Json.Schema;
+using Json.Schema.Generation;
+using OpenAI;
+
+namespace AutoGen.OpenAI.Sample;
+
+public class Structural_Output
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+
+ var schemaBuilder = new JsonSchemaBuilder().FromType();
+ var schema = schemaBuilder.Build();
+ var openAIClient = new OpenAIClient(apiKey);
+ var openAIClientAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
+ name: "assistant",
+ systemMessage: "You are a helpful assistant")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region chat_with_agent
+ var prompt = new TextMessage(Role.User, """
+ My name is John, I am 25 years old, and I live in Seattle. I like to play soccer and read books.
+ """);
+ var reply = await openAIClientAgent.GenerateReplyAsync(
+ messages: [prompt],
+ options: new GenerateReplyOptions
+ {
+ OutputSchema = schema,
+ });
+
+ var person = JsonSerializer.Deserialize(reply.GetContent());
+ Console.WriteLine($"Name: {person.Name}");
+ Console.WriteLine($"Age: {person.Age}");
+
+ if (!string.IsNullOrEmpty(person.Address))
+ {
+ Console.WriteLine($"Address: {person.Address}");
+ }
+
+ Console.WriteLine("Done.");
+ #endregion chat_with_agent
+
+ person.Name.Should().Be("John");
+ person.Age.Should().Be(25);
+ person.Address.Should().BeNullOrEmpty();
+ person.City.Should().Be("Seattle");
+ person.Hobbies.Count.Should().Be(2);
+ }
+
+
+ #region person_class
+ [Title("Person")]
+ public class Person
+ {
+ [JsonPropertyName("name")]
+ [Description("Name of the person")]
+ [Required]
+ public string Name { get; set; }
+
+ [JsonPropertyName("age")]
+ [Description("Age of the person")]
+ [Required]
+ public int Age { get; set; }
+
+ [JsonPropertyName("city")]
+ [Description("City of the person")]
+ public string? City { get; set; }
+
+ [JsonPropertyName("address")]
+ [Description("Address of the person")]
+ public string? Address { get; set; }
+
+ [JsonPropertyName("hobbies")]
+ [Description("Hobbies of the person")]
+ public List? Hobbies { get; set; }
+ }
+ #endregion person_class
+
+}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs
index f4fabe3c9e83..ed43c628a672 100644
--- a/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs
@@ -3,11 +3,11 @@
using AutoGen.Core;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
-using Azure.Core.Pipeline;
+using OpenAI;
namespace AutoGen.OpenAI.Sample;
+#region Function
public partial class Function
{
[Function]
@@ -16,18 +16,22 @@ public async Task GetWeatherAsync(string city)
return await Task.FromResult("The weather in " + city + " is 72 degrees and sunny.");
}
}
+#endregion Function
+
public class Tool_Call_With_Ollama_And_LiteLLM
{
public static async Task RunAsync()
{
- #region Create_Agent
- var liteLLMUrl = "http://localhost:4000";
- using var httpClient = new HttpClient(new CustomHttpClientHandler(liteLLMUrl));
- var option = new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2024_04_01_Preview)
- {
- Transport = new HttpClientTransport(httpClient),
- };
+ // Before running this code, make sure you have
+ // - Ollama:
+ // - Install dolphincoder:latest in Ollama
+ // - Ollama running on http://localhost:11434
+ // - LiteLLM
+ // - Install LiteLLM
+ // - Start LiteLLM with the following command:
+ // - litellm --model ollama_chat/dolphincoder --port 4000
+ # region Create_tools
var functions = new Function();
var functionMiddleware = new FunctionCallMiddleware(
functions: [functions.GetWeatherAsyncFunctionContract],
@@ -35,15 +39,20 @@ public static async Task RunAsync()
{
{ functions.GetWeatherAsyncFunctionContract.Name!, functions.GetWeatherAsyncWrapper },
});
+ #endregion Create_tools
+ #region Create_Agent
+ var liteLLMUrl = "http://localhost:4000";
// api-key is not required for local server
// so you can use any string here
- var openAIClient = new OpenAIClient("api-key", option);
+ var openAIClient = new OpenAIClient("api-key", new OpenAIClientOptions
+ {
+ Endpoint = new Uri("http://localhost:4000"),
+ });
var agent = new OpenAIChatAgent(
- openAIClient: openAIClient,
+ chatClient: openAIClient.GetChatClient("dolphincoder:latest"),
name: "assistant",
- modelName: "placeholder",
systemMessage: "You are a helpful AI assistant")
.RegisterMessageConnector()
.RegisterMiddleware(functionMiddleware)
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
new file mode 100644
index 000000000000..4e5247d93cec
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
@@ -0,0 +1,68 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Use_Json_Mode.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+using OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen.OpenAI.Sample;
+
+public class Use_Json_Mode
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+
+ var openAIClient = new OpenAIClient(apiKey);
+ var openAIClientAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
+ name: "assistant",
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0, // explicitly set a seed to enable deterministic output
+ responseFormat: ChatResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region chat_with_agent
+ var reply = await openAIClientAgent.SendAsync("My name is John, I am 25 years old, and I live in Seattle.");
+
+ var person = JsonSerializer.Deserialize(reply.GetContent());
+ Console.WriteLine($"Name: {person.Name}");
+ Console.WriteLine($"Age: {person.Age}");
+
+ if (!string.IsNullOrEmpty(person.Address))
+ {
+ Console.WriteLine($"Address: {person.Address}");
+ }
+
+ Console.WriteLine("Done.");
+ #endregion chat_with_agent
+
+ person.Name.Should().Be("John");
+ person.Age.Should().Be(25);
+ person.Address.Should().BeNullOrEmpty();
+ }
+
+
+ #region person_class
+ public class Person
+ {
+ [JsonPropertyName("name")]
+ public string Name { get; set; }
+
+ [JsonPropertyName("age")]
+ public int Age { get; set; }
+
+ [JsonPropertyName("address")]
+ public string Address { get; set; }
+ }
+ #endregion person_class
+}
+
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj b/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj
index 6c2266512929..45514431368f 100644
--- a/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj
@@ -2,15 +2,16 @@
Exe
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
True
$(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
enable
+
+
-
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs
index 2beb1ee7df0a..700bdfe75c7b 100644
--- a/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs
@@ -5,8 +5,8 @@
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
using Microsoft.SemanticKernel;
+using OpenAI;
#endregion Using
namespace AutoGen.SemanticKernel.Sample;
@@ -17,7 +17,7 @@ public static async Task RunAsync()
{
#region Create_plugin
var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var modelId = "gpt-3.5-turbo";
+ var modelId = "gpt-4o-mini";
var kernelBuilder = Kernel.CreateBuilder();
var kernel = kernelBuilder.Build();
var getWeatherFunction = KernelFunctionFactory.CreateFromMethod(
@@ -33,9 +33,8 @@ public static async Task RunAsync()
var openAIClient = new OpenAIClient(openAIKey);
var openAIAgent = new OpenAIChatAgent(
- openAIClient: openAIClient,
- name: "assistant",
- modelName: modelId)
+ chatClient: openAIClient.GetChatClient(modelId),
+ name: "assistant")
.RegisterMessageConnector() // register message connector so it support AutoGen built-in message types like TextMessage.
.RegisterMiddleware(kernelPluginMiddleware) // register the middleware to handle the plugin functions
.RegisterPrintMessage(); // pretty print the message to the console
diff --git a/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj b/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj
new file mode 100644
index 000000000000..76675ba12346
--- /dev/null
+++ b/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj
@@ -0,0 +1,13 @@
+
+
+
+ $(TestTargetFrameworks)
+ enable
+ enable
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs b/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs
new file mode 100644
index 000000000000..dbeb8494363d
--- /dev/null
+++ b/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using System.Runtime.CompilerServices;
+using AutoGen.Core;
+using AutoGen.WebAPI;
+
+var alice = new DummyAgent("alice");
+var bob = new DummyAgent("bob");
+
+var builder = WebApplication.CreateBuilder(args);
+// Add services to the container.
+
+// run endpoint at port 5000
+builder.WebHost.UseUrls("http://localhost:5000");
+var app = builder.Build();
+
+app.UseAgentAsOpenAIChatCompletionEndpoint(alice);
+app.UseAgentAsOpenAIChatCompletionEndpoint(bob);
+
+app.Run();
+
+public class DummyAgent : IStreamingAgent
+{
+ public DummyAgent(string name = "dummy")
+ {
+ Name = name;
+ }
+
+ public string Name { get; }
+
+ public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ return new TextMessage(Role.Assistant, $"I am dummy {this.Name}", this.Name);
+ }
+
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var reply = $"I am dummy {this.Name}";
+ foreach (var c in reply)
+ {
+ yield return new TextMessageUpdate(Role.Assistant, c.ToString(), this.Name);
+ };
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
index e395bb4a225f..81fa8e6438a8 100644
--- a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
+++ b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
@@ -1,5 +1,9 @@
-using System;
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicClientAgent.cs
+
+using System;
using System.Collections.Generic;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -16,6 +20,8 @@ public class AnthropicClientAgent : IStreamingAgent
private readonly string _systemMessage;
private readonly decimal _temperature;
private readonly int _maxTokens;
+ private readonly Tool[]? _tools;
+ private readonly ToolChoice? _toolChoice;
public AnthropicClientAgent(
AnthropicClient anthropicClient,
@@ -23,7 +29,9 @@ public AnthropicClientAgent(
string modelName,
string systemMessage = "You are a helpful AI assistant",
decimal temperature = 0.7m,
- int maxTokens = 1024)
+ int maxTokens = 1024,
+ Tool[]? tools = null,
+ ToolChoice? toolChoice = null)
{
Name = name;
_anthropicClient = anthropicClient;
@@ -31,6 +39,8 @@ public AnthropicClientAgent(
_systemMessage = systemMessage;
_temperature = temperature;
_maxTokens = maxTokens;
+ _tools = tools;
+ _toolChoice = toolChoice;
}
public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null,
@@ -40,7 +50,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G
return new MessageEnvelope(response, from: this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages,
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages,
GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (var message in _anthropicClient.StreamingChatCompletionsAsync(
@@ -54,11 +64,14 @@ private ChatCompletionRequest CreateParameters(IEnumerable messages, G
{
var chatCompletionRequest = new ChatCompletionRequest()
{
- SystemMessage = _systemMessage,
+ SystemMessage = [new SystemMessage { Text = _systemMessage }],
MaxTokens = options?.MaxToken ?? _maxTokens,
Model = _modelName,
Stream = shouldStream,
Temperature = (decimal?)options?.Temperature ?? _temperature,
+ Tools = _tools?.ToList(),
+ ToolChoice = _toolChoice ?? (_tools is { Length: > 0 } ? ToolChoice.Auto : null),
+ StopSequences = options?.StopSequence?.ToArray(),
};
chatCompletionRequest.Messages = BuildMessages(messages);
@@ -86,6 +99,22 @@ private List BuildMessages(IEnumerable messages)
}
}
- return chatMessages;
+ // merge messages with the same role
+ // fixing #2884
+ var mergedMessages = chatMessages.Aggregate(new List(), (acc, message) =>
+ {
+ if (acc.Count > 0 && acc.Last().Role == message.Role)
+ {
+ acc.Last().Content.AddRange(message.Content);
+ }
+ else
+ {
+ acc.Add(message);
+ }
+
+ return acc;
+ });
+
+ return mergedMessages;
}
}
diff --git a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
index 90bd33683f20..f106e08d35c4 100644
--- a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
+++ b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicClient.cs
using System;
@@ -24,12 +24,13 @@ public sealed class AnthropicClient : IDisposable
private static readonly JsonSerializerOptions JsonSerializerOptions = new()
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
- Converters = { new ContentBaseConverter() }
- };
-
- private static readonly JsonSerializerOptions JsonDeserializerOptions = new()
- {
- Converters = { new ContentBaseConverter() }
+ Converters =
+ {
+ new ContentBaseConverter(),
+ new JsonPropertyNameEnumConverter(),
+ new JsonPropertyNameEnumConverter(),
+ new SystemMessageConverter(),
+ }
};
public AnthropicClient(HttpClient httpClient, string baseUrl, string apiKey)
@@ -48,7 +49,9 @@ public async Task CreateChatCompletionsAsync(ChatComplet
var responseStream = await httpResponseMessage.Content.ReadAsStreamAsync();
if (httpResponseMessage.IsSuccessStatusCode)
+ {
return await DeserializeResponseAsync(responseStream, cancellationToken);
+ }
ErrorResponse res = await DeserializeResponseAsync(responseStream, cancellationToken);
throw new Exception(res.Error?.Message);
@@ -61,24 +64,58 @@ public async IAsyncEnumerable StreamingChatCompletionsAs
using var reader = new StreamReader(await httpResponseMessage.Content.ReadAsStreamAsync());
var currentEvent = new SseEvent();
+
while (await reader.ReadLineAsync() is { } line)
{
if (!string.IsNullOrEmpty(line))
{
- currentEvent.Data = line.Substring("data:".Length).Trim();
+ if (line.StartsWith("event:"))
+ {
+ currentEvent.EventType = line.Substring("event:".Length).Trim();
+ }
+ else if (line.StartsWith("data:"))
+ {
+ currentEvent.Data = line.Substring("data:".Length).Trim();
+ }
}
- else
+ else // an empty line indicates the end of an event
{
- if (currentEvent.Data == "[DONE]")
- continue;
+ if (currentEvent.EventType == "content_block_start" && !string.IsNullOrEmpty(currentEvent.Data))
+ {
+ var dataBlock = JsonSerializer.Deserialize(currentEvent.Data!);
+ if (dataBlock != null && dataBlock.ContentBlock?.Type == "tool_use")
+ {
+ currentEvent.ContentBlock = dataBlock.ContentBlock;
+ }
+ }
- if (currentEvent.Data != null)
+ if (currentEvent.EventType is "message_start" or "content_block_delta" or "message_delta" && currentEvent.Data != null)
{
- yield return await JsonSerializer.DeserializeAsync(
+ var res = await JsonSerializer.DeserializeAsync(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
+ if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
+ currentEvent.ContentBlock != null)
+ {
+ currentEvent.ContentBlock.AppendDeltaParameters(res.Delta.PartialJson!);
+ }
+ else if (res.Delta is { StopReason: "tool_use" } && currentEvent.ContentBlock != null)
+ {
+ if (res.Content == null)
+ {
+ res.Content = [currentEvent.ContentBlock.CreateToolUseContent()];
+ }
+ else
+ {
+ res.Content.Add(currentEvent.ContentBlock.CreateToolUseContent());
+ }
+
+ currentEvent = new SseEvent();
+ }
+
+ yield return res;
}
- else if (currentEvent.Data != null)
+ else if (currentEvent.EventType == "error" && currentEvent.Data != null)
{
var res = await JsonSerializer.DeserializeAsync(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken);
@@ -86,8 +123,10 @@ public async IAsyncEnumerable StreamingChatCompletionsAs
throw new Exception(res?.Error?.Message);
}
- // Reset the current event for the next one
- currentEvent = new SseEvent();
+ if (currentEvent.ContentBlock == null)
+ {
+ currentEvent = new SseEvent();
+ }
}
}
}
@@ -97,12 +136,13 @@ private Task SendRequestAsync(T requestObject, Cancellat
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _baseUrl);
var jsonRequest = JsonSerializer.Serialize(requestObject, JsonSerializerOptions);
httpRequestMessage.Content = new StringContent(jsonRequest, Encoding.UTF8, "application/json");
+ httpRequestMessage.Headers.Add("anthropic-beta", "prompt-caching-2024-07-31");
return _httpClient.SendAsync(httpRequestMessage, cancellationToken);
}
private async Task DeserializeResponseAsync(Stream responseStream, CancellationToken cancellationToken)
{
- return await JsonSerializer.DeserializeAsync(responseStream, JsonDeserializerOptions, cancellationToken)
+ return await JsonSerializer.DeserializeAsync(responseStream, JsonSerializerOptions, cancellationToken)
?? throw new Exception("Failed to deserialize response");
}
@@ -113,11 +153,50 @@ public void Dispose()
private struct SseEvent
{
+ public string EventType { get; set; }
public string? Data { get; set; }
+ public ContentBlock? ContentBlock { get; set; }
- public SseEvent(string? data = null)
+ public SseEvent(string eventType, string? data = null, ContentBlock? contentBlock = null)
{
+ EventType = eventType;
Data = data;
+ ContentBlock = contentBlock;
}
}
+
+ private class ContentBlock
+ {
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("input")]
+ public object? Input { get; set; }
+
+ public string? parameters { get; set; }
+
+ public void AppendDeltaParameters(string deltaParams)
+ {
+ StringBuilder sb = new StringBuilder(parameters);
+ sb.Append(deltaParams);
+ parameters = sb.ToString();
+ }
+
+ public ToolUseContent CreateToolUseContent()
+ {
+ return new ToolUseContent { Id = Id, Name = Name, Input = parameters };
+ }
+ }
+
+ private class DataBlock
+ {
+ [JsonPropertyName("content_block")]
+ public ContentBlock? ContentBlock { get; set; }
+ }
}
diff --git a/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj b/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj
index fefc439e00ba..a4fd32e7e345 100644
--- a/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj
+++ b/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj
@@ -1,8 +1,8 @@
- netstandard2.0
- AutoGen.Anthropic
+ $(PackageTargetFrameworks)
+ AutoGen.Anthropic
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
index 4cb8fdbb34e0..3e620f934c28 100644
--- a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
+++ b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
@@ -1,12 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// ContentConverter.cs
-
-using AutoGen.Anthropic.DTO;
-
+// ContentBaseConverter.cs
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
+using AutoGen.Anthropic.DTO;
namespace AutoGen.Anthropic.Converters;
public sealed class ContentBaseConverter : JsonConverter
@@ -24,6 +22,10 @@ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert,
return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
case "image":
return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "tool_use":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "tool_result":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
}
}
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs
new file mode 100644
index 000000000000..68b3c14bdee6
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// JsonPropertyNameEnumCoverter.cs
+
+using System;
+using System.Reflection;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.Converters;
+
+internal class JsonPropertyNameEnumConverter : JsonConverter where T : struct, Enum
+{
+ public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ string value = reader.GetString() ?? throw new JsonException("Value was null.");
+
+ foreach (var field in typeToConvert.GetFields())
+ {
+ var attribute = field.GetCustomAttribute();
+ if (attribute?.Name == value)
+ {
+ return (T)Enum.Parse(typeToConvert, field.Name);
+ }
+ }
+
+ throw new JsonException($"Unable to convert \"{value}\" to enum {typeToConvert}.");
+ }
+
+ public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
+ {
+ var field = value.GetType().GetField(value.ToString());
+ var attribute = field?.GetCustomAttribute();
+
+ if (attribute != null)
+ {
+ writer.WriteStringValue(attribute.Name);
+ }
+ else
+ {
+ writer.WriteStringValue(value.ToString());
+ }
+ }
+}
+
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/SystemMessageConverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/SystemMessageConverter.cs
new file mode 100644
index 000000000000..5bbe8a3a37f8
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Converters/SystemMessageConverter.cs
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// SystemMessageConverter.cs
+
+using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Anthropic.DTO;
+
+namespace AutoGen.Anthropic.Converters;
+
+public class SystemMessageConverter : JsonConverter
+{
+ public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ if (reader.TokenType == JsonTokenType.String)
+ {
+ return reader.GetString() ?? string.Empty;
+ }
+ if (reader.TokenType == JsonTokenType.StartArray)
+ {
+ return JsonSerializer.Deserialize(ref reader, options) ?? throw new InvalidOperationException();
+ }
+
+ throw new JsonException();
+ }
+
+ public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options)
+ {
+ if (value is string stringValue)
+ {
+ writer.WriteStringValue(stringValue);
+ }
+ else if (value is SystemMessage[] arrayValue)
+ {
+ JsonSerializer.Serialize(writer, arrayValue, options);
+ }
+ else
+ {
+ throw new JsonException();
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
index 0c1749eaa989..dfb86ef0af53 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatCompletionRequest.cs
-using System.Text.Json.Serialization;
using System.Collections.Generic;
+using System.Text.Json.Serialization;
namespace AutoGen.Anthropic.DTO;
@@ -14,7 +14,7 @@ public class ChatCompletionRequest
public List Messages { get; set; }
[JsonPropertyName("system")]
- public string? SystemMessage { get; set; }
+ public SystemMessage[]? SystemMessage { get; set; }
[JsonPropertyName("max_tokens")]
public int MaxTokens { get; set; }
@@ -37,12 +37,38 @@ public class ChatCompletionRequest
[JsonPropertyName("top_p")]
public decimal? TopP { get; set; }
+ [JsonPropertyName("tools")]
+ public List? Tools { get; set; }
+
+ [JsonPropertyName("tool_choice")]
+ public ToolChoice? ToolChoice { get; set; }
+
public ChatCompletionRequest()
{
Messages = new List();
}
}
+public class SystemMessage
+{
+ [JsonPropertyName("text")]
+ public string? Text { get; set; }
+
+ [JsonPropertyName("type")]
+ public string? Type { get; private set; } = "text";
+
+ [JsonPropertyName("cache_control")]
+ public CacheControl? CacheControl { get; set; }
+
+ public static SystemMessage CreateSystemMessage(string systemMessage) => new() { Text = systemMessage };
+
+ public static SystemMessage CreateSystemMessageWithCacheControl(string systemMessage) => new()
+ {
+ Text = systemMessage,
+ CacheControl = new CacheControl { Type = CacheControlType.Ephemeral }
+ };
+}
+
public class ChatMessage
{
[JsonPropertyName("role")]
@@ -62,4 +88,6 @@ public ChatMessage(string role, List content)
Role = role;
Content = content;
}
+
+ public void AddContent(ContentBase content) => Content.Add(content);
}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs
index c6861f9c3150..a142f2feacca 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs
@@ -1,10 +1,11 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatCompletionResponse.cs
-namespace AutoGen.Anthropic.DTO;
using System.Collections.Generic;
using System.Text.Json.Serialization;
+namespace AutoGen.Anthropic.DTO;
public class ChatCompletionResponse
{
[JsonPropertyName("content")]
@@ -49,9 +50,6 @@ public class StreamingMessage
[JsonPropertyName("role")]
public string? Role { get; set; }
- [JsonPropertyName("content")]
- public List? Content { get; set; }
-
[JsonPropertyName("model")]
public string? Model { get; set; }
@@ -72,6 +70,12 @@ public class Usage
[JsonPropertyName("output_tokens")]
public int OutputTokens { get; set; }
+
+ [JsonPropertyName("cache_creation_input_tokens")]
+ public int CacheCreationInputTokens { get; set; }
+
+ [JsonPropertyName("cache_read_input_tokens")]
+ public int CacheReadInputTokens { get; set; }
}
public class Delta
@@ -85,6 +89,9 @@ public class Delta
[JsonPropertyName("text")]
public string? Text { get; set; }
+ [JsonPropertyName("partial_json")]
+ public string? PartialJson { get; set; }
+
[JsonPropertyName("usage")]
public Usage? Usage { get; set; }
}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
index dd2481bd58f3..ade913b827c4 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
@@ -1,7 +1,9 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// Content.cs
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
+using AutoGen.Anthropic.Converters;
namespace AutoGen.Anthropic.DTO;
@@ -9,6 +11,9 @@ public abstract class ContentBase
{
[JsonPropertyName("type")]
public abstract string Type { get; }
+
+ [JsonPropertyName("cache_control")]
+ public CacheControl? CacheControl { get; set; }
}
public class TextContent : ContentBase
@@ -18,6 +23,12 @@ public class TextContent : ContentBase
[JsonPropertyName("text")]
public string? Text { get; set; }
+
+ public static TextContent CreateTextWithCacheControl(string text) => new()
+ {
+ Text = text,
+ CacheControl = new CacheControl { Type = CacheControlType.Ephemeral }
+ };
}
public class ImageContent : ContentBase
@@ -40,3 +51,45 @@ public class ImageSource
[JsonPropertyName("data")]
public string? Data { get; set; }
}
+
+public class ToolUseContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "tool_use";
+
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("input")]
+ public JsonNode? Input { get; set; }
+}
+
+public class ToolResultContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "tool_result";
+
+ [JsonPropertyName("tool_use_id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+}
+
+public class CacheControl
+{
+ [JsonPropertyName("type")]
+ public CacheControlType Type { get; set; }
+
+ public static CacheControl Create() => new CacheControl { Type = CacheControlType.Ephemeral };
+}
+
+[JsonConverter(typeof(JsonPropertyNameEnumConverter))]
+public enum CacheControlType
+{
+ [JsonPropertyName("ephemeral")]
+ Ephemeral
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs b/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs
index d02a8f6d1cfc..1a94334c88ff 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// ErrorResponse.cs
using System.Text.Json.Serialization;
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
new file mode 100644
index 000000000000..3845c4445925
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
@@ -0,0 +1,43 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Tool.cs
+
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.DTO;
+
+public class Tool
+{
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("description")]
+ public string? Description { get; set; }
+
+ [JsonPropertyName("input_schema")]
+ public InputSchema? InputSchema { get; set; }
+
+ [JsonPropertyName("cache_control")]
+ public CacheControl? CacheControl { get; set; }
+}
+
+public class InputSchema
+{
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("properties")]
+ public Dictionary? Properties { get; set; }
+
+ [JsonPropertyName("required")]
+ public List? Required { get; set; }
+}
+
+public class SchemaProperty
+{
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("description")]
+ public string? Description { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs b/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs
new file mode 100644
index 000000000000..0a5c3790e1de
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ToolChoice.cs
+
+using System.Text.Json.Serialization;
+using AutoGen.Anthropic.Converters;
+
+namespace AutoGen.Anthropic.DTO;
+
+[JsonConverter(typeof(JsonPropertyNameEnumConverter))]
+public enum ToolChoiceType
+{
+ [JsonPropertyName("auto")]
+ Auto, // Default behavior
+
+ [JsonPropertyName("any")]
+ Any, // Use any provided tool
+
+ [JsonPropertyName("tool")]
+ Tool // Force a specific tool
+}
+
+public class ToolChoice
+{
+ [JsonPropertyName("type")]
+ public ToolChoiceType Type { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ private ToolChoice(ToolChoiceType type, string? name = null)
+ {
+ Type = type;
+ Name = name;
+ }
+
+ public static ToolChoice Auto => new(ToolChoiceType.Auto);
+ public static ToolChoice Any => new(ToolChoiceType.Any);
+ public static ToolChoice ToolUse(string name) => new(ToolChoiceType.Tool, name);
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
index bb2f5820f74c..af06a0547849 100644
--- a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
+++ b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
@@ -6,6 +6,7 @@
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
+using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Anthropic.DTO;
@@ -28,7 +29,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
: response;
}
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
@@ -36,7 +37,7 @@ public async IAsyncEnumerable InvokeAsync(MiddlewareContext c
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
- if (reply is IStreamingMessage chatMessage)
+ if (reply is IMessage chatMessage)
{
var response = ProcessChatCompletionResponse(chatMessage, agent);
if (response is not null)
@@ -51,9 +52,20 @@ public async IAsyncEnumerable InvokeAsync(MiddlewareContext c
}
}
- private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage chatMessage,
+ private IMessage? ProcessChatCompletionResponse(IMessage chatMessage,
IStreamingAgent agent)
{
+ if (chatMessage.Content.Content is { Count: 1 } &&
+ chatMessage.Content.Content[0] is ToolUseContent toolUseContent)
+ {
+ return new ToolCallMessage(
+ toolUseContent.Name ??
+ throw new InvalidOperationException($"Expected {nameof(toolUseContent.Name)} to be specified"),
+ toolUseContent.Input?.ToString() ??
+ throw new InvalidOperationException($"Expected {nameof(toolUseContent.Input)} to be specified"),
+ from: agent.Name);
+ }
+
var delta = chatMessage.Content.Delta;
return delta != null && !string.IsNullOrEmpty(delta.Text)
? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name)
@@ -71,16 +83,20 @@ private async Task> ProcessMessageAsync(IEnumerable ProcessTextMessage(textMessage, agent),
ImageMessage imageMessage =>
- new MessageEnvelope(new ChatMessage("user",
+ (MessageEnvelope[])[new MessageEnvelope(new ChatMessage("user",
new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } }
.ToList()),
- from: agent.Name),
+ from: agent.Name)],
MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent),
- _ => message,
+
+ ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
+ ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
+ AggregateMessage toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
+ _ => [message],
};
- processedMessages.Add(processedMessage);
+ processedMessages.AddRange(processedMessage);
}
return processedMessages;
@@ -93,15 +109,42 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
throw new ArgumentNullException(nameof(response.Content));
}
- if (response.Content.Count != 1)
+ // When expecting a tool call, sometimes the response will contain two messages, one chat and one tool.
+ // The first message is typically a TextContent, of the LLM explaining what it is trying to do.
+ // The second message contains the tool call.
+ if (response.Content.Count > 1)
{
- throw new NotSupportedException($"{nameof(response.Content)} != 1");
+ if (response.Content.Count == 2 && response.Content[0] is TextContent &&
+ response.Content[1] is ToolUseContent toolUseContent)
+ {
+ return new ToolCallMessage(toolUseContent.Name ?? string.Empty,
+ toolUseContent.Input?.ToJsonString() ?? string.Empty,
+ from: from.Name);
+ }
+
+ throw new NotSupportedException($"Expected {nameof(response.Content)} to have one output");
}
- return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name);
+ var content = response.Content[0];
+ switch (content)
+ {
+ case TextContent textContent:
+ return new TextMessage(Role.Assistant, textContent.Text ?? string.Empty, from: from.Name);
+
+ case ToolUseContent toolUseContent:
+ return new ToolCallMessage(toolUseContent.Name ?? string.Empty,
+ toolUseContent.Input?.ToJsonString() ?? string.Empty,
+ from: from.Name);
+
+ case ImageContent:
+ throw new InvalidOperationException(
+ "Claude is an image understanding model only. It can interpret and analyze images, but it cannot generate, produce, edit, manipulate or create images");
+ default:
+ throw new ArgumentOutOfRangeException(nameof(content));
+ }
}
- private IMessage ProcessTextMessage(TextMessage textMessage, IAgent agent)
+ private IEnumerable> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
ChatMessage messages;
@@ -139,10 +182,10 @@ private IMessage ProcessTextMessage(TextMessage textMessage, IAgent
"user", textMessage.Content);
}
- return new MessageEnvelope(messages, from: textMessage.From);
+ return [new MessageEnvelope(messages, from: textMessage.From)];
}
- private async Task ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
+ private async Task> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
{
var content = new List();
foreach (var message in multiModalMessage.Content)
@@ -158,8 +201,7 @@ private async Task ProcessMultiModalMessageAsync(MultiModalMessage mul
}
}
- var chatMessage = new ChatMessage("user", content);
- return MessageEnvelope.Create(chatMessage, agent.Name);
+ return [MessageEnvelope.Create(new ChatMessage("user", content), agent.Name)];
}
private async Task ProcessImageSourceAsync(ImageMessage imageMessage)
@@ -192,4 +234,52 @@ private async Task ProcessImageSourceAsync(ImageMessage imageMessag
Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync())
};
}
+
+ private IEnumerable ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
+ {
+ var chatMessage = new ChatMessage("assistant", new List());
+ foreach (var toolCall in toolCallMessage.ToolCalls)
+ {
+ chatMessage.AddContent(new ToolUseContent
+ {
+ Id = toolCall.ToolCallId,
+ Name = toolCall.FunctionName,
+ Input = JsonNode.Parse(toolCall.FunctionArguments)
+ });
+ }
+
+ return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)];
+ }
+
+ private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage)
+ {
+ var chatMessage = new ChatMessage("user", new List());
+ foreach (var toolCall in toolCallResultMessage.ToolCalls)
+ {
+ chatMessage.AddContent(new ToolResultContent
+ {
+ Id = toolCall.ToolCallId ?? string.Empty,
+ Content = toolCall.Result,
+ });
+ }
+
+ return [MessageEnvelope.Create(chatMessage, toolCallResultMessage.From)];
+ }
+
+ private IEnumerable ProcessToolCallAggregateMessage(AggregateMessage aggregateMessage, IAgent agent)
+ {
+ if (aggregateMessage.From is { } from && from != agent.Name)
+ {
+ var contents = aggregateMessage.Message2.ToolCalls.Select(t => t.Result);
+ var messages = contents.Select(c =>
+ new ChatMessage("assistant", c ?? throw new ArgumentNullException(nameof(c))));
+
+ return messages.Select(m => new MessageEnvelope(m, from: from));
+ }
+
+ var toolCallMessage = ProcessToolCallMessage(aggregateMessage.Message1, agent);
+ var toolCallResult = ProcessToolCallResultMessage(aggregateMessage.Message2);
+
+ return toolCallMessage.Concat(toolCallResult);
+ }
}
diff --git a/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs b/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs
index e70572cbddf2..494a6686f521 100644
--- a/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs
+++ b/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// Constants.cs
+// AnthropicConstants.cs
namespace AutoGen.Anthropic.Utils;
@@ -11,4 +11,5 @@ public static class AnthropicConstants
public static string Claude3Opus = "claude-3-opus-20240229";
public static string Claude3Sonnet = "claude-3-sonnet-20240229";
public static string Claude3Haiku = "claude-3-haiku-20240307";
+ public static string Claude35Sonnet = "claude-3-5-sonnet-20240620";
}
diff --git a/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs b/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs
new file mode 100644
index 000000000000..452c5b1c3079
--- /dev/null
+++ b/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs
@@ -0,0 +1,202 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatCompletionsClientAgent.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.AzureAIInference.Extension;
+using AutoGen.Core;
+using Azure.AI.Inference;
+
+namespace AutoGen.AzureAIInference;
+
+///
+/// ChatCompletions client agent. This agent is a thin wrapper around to provide a simple interface for chat completions.
+/// supports the following message types:
+///
+/// -
+///
where T is : chat request message.
+///
+///
+/// returns the following message types:
+///
+/// -
+///
where T is : chat response message.
+/// where T is : streaming chat completions update.
+///
+///
+///
+public class ChatCompletionsClientAgent : IStreamingAgent
+{
+ private readonly ChatCompletionsClient chatCompletionsClient;
+ private readonly ChatCompletionsOptions options;
+ private readonly string systemMessage;
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// chat completions client
+ /// agent name
+ /// model name. e.g. gpt-turbo-3.5
+ /// system message
+ /// temperature
+ /// max tokens to generated
+ /// response format, set it to to enable json mode.
+ /// seed to use, set it to enable deterministic output
+ /// functions
+ public ChatCompletionsClientAgent(
+ ChatCompletionsClient chatCompletionsClient,
+ string name,
+ string modelName,
+ string systemMessage = "You are a helpful AI assistant",
+ float temperature = 0.7f,
+ int maxTokens = 1024,
+ int? seed = null,
+ ChatCompletionsResponseFormat? responseFormat = null,
+ IEnumerable? functions = null)
+ : this(
+ chatCompletionsClient: chatCompletionsClient,
+ name: name,
+ options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions),
+ systemMessage: systemMessage)
+ {
+ }
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// chat completions client
+ /// agent name
+ /// system message
+ /// chat completion option. The option can't contain messages
+ public ChatCompletionsClientAgent(
+ ChatCompletionsClient chatCompletionsClient,
+ string name,
+ ChatCompletionsOptions options,
+ string systemMessage = "You are a helpful AI assistant")
+ {
+ if (options.Messages is { Count: > 0 })
+ {
+ throw new ArgumentException("Messages should not be provided in options");
+ }
+
+ this.chatCompletionsClient = chatCompletionsClient;
+ this.Name = name;
+ this.options = options;
+ this.systemMessage = systemMessage;
+ }
+
+ public string Name { get; }
+
+ public async Task GenerateReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ var settings = this.CreateChatCompletionsOptions(options, messages);
+ var reply = await this.chatCompletionsClient.CompleteAsync(settings, cancellationToken: cancellationToken);
+
+ return new MessageEnvelope(reply, from: this.Name);
+ }
+
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var settings = this.CreateChatCompletionsOptions(options, messages);
+ var response = await this.chatCompletionsClient.CompleteStreamingAsync(settings, cancellationToken);
+ await foreach (var update in response.WithCancellation(cancellationToken))
+ {
+ yield return new MessageEnvelope(update, from: this.Name);
+ }
+ }
+
+ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages)
+ {
+ var oaiMessages = messages.Select(m => m switch
+ {
+ IMessage chatRequestMessage => chatRequestMessage.Content,
+ _ => throw new ArgumentException("Invalid message type")
+ });
+
+ // add system message if there's no system message in messages
+ if (!oaiMessages.Any(m => m is ChatRequestSystemMessage))
+ {
+ oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages);
+ }
+
+ // clone the options by serializing and deserializing
+ var json = JsonSerializer.Serialize(this.options);
+ var settings = JsonSerializer.Deserialize(json) ?? throw new InvalidOperationException("Failed to clone options");
+
+ foreach (var m in oaiMessages)
+ {
+ settings.Messages.Add(m);
+ }
+
+ settings.Temperature = options?.Temperature ?? settings.Temperature;
+ settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens;
+
+ foreach (var functions in this.options.Tools)
+ {
+ settings.Tools.Add(functions);
+ }
+
+ foreach (var stopSequence in this.options.StopSequences)
+ {
+ settings.StopSequences.Add(stopSequence);
+ }
+
+ var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToAzureAIInferenceFunctionDefinition()).ToList();
+ if (openAIFunctionDefinitions is { Count: > 0 })
+ {
+ foreach (var f in openAIFunctionDefinitions)
+ {
+ settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
+ }
+ }
+
+ if (options?.StopSequence is var sequence && sequence is { Length: > 0 })
+ {
+ foreach (var seq in sequence)
+ {
+ settings.StopSequences.Add(seq);
+ }
+ }
+
+ return settings;
+ }
+
+ private static ChatCompletionsOptions CreateChatCompletionOptions(
+ string modelName,
+ float temperature = 0.7f,
+ int maxTokens = 1024,
+ int? seed = null,
+ ChatCompletionsResponseFormat? responseFormat = null,
+ IEnumerable? functions = null)
+ {
+ var options = new ChatCompletionsOptions()
+ {
+ Model = modelName,
+ Temperature = temperature,
+ MaxTokens = maxTokens,
+ Seed = seed,
+ ResponseFormat = responseFormat,
+ };
+
+ if (functions is not null)
+ {
+ foreach (var f in functions)
+ {
+ options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
+ }
+ }
+
+ return options;
+ }
+}
diff --git a/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj b/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj
new file mode 100644
index 000000000000..e9401bc4bc22
--- /dev/null
+++ b/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj
@@ -0,0 +1,25 @@
+
+
+ $(PackageTargetFrameworks)
+ AutoGen.AzureAIInference
+
+
+
+
+
+
+ AutoGen.AzureAIInference
+
+ Azure AI Inference Intergration for AutoGen.
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs b/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs
new file mode 100644
index 000000000000..8faf29604ed1
--- /dev/null
+++ b/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatComptionClientAgentExtension.cs
+
+using AutoGen.Core;
+
+namespace AutoGen.AzureAIInference.Extension;
+
+public static class ChatComptionClientAgentExtension
+{
+ ///
+ /// Register an to the
+ ///
+ /// the connector to use. If null, a new instance of will be created.
+ public static MiddlewareStreamingAgent RegisterMessageConnector(
+ this ChatCompletionsClientAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null)
+ {
+ if (connector == null)
+ {
+ connector = new AzureAIInferenceChatRequestMessageConnector();
+ }
+
+ return agent.RegisterStreamingMiddleware(connector);
+ }
+
+ ///
+ /// Register an to the where T is
+ ///
+ /// the connector to use. If null, a new instance of will be created.
+ public static MiddlewareStreamingAgent RegisterMessageConnector(
+ this MiddlewareStreamingAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null)
+ {
+ if (connector == null)
+ {
+ connector = new AzureAIInferenceChatRequestMessageConnector();
+ }
+
+ return agent.RegisterStreamingMiddleware(connector);
+ }
+}
diff --git a/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs b/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs
new file mode 100644
index 000000000000..4cd7b3864f95
--- /dev/null
+++ b/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs
@@ -0,0 +1,64 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// FunctionContractExtension.cs
+
+using System;
+using System.Collections.Generic;
+using AutoGen.Core;
+using Azure.AI.Inference;
+using Json.Schema;
+using Json.Schema.Generation;
+
+namespace AutoGen.AzureAIInference.Extension;
+
+public static class FunctionContractExtension
+{
+ ///
+ /// Convert a to a that can be used in gpt funciton call.
+ ///
+ /// function contract
+ ///
+ public static FunctionDefinition ToAzureAIInferenceFunctionDefinition(this FunctionContract functionContract)
+ {
+ var functionDefinition = new FunctionDefinition
+ {
+ Name = functionContract.Name,
+ Description = functionContract.Description,
+ };
+ var requiredParameterNames = new List();
+ var propertiesSchemas = new Dictionary();
+ var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
+ foreach (var param in functionContract.Parameters ?? [])
+ {
+ if (param.Name is null)
+ {
+ throw new InvalidOperationException("Parameter name cannot be null");
+ }
+
+ var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
+ if (param.Description != null)
+ {
+ schemaBuilder = schemaBuilder.Description(param.Description);
+ }
+
+ if (param.IsRequired)
+ {
+ requiredParameterNames.Add(param.Name);
+ }
+
+ var schema = schemaBuilder.Build();
+ propertiesSchemas[param.Name] = schema;
+
+ }
+ propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
+ propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);
+
+ var option = new System.Text.Json.JsonSerializerOptions()
+ {
+ PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase
+ };
+
+ functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option);
+
+ return functionDefinition;
+ }
+}
diff --git a/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs b/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs
new file mode 100644
index 000000000000..9c5d22e2e7e7
--- /dev/null
+++ b/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs
@@ -0,0 +1,302 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AzureAIInferenceChatRequestMessageConnector.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.Core;
+using Azure.AI.Inference;
+
+namespace AutoGen.AzureAIInference;
+
+///
+/// This middleware converts the incoming to where T is before sending to agent. And converts the output to after receiving from agent.
+/// Supported are
+/// -
+/// -
+/// -
+/// -
+/// -
+/// - where T is
+/// - where TMessage1 is and TMessage2 is
+///
+public class AzureAIInferenceChatRequestMessageConnector : IStreamingMiddleware
+{
+ private bool strictMode = false;
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// If true, will throw an
+ /// When the message type is not supported. If false, it will ignore the unsupported message type.
+ public AzureAIInferenceChatRequestMessageConnector(bool strictMode = false)
+ {
+ this.strictMode = strictMode;
+ }
+
+ public string? Name => nameof(AzureAIInferenceChatRequestMessageConnector);
+
+ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
+ {
+ var chatMessages = ProcessIncomingMessages(agent, context.Messages);
+
+ var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);
+
+ return PostProcessMessage(reply);
+ }
+
+ public async IAsyncEnumerable InvokeAsync(
+ MiddlewareContext context,
+ IStreamingAgent agent,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var chatMessages = ProcessIncomingMessages(agent, context.Messages);
+ var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
+ string? currentToolName = null;
+ await foreach (var reply in streamingReply)
+ {
+ if (reply is IMessage update)
+ {
+ if (update.Content.FunctionName is string functionName)
+ {
+ currentToolName = functionName;
+ }
+ else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate && toolCallUpdate.Name is string toolCallName)
+ {
+ currentToolName = toolCallName;
+ }
+ var postProcessMessage = PostProcessStreamingMessage(update, currentToolName);
+ if (postProcessMessage != null)
+ {
+ yield return postProcessMessage;
+ }
+ }
+ else
+ {
+ if (this.strictMode)
+ {
+ throw new InvalidOperationException($"Invalid streaming message type {reply.GetType().Name}");
+ }
+ else
+ {
+ yield return reply;
+ }
+ }
+ }
+ }
+
+ public IMessage PostProcessMessage(IMessage message)
+ {
+ return message switch
+ {
+ IMessage m => PostProcessChatResponseMessage(m.Content, m.From),
+ IMessage m => PostProcessChatCompletions(m),
+ _ when strictMode is false => message,
+ _ => throw new InvalidOperationException($"Invalid return message type {message.GetType().Name}"),
+ };
+ }
+
+ public IMessage? PostProcessStreamingMessage(IMessage update, string? currentToolName)
+ {
+ if (update.Content.ContentUpdate is string contentUpdate && string.IsNullOrEmpty(contentUpdate) == false)
+ {
+ // text message
+ return new TextMessageUpdate(Role.Assistant, contentUpdate, from: update.From);
+ }
+ else if (update.Content.FunctionName is string functionName)
+ {
+ return new ToolCallMessageUpdate(functionName, string.Empty, from: update.From);
+ }
+ else if (update.Content.FunctionArgumentsUpdate is string functionArgumentsUpdate && currentToolName is string)
+ {
+ return new ToolCallMessageUpdate(currentToolName, functionArgumentsUpdate, from: update.From);
+ }
+ else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate tooCallUpdate && currentToolName is string)
+ {
+ return new ToolCallMessageUpdate(tooCallUpdate.Name ?? currentToolName, tooCallUpdate.ArgumentsUpdate, from: update.From);
+ }
+ else
+ {
+ return null;
+ }
+ }
+
+ private IMessage PostProcessChatCompletions(IMessage message)
+ {
+ // throw exception if prompt filter results is not null
+ if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered)
+ {
+ throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input.");
+ }
+
+ return PostProcessChatResponseMessage(message.Content.Choices[0].Message, message.From);
+ }
+
+ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from)
+ {
+ var textContent = chatResponseMessage.Content;
+ if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any())
+ {
+ var functionToolCalls = chatResponseMessage.ToolCalls
+ .Where(tc => tc is ChatCompletionsFunctionToolCall)
+ .Select(tc => (ChatCompletionsFunctionToolCall)tc);
+
+ var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });
+
+ return new ToolCallMessage(toolCalls, from)
+ {
+ Content = textContent,
+ };
+ }
+
+ if (textContent is string content && !string.IsNullOrEmpty(content))
+ {
+ return new TextMessage(Role.Assistant, content, from);
+ }
+
+ throw new InvalidOperationException("Invalid ChatResponseMessage");
+ }
+
+ public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages)
+ {
+ return messages.SelectMany(m =>
+ {
+ if (m is IMessage crm)
+ {
+ return [crm];
+ }
+ else
+ {
+ var chatRequestMessages = m switch
+ {
+ TextMessage textMessage => ProcessTextMessage(agent, textMessage),
+ ImageMessage imageMessage when (imageMessage.From is null || imageMessage.From != agent.Name) => ProcessImageMessage(agent, imageMessage),
+ MultiModalMessage multiModalMessage when (multiModalMessage.From is null || multiModalMessage.From != agent.Name) => ProcessMultiModalMessage(agent, multiModalMessage),
+ ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(agent, toolCallMessage),
+ ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
+ AggregateMessage aggregateMessage => ProcessFunctionCallMiddlewareMessage(agent, aggregateMessage),
+ _ when strictMode is false => [],
+ _ => throw new InvalidOperationException($"Invalid message type: {m.GetType().Name}"),
+ };
+
+ if (chatRequestMessages.Any())
+ {
+ return chatRequestMessages.Select(cm => MessageEnvelope.Create(cm, m.From));
+ }
+ else
+ {
+ return [m];
+ }
+ }
+ });
+ }
+
+ private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message)
+ {
+ if (message.Role == Role.System)
+ {
+ return [new ChatRequestSystemMessage(message.Content)];
+ }
+
+ if (agent.Name == message.From)
+ {
+ return [new ChatRequestAssistantMessage { Content = message.Content }];
+ }
+ else
+ {
+ return message.From switch
+ {
+ null when message.Role == Role.User => [new ChatRequestUserMessage(message.Content)],
+ null when message.Role == Role.Assistant => [new ChatRequestAssistantMessage() { Content = message.Content }],
+ null => throw new InvalidOperationException("Invalid Role"),
+ _ => [new ChatRequestUserMessage(message.Content)]
+ };
+ }
+ }
+
+ private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message)
+ {
+ if (agent.Name == message.From)
+ {
+ // image message from assistant is not supported
+ throw new ArgumentException("ImageMessage is not supported when message.From is the same with agent");
+ }
+
+ var imageContentItem = this.CreateChatMessageImageContentItemFromImageMessage(message);
+ return [new ChatRequestUserMessage([imageContentItem])];
+ }
+
+ private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message)
+ {
+ if (agent.Name == message.From)
+ {
+ // image message from assistant is not supported
+ throw new ArgumentException("MultiModalMessage is not supported when message.From is the same with agent");
+ }
+
+ IEnumerable items = message.Content.Select(ci => ci switch
+ {
+ TextMessage text => new ChatMessageTextContentItem(text.Content),
+ ImageMessage image => this.CreateChatMessageImageContentItemFromImageMessage(image),
+ _ => throw new NotImplementedException(),
+ });
+
+ return [new ChatRequestUserMessage(items)];
+ }
+
+ private ChatMessageImageContentItem CreateChatMessageImageContentItemFromImageMessage(ImageMessage message)
+ {
+ return message.Data is null && message.Url is not null
+ ? new ChatMessageImageContentItem(new Uri(message.Url))
+ : new ChatMessageImageContentItem(message.Data, message.Data?.MediaType);
+ }
+
+ private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message)
+ {
+ if (message.From is not null && message.From != agent.Name)
+ {
+ throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
+ }
+
+ var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
+ var textContent = message.GetContent() ?? string.Empty;
+ var chatRequestMessage = new ChatRequestAssistantMessage() { Content = textContent };
+ foreach (var tc in toolCall)
+ {
+ chatRequestMessage.ToolCalls.Add(tc);
+ }
+
+ return [chatRequestMessage];
+ }
+
+ private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message)
+ {
+ return message.ToolCalls
+ .Where(tc => tc.Result is not null)
+ .Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}"));
+ }
+
+ private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage)
+ {
+ if (aggregateMessage.From is not null && aggregateMessage.From != agent.Name)
+ {
+ // convert as user message
+ var resultMessage = aggregateMessage.Message2;
+
+ return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result));
+ }
+ else
+ {
+ var toolCallMessage1 = aggregateMessage.Message1;
+ var toolCallResultMessage = aggregateMessage.Message2;
+
+ var assistantMessage = this.ProcessToolCallMessage(agent, toolCallMessage1);
+ var toolCallResults = this.ProcessToolCallResultMessage(toolCallResultMessage);
+
+ return assistantMessage.Concat(toolCallResults);
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Agent/IAgent.cs b/dotnet/src/AutoGen.Core/Agent/IAgent.cs
index b9149008480d..f2b8ce67d01b 100644
--- a/dotnet/src/AutoGen.Core/Agent/IAgent.cs
+++ b/dotnet/src/AutoGen.Core/Agent/IAgent.cs
@@ -5,12 +5,17 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
+using Json.Schema;
namespace AutoGen.Core;
-public interface IAgent
+
+public interface IAgentMetaInformation
{
public string Name { get; }
+}
+public interface IAgent : IAgentMetaInformation
+{
///
/// Generate reply
///
@@ -38,6 +43,7 @@ public GenerateReplyOptions(GenerateReplyOptions other)
this.MaxToken = other.MaxToken;
this.StopSequence = other.StopSequence?.Select(s => s)?.ToArray();
this.Functions = other.Functions?.Select(f => f)?.ToArray();
+ this.OutputSchema = other.OutputSchema;
}
public float? Temperature { get; set; }
@@ -47,4 +53,9 @@ public GenerateReplyOptions(GenerateReplyOptions other)
public string[]? StopSequence { get; set; }
public FunctionContract[]? Functions { get; set; }
+
+ ///
+ /// Structural schema for the output. This property only applies to certain LLMs.
+ ///
+ public JsonSchema? OutputSchema { get; set; }
}
diff --git a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
index 665f18bac12a..6b7794c921ad 100644
--- a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
+++ b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
@@ -11,7 +11,7 @@ namespace AutoGen.Core;
///
public interface IStreamingAgent : IAgent
{
- public IAsyncEnumerable GenerateStreamingReplyAsync(
+ public IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default);
diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
index 52967d6ff1ce..c7643b1e4735 100644
--- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
+++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
@@ -47,7 +47,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
}
- public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
@@ -83,7 +83,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat
return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken);
}
- public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (streamingMiddleware is null)
{
diff --git a/dotnet/src/AutoGen.Core/AutoGen.Core.csproj b/dotnet/src/AutoGen.Core/AutoGen.Core.csproj
index 60aeb3ae3fca..8cf9e9183d40 100644
--- a/dotnet/src/AutoGen.Core/AutoGen.Core.csproj
+++ b/dotnet/src/AutoGen.Core/AutoGen.Core.csproj
@@ -1,6 +1,6 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
AutoGen.Core
@@ -17,7 +17,10 @@
-
+
+
+
+
diff --git a/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs b/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs
index 44ce8838b73a..13ce970d551b 100644
--- a/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs
+++ b/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentExtension.cs
+using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
@@ -60,14 +61,14 @@ public static async Task SendAsync(
}
///
- /// Send message to another agent.
+ /// Send message to another agent and iterate over the responses.
///
/// sender agent.
/// receiver agent.
/// chat history.
/// max conversation round.
/// conversation history
- public static async Task> SendAsync(
+ public static IAsyncEnumerable SendAsync(
this IAgent agent,
IAgent receiver,
IEnumerable chatHistory,
@@ -78,21 +79,21 @@ public static async Task> SendAsync(
{
var gc = manager.GroupChat;
- return await agent.SendMessageToGroupAsync(gc, chatHistory, maxRound, ct);
+ return gc.SendAsync(chatHistory, maxRound, ct);
}
var groupChat = new RoundRobinGroupChat(
- agents: new[]
- {
+ agents:
+ [
agent,
receiver,
- });
+ ]);
- return await groupChat.CallAsync(chatHistory, maxRound, ct: ct);
+ return groupChat.SendAsync(chatHistory, maxRound, cancellationToken: ct);
}
///
- /// Send message to another agent.
+ /// Send message to another agent and iterate over the responses.
///
/// sender agent.
/// message to send. will be added to the end of if provided
@@ -100,7 +101,7 @@ public static async Task> SendAsync(
/// chat history.
/// max conversation round.
/// conversation history
- public static async Task> SendAsync(
+ public static IAsyncEnumerable SendAsync(
this IAgent agent,
IAgent receiver,
string message,
@@ -116,11 +117,12 @@ public static async Task> SendAsync(
chatHistory = chatHistory ?? new List();
chatHistory = chatHistory.Append(msg);
- return await agent.SendAsync(receiver, chatHistory, maxRound, ct);
+ return agent.SendAsync(receiver, chatHistory, maxRound, ct);
}
///
- /// Shortcut API to send message to another agent.
+ /// Shortcut API to send message to another agent and get all responses.
+ /// To iterate over the responses, use or
///
/// sender agent
/// receiver agent
@@ -144,10 +146,16 @@ public static async Task> InitiateChatAsync(
chatHistory.Add(msg);
}
- return await agent.SendAsync(receiver, chatHistory, maxRound, ct);
+ await foreach (var msg in agent.SendAsync(receiver, chatHistory, maxRound, ct))
+ {
+ chatHistory.Add(msg);
+ }
+
+ return chatHistory;
}
- public static async Task> SendMessageToGroupAsync(
+ [Obsolete("use GroupChatExtension.SendAsync")]
+ public static IAsyncEnumerable SendMessageToGroupAsync(
this IAgent agent,
IGroupChat groupChat,
string msg,
@@ -159,16 +167,18 @@ public static async Task> SendMessageToGroupAsync(
chatHistory = chatHistory ?? Enumerable.Empty();
chatHistory = chatHistory.Append(chatMessage);
- return await agent.SendMessageToGroupAsync(groupChat, chatHistory, maxRound, ct);
+ return agent.SendMessageToGroupAsync(groupChat, chatHistory, maxRound, ct);
}
- public static async Task> SendMessageToGroupAsync(
+ [Obsolete("use GroupChatExtension.SendAsync")]
+ public static IAsyncEnumerable SendMessageToGroupAsync(
this IAgent _,
IGroupChat groupChat,
IEnumerable? chatHistory = null,
int maxRound = 10,
CancellationToken ct = default)
{
- return await groupChat.CallAsync(chatHistory, maxRound, ct);
+ chatHistory = chatHistory ?? Enumerable.Empty();
+ return groupChat.SendAsync(chatHistory, maxRound, ct);
}
}
diff --git a/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
index e3e44622c817..89da7708797c 100644
--- a/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
+++ b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
@@ -4,6 +4,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
namespace AutoGen.Core;
@@ -23,6 +25,46 @@ public static void AddInitializeMessage(this IAgent agent, string message, IGrou
groupChat.SendIntroduction(msg);
}
+ ///
+ /// Send messages to a and return new messages from the group chat.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static async IAsyncEnumerable SendAsync(
+ this IGroupChat groupChat,
+ IEnumerable chatHistory,
+ int maxRound = 10,
+ [EnumeratorCancellation]
+ CancellationToken cancellationToken = default)
+ {
+ while (maxRound-- > 0)
+ {
+ var messages = await groupChat.CallAsync(chatHistory, maxRound: 1, cancellationToken);
+
+ // if no new messages, break the loop
+ if (messages.Count() == chatHistory.Count())
+ {
+ yield break;
+ }
+
+ var lastMessage = messages.Last();
+
+ yield return lastMessage;
+ if (lastMessage.IsGroupChatTerminateMessage())
+ {
+ yield break;
+ }
+
+ // messages will contain the complete chat history, include initalize messages
+ // but we only need to add the last message to the chat history
+ // fix #3268
+ chatHistory = chatHistory.Append(lastMessage);
+ }
+ }
+
///
/// Send an instruction message to the group chat.
///
@@ -78,6 +120,7 @@ public static bool IsGroupChatClearMessage(this IMessage message)
return message.GetContent()?.Contains(CLEAR_MESSAGES) ?? false;
}
+ [Obsolete]
public static IEnumerable ProcessConversationForAgent(
this IGroupChat groupChat,
IEnumerable initialMessages,
@@ -100,8 +143,7 @@ internal static IEnumerable ProcessConversationsForRolePlay(
var msg = @$"From {x.From}:
{x.GetContent()}
-round #
- {i}";
+round # {i}";
return new TextMessage(Role.User, content: msg);
});
diff --git a/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs b/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
index 2c828c26d890..556c16436c63 100644
--- a/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
+++ b/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
@@ -35,7 +35,7 @@ public class FunctionContract
///
/// The name of the function.
///
- public string? Name { get; set; }
+ public string Name { get; set; } = null!;
///
/// The description of the function.
diff --git a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
index d6b71e2a3f13..acff955a292c 100644
--- a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
+++ b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
@@ -12,11 +13,7 @@ public class Graph
{
private readonly List transitions = new List();
- public Graph()
- {
- }
-
- public Graph(IEnumerable? transitions)
+ public Graph(IEnumerable? transitions = null)
{
if (transitions != null)
{
@@ -40,13 +37,13 @@ public void AddTransition(Transition transition)
/// the from agent
/// messages
/// A list of agents that the messages can be transit to
- public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages)
+ public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages, CancellationToken ct = default)
{
var nextAgents = new List();
var availableTransitions = transitions.FindAll(t => t.From == fromAgent) ?? Enumerable.Empty();
foreach (var transition in availableTransitions)
{
- if (await transition.CanTransitionAsync(messages))
+ if (await transition.CanTransitionAsync(messages, ct))
{
nextAgents.Add(transition.To);
}
@@ -63,7 +60,7 @@ public class Transition
{
private readonly IAgent _from;
private readonly IAgent _to;
- private readonly Func, Task>? _canTransition;
+ private readonly Func, CancellationToken, Task>? _canTransition;
///
/// Create a new instance of .
@@ -73,22 +70,44 @@ public class Transition
/// from agent
/// to agent
/// detect if the transition is allowed, default to be always true
- internal Transition(IAgent from, IAgent to, Func, Task>? canTransitionAsync = null)
+ internal Transition(IAgent from, IAgent to, Func, CancellationToken, Task>? canTransitionAsync = null)
{
_from = from;
_to = to;
_canTransition = canTransitionAsync;
}
+ ///
+ /// Create a new instance of without transition condition check.
+ ///
+ /// "
+ public static Transition Create(TFromAgent from, TToAgent to)
+ where TFromAgent : IAgent
+ where TToAgent : IAgent
+ {
+ return new Transition(from, to, (fromAgent, toAgent, messages, _) => Task.FromResult(true));
+ }
+
///
/// Create a new instance of .
///
/// "
- public static Transition Create(TFromAgent from, TToAgent to, Func, Task>? canTransitionAsync = null)
+ public static Transition Create(TFromAgent from, TToAgent to, Func, Task> canTransitionAsync)
+ where TFromAgent : IAgent
+ where TToAgent : IAgent
+ {
+ return new Transition(from, to, (fromAgent, toAgent, messages, _) => canTransitionAsync.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages));
+ }
+
+ ///
+ /// Create a new instance of with cancellation token.
+ ///
+ /// "
+ public static Transition Create(TFromAgent from, TToAgent to, Func, CancellationToken, Task> canTransitionAsync)
where TFromAgent : IAgent
where TToAgent : IAgent
{
- return new Transition(from, to, (fromAgent, toAgent, messages) => canTransitionAsync?.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages) ?? Task.FromResult(true));
+ return new Transition(from, to, (fromAgent, toAgent, messages, ct) => canTransitionAsync.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages, ct));
}
public IAgent From => _from;
@@ -99,13 +118,13 @@ public static Transition Create(TFromAgent from, TToAgent
/// Check if the transition is allowed.
///
/// messages
- public Task CanTransitionAsync(IEnumerable messages)
+ public Task CanTransitionAsync(IEnumerable messages, CancellationToken ct = default)
{
if (_canTransition == null)
{
return Task.FromResult(true);
}
- return _canTransition(this.From, this.To, messages);
+ return _canTransition(this.From, this.To, messages, ct);
}
}
diff --git a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
index 5e82931ab658..57e15c18ca62 100644
--- a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
+++ b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
@@ -15,6 +15,7 @@ public class GroupChat : IGroupChat
private List agents = new List();
private IEnumerable initializeMessages = new List();
private Graph? workflow = null;
+ private readonly IOrchestrator orchestrator;
public IEnumerable? Messages { get; private set; }
@@ -36,6 +37,37 @@ public GroupChat(
this.initializeMessages = initializeMessages ?? new List();
this.workflow = workflow;
+ if (admin is not null)
+ {
+ this.orchestrator = new RolePlayOrchestrator(admin, workflow);
+ }
+ else if (workflow is not null)
+ {
+ this.orchestrator = new WorkflowOrchestrator(workflow);
+ }
+ else
+ {
+ this.orchestrator = new RoundRobinOrchestrator();
+ }
+
+ this.Validation();
+ }
+
+ ///
+ /// Create a group chat which uses the to decide the next speaker(s).
+ ///
+ ///
+ ///
+ ///
+ public GroupChat(
+ IEnumerable members,
+ IOrchestrator orchestrator,
+ IEnumerable? initializeMessages = null)
+ {
+ this.agents = members.ToList();
+ this.initializeMessages = initializeMessages ?? new List();
+ this.orchestrator = orchestrator;
+
this.Validation();
}
@@ -64,12 +96,6 @@ private void Validation()
throw new Exception("All agents in the workflow must be in the group chat.");
}
}
-
- // must provide one of admin or workflow
- if (this.admin == null && this.workflow == null)
- {
- throw new Exception("Must provide one of admin or workflow.");
- }
}
///
@@ -81,6 +107,7 @@ private void Validation()
/// current speaker
/// conversation history
/// next speaker.
+ [Obsolete("Please use RolePlayOrchestrator or WorkflowOrchestrator")]
public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable conversationHistory)
{
var agentNames = this.agents.Select(x => x.Name).ToList();
@@ -140,37 +167,40 @@ public void AddInitializeMessage(IMessage message)
}
public async Task> CallAsync(
- IEnumerable? conversationWithName = null,
+ IEnumerable? chatHistory = null,
int maxRound = 10,
CancellationToken ct = default)
{
var conversationHistory = new List();
- if (conversationWithName != null)
+ conversationHistory.AddRange(this.initializeMessages);
+ if (chatHistory != null)
{
- conversationHistory.AddRange(conversationWithName);
+ conversationHistory.AddRange(chatHistory);
}
+ var roundLeft = maxRound;
- var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
+ while (roundLeft > 0)
{
- null => this.agents.First(),
- _ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
- };
- var round = 0;
- while (round < maxRound)
- {
- var currentSpeaker = await this.SelectNextSpeakerAsync(lastSpeaker, conversationHistory);
- var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
- var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
+ var orchestratorContext = new OrchestrationContext
+ {
+ Candidates = this.agents,
+ ChatHistory = conversationHistory,
+ };
+ var nextSpeaker = await this.orchestrator.GetNextSpeakerAsync(orchestratorContext, ct);
+ if (nextSpeaker == null)
+ {
+ break;
+ }
+
+ var result = await nextSpeaker.GenerateReplyAsync(conversationHistory, cancellationToken: ct);
conversationHistory.Add(result);
- // if message is terminate message, then terminate the conversation
- if (result?.IsGroupChatTerminateMessage() ?? false)
+ if (result.IsGroupChatTerminateMessage())
{
- break;
+ return conversationHistory;
}
- lastSpeaker = currentSpeaker;
- round++;
+ roundLeft--;
}
return conversationHistory;
diff --git a/dotnet/src/AutoGen.Core/IGroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/IGroupChat.cs
similarity index 100%
rename from dotnet/src/AutoGen.Core/IGroupChat.cs
rename to dotnet/src/AutoGen.Core/GroupChat/IGroupChat.cs
diff --git a/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs
index b8de89b834fe..b95cd1958fc5 100644
--- a/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs
+++ b/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs
@@ -3,9 +3,6 @@
using System;
using System.Collections.Generic;
-using System.Linq;
-using System.Threading;
-using System.Threading.Tasks;
namespace AutoGen.Core;
@@ -25,76 +22,12 @@ public SequentialGroupChat(IEnumerable agents, List? initializ
///
/// A group chat that allows agents to talk in a round-robin manner.
///
-public class RoundRobinGroupChat : IGroupChat
+public class RoundRobinGroupChat : GroupChat
{
- private readonly List agents = new List();
- private readonly List initializeMessages = new List();
-
public RoundRobinGroupChat(
IEnumerable agents,
List? initializeMessages = null)
+ : base(agents, initializeMessages: initializeMessages)
{
- this.agents.AddRange(agents);
- this.initializeMessages = initializeMessages ?? new List();
- }
-
- ///
- public void AddInitializeMessage(IMessage message)
- {
- this.SendIntroduction(message);
- }
-
- public async Task> CallAsync(
- IEnumerable? conversationWithName = null,
- int maxRound = 10,
- CancellationToken ct = default)
- {
- var conversationHistory = new List();
- if (conversationWithName != null)
- {
- conversationHistory.AddRange(conversationWithName);
- }
-
- var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
- {
- null => this.agents.First(),
- _ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
- };
- var round = 0;
- while (round < maxRound)
- {
- var currentSpeaker = this.SelectNextSpeaker(lastSpeaker);
- var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
- var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
- conversationHistory.Add(result);
-
- // if message is terminate message, then terminate the conversation
- if (result?.IsGroupChatTerminateMessage() ?? false)
- {
- break;
- }
-
- lastSpeaker = currentSpeaker;
- round++;
- }
-
- return conversationHistory;
- }
-
- public void SendIntroduction(IMessage message)
- {
- this.initializeMessages.Add(message);
- }
-
- private IAgent SelectNextSpeaker(IAgent currentSpeaker)
- {
- var index = this.agents.IndexOf(currentSpeaker);
- if (index == -1)
- {
- throw new ArgumentException("The agent is not in the group chat", nameof(currentSpeaker));
- }
-
- var nextIndex = (index + 1) % this.agents.Count;
- return this.agents[nextIndex];
}
}
diff --git a/dotnet/src/AutoGen.Core/Message/IMessage.cs b/dotnet/src/AutoGen.Core/Message/IMessage.cs
index ad215d510e3b..9952cbf06792 100644
--- a/dotnet/src/AutoGen.Core/Message/IMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/IMessage.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMessage.cs
+using System;
using System.Collections.Generic;
namespace AutoGen.Core;
@@ -35,19 +36,21 @@ namespace AutoGen.Core;
///
///
///
-public interface IMessage : IStreamingMessage
+public interface IMessage
{
+ string? From { get; set; }
}
-public interface IMessage : IMessage, IStreamingMessage
+public interface IMessage : IMessage
{
+ T Content { get; }
}
///
/// The interface for messages that can get text content.
/// This interface will be used by to get the content from the message.
///
-public interface ICanGetTextContent : IMessage, IStreamingMessage
+public interface ICanGetTextContent : IMessage
{
public string? GetContent();
}
@@ -55,17 +58,18 @@ public interface ICanGetTextContent : IMessage, IStreamingMessage
///
/// The interface for messages that can get a list of
///
-public interface ICanGetToolCalls : IMessage, IStreamingMessage
+public interface ICanGetToolCalls : IMessage
{
public IEnumerable GetToolCalls();
}
-
+[Obsolete("Use IMessage instead")]
public interface IStreamingMessage
{
string? From { get; set; }
}
+[Obsolete("Use IMessage instead")]
public interface IStreamingMessage : IStreamingMessage
{
T Content { get; }
diff --git a/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs
index f83bea279260..dc9709bbde5b 100644
--- a/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs
+++ b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs
@@ -5,7 +5,7 @@
namespace AutoGen.Core;
-public abstract class MessageEnvelope : IMessage, IStreamingMessage
+public abstract class MessageEnvelope : IMessage
{
public MessageEnvelope(string? from = null, IDictionary? metadata = null)
{
@@ -23,7 +23,7 @@ public static MessageEnvelope Create(TContent content, strin
public IDictionary Metadata { get; set; }
}
-public class MessageEnvelope : MessageEnvelope, IMessage, IStreamingMessage
+public class MessageEnvelope : MessageEnvelope, IMessage
{
public MessageEnvelope(T content, string? from = null, IDictionary? metadata = null)
: base(from, metadata)
diff --git a/dotnet/src/AutoGen.Core/Message/TextMessage.cs b/dotnet/src/AutoGen.Core/Message/TextMessage.cs
index addd8728a926..9419c2b3ba86 100644
--- a/dotnet/src/AutoGen.Core/Message/TextMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/TextMessage.cs
@@ -3,7 +3,7 @@
namespace AutoGen.Core;
-public class TextMessage : IMessage, IStreamingMessage, ICanGetTextContent
+public class TextMessage : IMessage, ICanGetTextContent
{
public TextMessage(Role role, string content, string? from = null)
{
@@ -51,7 +51,7 @@ public override string ToString()
}
}
-public class TextMessageUpdate : IStreamingMessage, ICanGetTextContent
+public class TextMessageUpdate : IMessage, ICanGetTextContent
{
public TextMessageUpdate(Role role, string? content, string? from = null)
{
diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
index 7781b785ef8c..7d46d56135aa 100644
--- a/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// FunctionCallAggregateMessage.cs
+// ToolCallAggregateMessage.cs
using System.Collections.Generic;
diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
index d0f89e1ecdde..8660b323044f 100644
--- a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
@@ -109,7 +109,7 @@ public IEnumerable GetToolCalls()
}
}
-public class ToolCallMessageUpdate : IStreamingMessage
+public class ToolCallMessageUpdate : IMessage
{
public ToolCallMessageUpdate(string functionName, string functionArgumentUpdate, string? from = null)
{
diff --git a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
index d0788077b590..7d30f6d0928a 100644
--- a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
+++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
@@ -70,7 +70,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
return reply;
}
- public async IAsyncEnumerable InvokeAsync(
+ public async IAsyncEnumerable InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
@@ -86,16 +86,16 @@ public async IAsyncEnumerable InvokeAsync(
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();
- IStreamingMessage? initMessage = default;
+ IMessage? mergedFunctionCallMessage = default;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
{
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
{
- if (initMessage is null)
+ if (mergedFunctionCallMessage is null)
{
- initMessage = new ToolCallMessage(toolCallMessageUpdate);
+ mergedFunctionCallMessage = new ToolCallMessage(toolCallMessageUpdate);
}
- else if (initMessage is ToolCallMessage toolCall)
+ else if (mergedFunctionCallMessage is ToolCallMessage toolCall)
{
toolCall.Update(toolCallMessageUpdate);
}
@@ -104,13 +104,17 @@ public async IAsyncEnumerable InvokeAsync(
throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate");
}
}
+ else if (message is ToolCallMessage toolCallMessage1)
+ {
+ mergedFunctionCallMessage = toolCallMessage1;
+ }
else
{
yield return message;
}
}
- if (initMessage is ToolCallMessage toolCallMsg)
+ if (mergedFunctionCallMessage is ToolCallMessage toolCallMsg)
{
yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
}
diff --git a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs
index bc7aec57f52b..d550bdb519ce 100644
--- a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs
+++ b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs
@@ -14,7 +14,7 @@ public interface IStreamingMiddleware : IMiddleware
///
/// The streaming version of .
///
- public IAsyncEnumerable InvokeAsync(
+ public IAsyncEnumerable InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default);
diff --git a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs
index 099f78e5f176..a4e84de85a44 100644
--- a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs
+++ b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs
@@ -48,7 +48,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
}
}
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IMessage? recentUpdate = null;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
diff --git a/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs
new file mode 100644
index 000000000000..777834871f65
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// IOrchestrator.cs
+
+using System;
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class OrchestrationContext
+{
+ public IEnumerable Candidates { get; set; } = Array.Empty();
+
+ public IEnumerable ChatHistory { get; set; } = Array.Empty();
+}
+
+public interface IOrchestrator
+{
+ ///
+ /// Return the next agent as the next speaker. return null if no agent is selected.
+ ///
+ /// orchestration context, such as candidate agents and chat history.
+ /// cancellation token
+ public Task GetNextSpeakerAsync(
+ OrchestrationContext context,
+ CancellationToken cancellationToken = default);
+}
diff --git a/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs
new file mode 100644
index 000000000000..6798f23f2df8
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs
@@ -0,0 +1,116 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// RolePlayOrchestrator.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class RolePlayOrchestrator : IOrchestrator
+{
+ private readonly IAgent admin;
+ private readonly Graph? workflow = null;
+ public RolePlayOrchestrator(IAgent admin, Graph? workflow = null)
+ {
+ this.admin = admin;
+ this.workflow = workflow;
+ }
+
+ public async Task GetNextSpeakerAsync(
+ OrchestrationContext context,
+ CancellationToken cancellationToken = default)
+ {
+ var candidates = context.Candidates.ToList();
+
+ if (candidates.Count == 0)
+ {
+ return null;
+ }
+
+ if (candidates.Count == 1)
+ {
+ return candidates.First();
+ }
+
+ // if there's a workflow
+ // and the next available agent from the workflow is in the group chat
+ // then return the next agent from the workflow
+ if (this.workflow != null)
+ {
+ var lastMessage = context.ChatHistory.LastOrDefault();
+ if (lastMessage == null)
+ {
+ return null;
+ }
+ var currentSpeaker = candidates.First(candidates => candidates.Name == lastMessage.From);
+ var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory);
+ nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name));
+ candidates = nextAgents.ToList();
+ if (!candidates.Any())
+ {
+ return null;
+ }
+
+ if (candidates is { Count: 1 })
+ {
+ return candidates.First();
+ }
+ }
+
+ // In this case, since there are more than one available agents from the workflow for the next speaker
+ // the admin will be invoked to decide the next speaker
+ var agentNames = candidates.Select(candidate => candidate.Name);
+ var rolePlayMessage = new TextMessage(Role.User,
+ content: $@"You are in a role play game. Carefully read the conversation history and carry on the conversation.
+The available roles are:
+{string.Join(",", agentNames)}
+
+Each message will start with 'From name:', e.g:
+From {agentNames.First()}:
+//your message//.");
+
+ var chatHistoryWithName = this.ProcessConversationsForRolePlay(context.ChatHistory);
+ var messages = new IMessage[] { rolePlayMessage }.Concat(chatHistoryWithName);
+
+ var response = await this.admin.GenerateReplyAsync(
+ messages: messages,
+ options: new GenerateReplyOptions
+ {
+ Temperature = 0,
+ MaxToken = 128,
+ StopSequence = [":"],
+ Functions = null,
+ },
+ cancellationToken: cancellationToken);
+
+ var name = response.GetContent() ?? throw new Exception("No name is returned.");
+
+ // remove From
+ name = name!.Substring(5);
+ var candidate = candidates.FirstOrDefault(x => x.Name!.ToLower() == name.ToLower());
+
+ if (candidate != null)
+ {
+ return candidate;
+ }
+
+ var errorMessage = $"The response from admin is {name}, which is either not in the candidates list or not in the correct format.";
+ throw new Exception(errorMessage);
+ }
+
+ private IEnumerable ProcessConversationsForRolePlay(IEnumerable messages)
+ {
+ return messages.Select((x, i) =>
+ {
+ var msg = @$"From {x.From}:
+{x.GetContent()}
+
+round # {i}";
+
+ return new TextMessage(Role.User, content: msg);
+ });
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs
new file mode 100644
index 000000000000..af5efdc0e9ee
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// RoundRobinOrchestrator.cs
+
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+///
+/// Return the next agent in a round-robin fashion.
+///
+/// If the last message is from one of the candidates, the next agent will be the next candidate in the list.
+///
+///
+/// Otherwise, the first agent in will be returned.
+///
+///
+///
+///
+public class RoundRobinOrchestrator : IOrchestrator
+{
+ public async Task GetNextSpeakerAsync(
+ OrchestrationContext context,
+ CancellationToken cancellationToken = default)
+ {
+ var lastMessage = context.ChatHistory.LastOrDefault();
+
+ if (lastMessage == null)
+ {
+ return context.Candidates.FirstOrDefault();
+ }
+
+ var candidates = context.Candidates.ToList();
+ var lastAgentIndex = candidates.FindIndex(a => a.Name == lastMessage.From);
+ if (lastAgentIndex == -1)
+ {
+ return null;
+ }
+
+ var nextAgentIndex = (lastAgentIndex + 1) % candidates.Count;
+ return candidates[nextAgentIndex];
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs
new file mode 100644
index 000000000000..b84850a07c75
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// WorkflowOrchestrator.cs
+
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class WorkflowOrchestrator : IOrchestrator
+{
+ private readonly Graph workflow;
+
+ public WorkflowOrchestrator(Graph workflow)
+ {
+ this.workflow = workflow;
+ }
+
+ public async Task GetNextSpeakerAsync(
+ OrchestrationContext context,
+ CancellationToken cancellationToken = default)
+ {
+ var lastMessage = context.ChatHistory.LastOrDefault();
+ if (lastMessage == null)
+ {
+ return null;
+ }
+
+ var candidates = context.Candidates.ToList();
+ var currentSpeaker = candidates.FirstOrDefault(candidates => candidates.Name == lastMessage.From);
+
+ if (currentSpeaker == null)
+ {
+ return null;
+ }
+ var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory);
+ nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name));
+ candidates = nextAgents.ToList();
+ if (!candidates.Any())
+ {
+ return null;
+ }
+
+ if (candidates is { Count: 1 })
+ {
+ return candidates.First();
+ }
+ else
+ {
+ throw new System.Exception("There are more than one available agents from the workflow for the next speaker.");
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj b/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj
index 72c67fe78016..e850d94944bc 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj
+++ b/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj
@@ -1,7 +1,7 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
enable
enable
AutoGen.DotnetInteractive
@@ -27,12 +27,14 @@
-
-
-
+
+
+
+
+
diff --git a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveFunction.cs b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveFunction.cs
index bb5504cd5487..c9b59203462b 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveFunction.cs
+++ b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveFunction.cs
@@ -2,14 +2,12 @@
// DotnetInteractiveFunction.cs
using System.Text;
-using System.Text.Json;
-using Azure.AI.OpenAI;
using Microsoft.DotNet.Interactive.Documents;
using Microsoft.DotNet.Interactive.Documents.Jupyter;
namespace AutoGen.DotnetInteractive;
-public class DotnetInteractiveFunction : IDisposable
+public partial class DotnetInteractiveFunction : IDisposable
{
private readonly InteractiveService? _interactiveService = null;
private string _notebookPath;
@@ -71,6 +69,7 @@ public DotnetInteractiveFunction(InteractiveService interactiveService, string?
/// Run existing dotnet code from message. Don't modify the code, run it as is.
///
/// code.
+ [Function]
public async Task RunCode(string code)
{
if (this._interactiveService == null)
@@ -117,6 +116,7 @@ public async Task RunCode(string code)
/// Install nuget packages.
///
/// nuget package to install.
+ [Function]
public async Task InstallNugetPackages(string[] nugetPackages)
{
if (this._interactiveService == null)
@@ -173,105 +173,6 @@ private async Task AddCellAsync(string cellContent, string kernelName)
writeStream.Dispose();
}
- private class RunCodeSchema
- {
- public string code { get; set; } = string.Empty;
- }
-
- public Task RunCodeWrapper(string arguments)
- {
- var schema = JsonSerializer.Deserialize(
- arguments,
- new JsonSerializerOptions
- {
- PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
- });
-
- return RunCode(schema!.code);
- }
-
- public FunctionDefinition RunCodeFunction
- {
- get => new FunctionDefinition
- {
- Name = @"RunCode",
- Description = """
-Run existing dotnet code from message. Don't modify the code, run it as is.
-""",
- Parameters = BinaryData.FromObjectAsJson(new
- {
- Type = "object",
- Properties = new
- {
- code = new
- {
- Type = @"string",
- Description = @"code.",
- },
- },
- Required = new[]
- {
- "code",
- },
- },
- new JsonSerializerOptions
- {
- PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
- })
- };
- }
-
- private class InstallNugetPackagesSchema
- {
- public string[] nugetPackages { get; set; } = Array.Empty();
- }
-
- public Task InstallNugetPackagesWrapper(string arguments)
- {
- var schema = JsonSerializer.Deserialize(
- arguments,
- new JsonSerializerOptions
- {
- PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
- });
-
- return InstallNugetPackages(schema!.nugetPackages);
- }
-
- public FunctionDefinition InstallNugetPackagesFunction
- {
- get => new FunctionDefinition
- {
- Name = @"InstallNugetPackages",
- Description = """
-Install nuget packages.
-""",
- Parameters = BinaryData.FromObjectAsJson(new
- {
- Type = "object",
- Properties = new
- {
- nugetPackages = new
- {
- Type = @"array",
- Items = new
- {
- Type = @"string",
- },
- Description = @"nuget package to install.",
- },
- },
- Required = new[]
- {
- "nugetPackages",
- },
- },
- new JsonSerializerOptions
- {
- PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
- })
- };
- }
public void Dispose()
{
this._interactiveService?.Dispose();
diff --git a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveKernelBuilder.cs b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveKernelBuilder.cs
new file mode 100644
index 000000000000..cc282fbba55c
--- /dev/null
+++ b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveKernelBuilder.cs
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// DotnetInteractiveKernelBuilder.cs
+
+namespace AutoGen.DotnetInteractive;
+
+public static class DotnetInteractiveKernelBuilder
+{
+
+#if NET8_0_OR_GREATER
+ public static InProccessDotnetInteractiveKernelBuilder CreateEmptyInProcessKernelBuilder()
+ {
+ return new InProccessDotnetInteractiveKernelBuilder();
+ }
+
+
+ public static InProccessDotnetInteractiveKernelBuilder CreateDefaultInProcessKernelBuilder()
+ {
+ return new InProccessDotnetInteractiveKernelBuilder()
+ .AddCSharpKernel()
+ .AddFSharpKernel();
+ }
+#endif
+
+ public static DotnetInteractiveStdioKernelConnector CreateKernelBuilder(string workingDirectory, string kernelName = "root-proxy")
+ {
+ return new DotnetInteractiveStdioKernelConnector(workingDirectory, kernelName);
+ }
+}
diff --git a/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveStdioKernelConnector.cs b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveStdioKernelConnector.cs
new file mode 100644
index 000000000000..a3ea80a7b12a
--- /dev/null
+++ b/dotnet/src/AutoGen.DotnetInteractive/DotnetInteractiveStdioKernelConnector.cs
@@ -0,0 +1,86 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// DotnetInteractiveStdioKernelConnector.cs
+
+using AutoGen.DotnetInteractive.Extension;
+using Microsoft.DotNet.Interactive;
+using Microsoft.DotNet.Interactive.Commands;
+using Microsoft.DotNet.Interactive.Connection;
+
+namespace AutoGen.DotnetInteractive;
+
+public class DotnetInteractiveStdioKernelConnector
+{
+ private string workingDirectory;
+ private InteractiveService interactiveService;
+ private string kernelName;
+ private List setupCommands = new List();
+
+ internal DotnetInteractiveStdioKernelConnector(string workingDirectory, string kernelName = "root-proxy")
+ {
+ this.workingDirectory = workingDirectory;
+ this.interactiveService = new InteractiveService(workingDirectory);
+ this.kernelName = kernelName;
+ }
+
+ public DotnetInteractiveStdioKernelConnector RestoreDotnetInteractive()
+ {
+ if (this.interactiveService.RestoreDotnetInteractive())
+ {
+ return this;
+ }
+ else
+ {
+ throw new Exception("Failed to restore dotnet interactive tool.");
+ }
+ }
+
+ public DotnetInteractiveStdioKernelConnector AddPythonKernel(
+ string venv,
+ string kernelName = "python")
+ {
+ var magicCommand = $"#!connect jupyter --kernel-name {kernelName} --kernel-spec {venv}";
+ var connectCommand = new SubmitCode(magicCommand);
+
+ this.setupCommands.Add(connectCommand);
+
+ return this;
+ }
+
+ public async Task BuildAsync(CancellationToken ct = default)
+ {
+ var compositeKernel = new CompositeKernel();
+ var url = KernelHost.CreateHostUri(this.kernelName);
+ var cmd = new string[]
+ {
+ "dotnet",
+ "tool",
+ "run",
+ "dotnet-interactive",
+ $"[cb-{this.kernelName}]",
+ "stdio",
+ //"--default-kernel",
+ //"csharp",
+ "--working-dir",
+ $@"""{workingDirectory}""",
+ };
+
+ var connector = new StdIoKernelConnector(
+ cmd,
+ this.kernelName,
+ url,
+ new DirectoryInfo(this.workingDirectory));
+
+ var rootProxyKernel = await connector.CreateRootProxyKernelAsync();
+
+ rootProxyKernel.KernelInfo.SupportedKernelCommands.Add(new(nameof(SubmitCode)));
+
+ var dotnetKernel = await connector.CreateProxyKernelAsync(".NET");
+ foreach (var setupCommand in this.setupCommands)
+ {
+ var setupCommandResult = await rootProxyKernel.SendAsync(setupCommand, ct);
+ setupCommandResult.ThrowOnCommandFailed();
+ }
+
+ return rootProxyKernel;
+ }
+}
diff --git a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs
index 83955c53fa16..de1e2a68cc0c 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs
+++ b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs
@@ -21,6 +21,7 @@ public static class AgentExtension
/// [!code-csharp[Example04_Dynamic_GroupChat_Coding_Task](~/../sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs)]
/// ]]>
///
+ [Obsolete]
public static IAgent RegisterDotnetCodeBlockExectionHook(
this IAgent agent,
InteractiveService interactiveService,
diff --git a/dotnet/src/AutoGen.DotnetInteractive/Utils.cs b/dotnet/src/AutoGen.DotnetInteractive/Extension/KernelExtension.cs
similarity index 57%
rename from dotnet/src/AutoGen.DotnetInteractive/Utils.cs
rename to dotnet/src/AutoGen.DotnetInteractive/Extension/KernelExtension.cs
index d10208d508c6..2a7afdf8857f 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/Utils.cs
+++ b/dotnet/src/AutoGen.DotnetInteractive/Extension/KernelExtension.cs
@@ -1,23 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// Utils.cs
+// KernelExtension.cs
-using System.Collections;
-using System.Collections.Immutable;
using Microsoft.DotNet.Interactive;
using Microsoft.DotNet.Interactive.Commands;
using Microsoft.DotNet.Interactive.Connection;
using Microsoft.DotNet.Interactive.Events;
-public static class ObservableExtensions
+namespace AutoGen.DotnetInteractive.Extension;
+
+public static class KernelExtension
{
- public static SubscribedList ToSubscribedList(this IObservable source)
+ public static async Task RunSubmitCodeCommandAsync(
+ this Kernel kernel,
+ string codeBlock,
+ string targetKernelName,
+ CancellationToken ct = default)
{
- return new SubscribedList(source);
+ try
+ {
+ var cmd = new SubmitCode(codeBlock, targetKernelName);
+ var res = await kernel.SendAndThrowOnCommandFailedAsync(cmd, ct);
+ var events = res.Events;
+ var displayValues = res.Events.Where(x => x is StandardErrorValueProduced || x is StandardOutputValueProduced || x is ReturnValueProduced || x is DisplayedValueProduced)
+ .SelectMany(x => (x as DisplayEvent)!.FormattedValues);
+
+ if (displayValues is null || displayValues.Count() == 0)
+ {
+ return null;
+ }
+
+ return string.Join("\n", displayValues.Select(x => x.Value));
+ }
+ catch (Exception ex)
+ {
+ return $"Error: {ex.Message}";
+ }
}
-}
-public static class KernelExtensions
-{
internal static void SetUpValueSharingIfSupported(this ProxyKernel proxyKernel)
{
var supportedCommands = proxyKernel.KernelInfo.SupportedKernelCommands;
@@ -38,7 +57,7 @@ internal static async Task SendAndThrowOnCommandFailedAsync
return result;
}
- private static void ThrowOnCommandFailed(this KernelCommandResult result)
+ internal static void ThrowOnCommandFailed(this KernelCommandResult result)
{
var failedEvents = result.Events.OfType();
if (!failedEvents.Any())
@@ -60,27 +79,3 @@ private static void ThrowOnCommandFailed(this KernelCommandResult result)
private static Exception GetException(this CommandFailed commandFailedEvent)
=> new Exception(commandFailedEvent.Message);
}
-
-public class SubscribedList : IReadOnlyList, IDisposable
-{
- private ImmutableArray _list = ImmutableArray.Empty;
- private readonly IDisposable _subscription;
-
- public SubscribedList(IObservable source)
- {
- _subscription = source.Subscribe(x => _list = _list.Add(x));
- }
-
- public IEnumerator GetEnumerator()
- {
- return ((IEnumerable)_list).GetEnumerator();
- }
-
- IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
-
- public int Count => _list.Length;
-
- public T this[int index] => _list[index];
-
- public void Dispose() => _subscription.Dispose();
-}
diff --git a/dotnet/src/AutoGen.DotnetInteractive/Extension/MessageExtension.cs b/dotnet/src/AutoGen.DotnetInteractive/Extension/MessageExtension.cs
new file mode 100644
index 000000000000..6a8bf66c19f3
--- /dev/null
+++ b/dotnet/src/AutoGen.DotnetInteractive/Extension/MessageExtension.cs
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// MessageExtension.cs
+
+using System.Text.RegularExpressions;
+
+namespace AutoGen.DotnetInteractive.Extension;
+
+public static class MessageExtension
+{
+ ///
+ /// Extract a single code block from a message. If the message contains multiple code blocks, only the first one will be returned.
+ ///
+ ///
+ /// code block prefix, e.g. ```csharp
+ /// code block suffix, e.g. ```
+ ///
+ public static string? ExtractCodeBlock(
+ this IMessage message,
+ string codeBlockPrefix,
+ string codeBlockSuffix)
+ {
+ foreach (var codeBlock in message.ExtractCodeBlocks(codeBlockPrefix, codeBlockSuffix))
+ {
+ return codeBlock;
+ }
+
+ return null;
+ }
+
+ ///
+ /// Extract all code blocks from a message.
+ ///
+ ///
+ /// code block prefix, e.g. ```csharp
+ /// code block suffix, e.g. ```
+ ///
+ public static IEnumerable ExtractCodeBlocks(
+ this IMessage message,
+ string codeBlockPrefix,
+ string codeBlockSuffix)
+ {
+ var content = message.GetContent() ?? string.Empty;
+ if (string.IsNullOrWhiteSpace(content))
+ {
+ yield break;
+ }
+
+ foreach (Match match in Regex.Matches(content, $@"{codeBlockPrefix}([\s\S]*?){codeBlockSuffix}"))
+ {
+ yield return match.Groups[1].Value.Trim();
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.DotnetInteractive/InProccessDotnetInteractiveKernelBuilder.cs b/dotnet/src/AutoGen.DotnetInteractive/InProccessDotnetInteractiveKernelBuilder.cs
new file mode 100644
index 000000000000..6ddd3d6b4178
--- /dev/null
+++ b/dotnet/src/AutoGen.DotnetInteractive/InProccessDotnetInteractiveKernelBuilder.cs
@@ -0,0 +1,110 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// InProccessDotnetInteractiveKernelBuilder.cs
+
+#if NET8_0_OR_GREATER
+using AutoGen.DotnetInteractive.Extension;
+using Microsoft.DotNet.Interactive;
+using Microsoft.DotNet.Interactive.Commands;
+using Microsoft.DotNet.Interactive.CSharp;
+using Microsoft.DotNet.Interactive.FSharp;
+using Microsoft.DotNet.Interactive.Jupyter;
+using Microsoft.DotNet.Interactive.PackageManagement;
+using Microsoft.DotNet.Interactive.PowerShell;
+
+namespace AutoGen.DotnetInteractive;
+
+///
+/// Build an in-proc dotnet interactive kernel.
+///
+public class InProccessDotnetInteractiveKernelBuilder
+{
+ private readonly CompositeKernel compositeKernel;
+
+ internal InProccessDotnetInteractiveKernelBuilder()
+ {
+ this.compositeKernel = new CompositeKernel();
+
+ // add jupyter connector
+ this.compositeKernel.AddKernelConnector(
+ new ConnectJupyterKernelCommand()
+ .AddConnectionOptions(new JupyterHttpKernelConnectionOptions())
+ .AddConnectionOptions(new JupyterLocalKernelConnectionOptions()));
+ }
+
+ public InProccessDotnetInteractiveKernelBuilder AddCSharpKernel(IEnumerable? aliases = null)
+ {
+ aliases ??= ["c#", "C#", "csharp"];
+ // create csharp kernel
+ var csharpKernel = new CSharpKernel()
+ .UseNugetDirective((k, resolvedPackageReference) =>
+ {
+
+ k.AddAssemblyReferences(resolvedPackageReference
+ .SelectMany(r => r.AssemblyPaths));
+ return Task.CompletedTask;
+ })
+ .UseKernelHelpers()
+ .UseWho()
+ .UseMathAndLaTeX()
+ .UseValueSharing();
+
+ this.AddKernel(csharpKernel, aliases);
+
+ return this;
+ }
+
+ public InProccessDotnetInteractiveKernelBuilder AddFSharpKernel(IEnumerable? aliases = null)
+ {
+ aliases ??= ["f#", "F#", "fsharp"];
+ // create fsharp kernel
+ var fsharpKernel = new FSharpKernel()
+ .UseDefaultFormatting()
+ .UseKernelHelpers()
+ .UseWho()
+ .UseMathAndLaTeX()
+ .UseValueSharing();
+
+ this.AddKernel(fsharpKernel, aliases);
+
+ return this;
+ }
+
+ public InProccessDotnetInteractiveKernelBuilder AddPowershellKernel(IEnumerable? aliases = null)
+ {
+ aliases ??= ["pwsh", "powershell"];
+ // create powershell kernel
+ var powershellKernel = new PowerShellKernel()
+ .UseProfiles()
+ .UseValueSharing();
+
+ this.AddKernel(powershellKernel, aliases);
+
+ return this;
+ }
+
+ public InProccessDotnetInteractiveKernelBuilder AddPythonKernel(string venv, string kernelName = "python")
+ {
+ // create python kernel
+ var magicCommand = $"#!connect jupyter --kernel-name {kernelName} --kernel-spec {venv}";
+ var connectCommand = new SubmitCode(magicCommand);
+ var result = this.compositeKernel.SendAsync(connectCommand).Result;
+
+ result.ThrowOnCommandFailed();
+
+ return this;
+ }
+
+ public CompositeKernel Build()
+ {
+ return this.compositeKernel
+ .UseDefaultMagicCommands()
+ .UseImportMagicCommand();
+ }
+
+ private InProccessDotnetInteractiveKernelBuilder AddKernel(Kernel kernel, IEnumerable? aliases = null)
+ {
+ this.compositeKernel.Add(kernel, aliases);
+ return this;
+ }
+}
+#endif
diff --git a/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs b/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
index 7490b64e1267..3381aecf5794 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
+++ b/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
@@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Reactive.Linq;
using System.Reflection;
+using AutoGen.DotnetInteractive.Extension;
using Microsoft.DotNet.Interactive;
using Microsoft.DotNet.Interactive.Commands;
using Microsoft.DotNet.Interactive.Connection;
@@ -19,18 +20,14 @@ public class InteractiveService : IDisposable
private bool disposedValue;
private const string DotnetInteractiveToolNotInstallMessage = "Cannot find a tool in the manifest file that has a command named 'dotnet-interactive'.";
//private readonly ProcessJobTracker jobTracker = new ProcessJobTracker();
- private string installingDirectory;
-
- public event EventHandler? DisplayEvent;
-
- public event EventHandler? Output;
-
- public event EventHandler? CommandFailed;
-
- public event EventHandler? HoverTextProduced;
+ private string? installingDirectory;
///
- /// Create an instance of InteractiveService
+ /// Install dotnet interactive tool to
+ /// and create an instance of .
+ ///
+ /// When using this constructor, you need to call to install dotnet interactive tool
+ /// and start the kernel.
///
/// dotnet interactive installing directory
public InteractiveService(string installingDirectory)
@@ -38,37 +35,37 @@ public InteractiveService(string installingDirectory)
this.installingDirectory = installingDirectory;
}
+ ///
+ /// Create an instance of with a running kernel.
+ /// When using this constructor, you don't need to call to start the kernel.
+ ///
+ ///
+ public InteractiveService(Kernel kernel)
+ {
+ this.kernel = kernel;
+ }
+
+ public Kernel? Kernel => this.kernel;
+
public async Task StartAsync(string workingDirectory, CancellationToken ct = default)
{
+ if (this.kernel != null)
+ {
+ return true;
+ }
+
this.kernel = await this.CreateKernelAsync(workingDirectory, true, ct);
return true;
}
- public async Task SubmitCommandAsync(KernelCommand cmd, CancellationToken ct)
+ public async Task SubmitCommandAsync(SubmitCode cmd, CancellationToken ct)
{
if (this.kernel == null)
{
throw new Exception("Kernel is not running");
}
- try
- {
- var res = await this.kernel.SendAndThrowOnCommandFailedAsync(cmd, ct);
- var events = res.Events;
- var displayValues = events.Where(x => x is StandardErrorValueProduced || x is StandardOutputValueProduced || x is ReturnValueProduced)
- .SelectMany(x => (x as DisplayEvent)!.FormattedValues);
-
- if (displayValues is null || displayValues.Count() == 0)
- {
- return null;
- }
-
- return string.Join("\n", displayValues.Select(x => x.Value));
- }
- catch (Exception ex)
- {
- return $"Error: {ex.Message}";
- }
+ return await this.kernel.RunSubmitCodeCommandAsync(cmd.Code, cmd.TargetKernelName, ct);
}
public async Task SubmitPowershellCodeAsync(string code, CancellationToken ct)
@@ -85,7 +82,11 @@ public async Task StartAsync(string workingDirectory, CancellationToken ct
public bool RestoreDotnetInteractive()
{
- this.WriteLine("Restore dotnet interactive tool");
+ if (this.installingDirectory is null)
+ {
+ throw new Exception("Installing directory is not set");
+ }
+
// write RestoreInteractive.config from embedded resource to this.workingDirectory
var assembly = Assembly.GetAssembly(typeof(InteractiveService))!;
var resourceName = "AutoGen.DotnetInteractive.RestoreInteractive.config";
@@ -178,8 +179,6 @@ await rootProxyKernel.SendAsync(
//compositeKernel.DefaultKernelName = "csharp";
compositeKernel.Add(rootProxyKernel);
- compositeKernel.KernelEvents.Subscribe(this.OnKernelDiagnosticEventReceived);
-
return compositeKernel;
}
catch (CommandLineInvocationException) when (restoreWhenFail)
@@ -195,35 +194,11 @@ await rootProxyKernel.SendAsync(
}
}
- private void OnKernelDiagnosticEventReceived(KernelEvent ke)
- {
- this.WriteLine("Receive data from kernel");
- this.WriteLine(KernelEventEnvelope.Serialize(ke));
-
- switch (ke)
- {
- case DisplayEvent de:
- this.DisplayEvent?.Invoke(this, de);
- break;
- case CommandFailed cf:
- this.CommandFailed?.Invoke(this, cf);
- break;
- case HoverTextProduced cf:
- this.HoverTextProduced?.Invoke(this, cf);
- break;
- }
- }
-
- private void WriteLine(string data)
- {
- this.Output?.Invoke(this, data);
- }
-
private void PrintProcessOutput(object sender, DataReceivedEventArgs e)
{
if (!string.IsNullOrEmpty(e.Data))
{
- this.WriteLine(e.Data);
+ Console.WriteLine(e.Data);
}
}
diff --git a/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj b/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj
index 29c4d1bb9c6f..9a60596503bc 100644
--- a/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj
+++ b/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj
@@ -1,7 +1,7 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
diff --git a/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs b/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs
index b081faae8321..e759ba26d1e9 100644
--- a/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs
+++ b/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs
@@ -143,7 +143,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G
return MessageEnvelope.Create(response, this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = this.client.GenerateContentStreamAsync(request);
diff --git a/dotnet/src/AutoGen.Gemini/IGeminiClient.cs b/dotnet/src/AutoGen.Gemini/IGeminiClient.cs
index 2e209e02b030..d391a4508398 100644
--- a/dotnet/src/AutoGen.Gemini/IGeminiClient.cs
+++ b/dotnet/src/AutoGen.Gemini/IGeminiClient.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// IVertexGeminiClient.cs
+// IGeminiClient.cs
using System.Collections.Generic;
using System.Threading;
diff --git a/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs b/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs
index cb18ba084d78..422fb4cd3458 100644
--- a/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs
+++ b/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs
@@ -39,7 +39,7 @@ public GeminiMessageConnector(bool strictMode = false)
public string Name => nameof(GeminiMessageConnector);
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);
diff --git a/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs b/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs
index c54f2280dfd3..12a11993cd69 100644
--- a/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs
+++ b/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// IGeminiClient.cs
+// VertexGeminiClient.cs
using System.Collections.Generic;
using System.Threading;
diff --git a/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj b/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj
index f45a2f7eba5f..aa891e71294d 100644
--- a/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj
+++ b/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj
@@ -1,7 +1,7 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
AutoGen.LMStudio
@@ -17,7 +17,7 @@
-
+
diff --git a/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs b/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs
index 9d0daa535b23..c4808b443c79 100644
--- a/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs
+++ b/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs
@@ -6,7 +6,7 @@
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
-using AutoGen.OpenAI;
+using AutoGen.OpenAI.V1;
using Azure.AI.OpenAI;
using Azure.Core.Pipeline;
@@ -18,6 +18,7 @@ namespace AutoGen.LMStudio;
///
/// [!code-csharp[LMStudioAgent](../../sample/AutoGen.BasicSamples/Example08_LMStudio.cs?name=lmstudio_example_1)]
///
+[Obsolete("Use OpenAIChatAgent to connect to LM Studio")]
public class LMStudioAgent : IAgent
{
private readonly GPTAgent innerAgent;
@@ -80,7 +81,7 @@ protected override Task SendAsync(HttpRequestMessage reques
{
// request.RequestUri = new Uri($"{_modelServiceUrl}{request.RequestUri.PathAndQuery}");
var uriBuilder = new UriBuilder(_modelServiceUrl);
- uriBuilder.Path = request.RequestUri.PathAndQuery;
+ uriBuilder.Path = request.RequestUri?.PathAndQuery ?? throw new InvalidOperationException("RequestUri is null");
request.RequestUri = uriBuilder.Uri;
return base.SendAsync(request, cancellationToken);
}
diff --git a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs
index cc2c74145504..db14d68a1217 100644
--- a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs
+++ b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs
@@ -78,7 +78,7 @@ public async Task GenerateReplyAsync(
return new MessageEnvelope(response, from: this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
@@ -97,6 +97,7 @@ private ChatCompletionRequest BuildChatRequest(IEnumerable messages, G
var chatHistory = BuildChatHistory(messages);
var chatRequest = new ChatCompletionRequest(model: _model, messages: chatHistory.ToList(), temperature: options?.Temperature, randomSeed: _randomSeed)
{
+ Stop = options?.StopSequence,
MaxTokens = options?.MaxToken,
ResponseFormat = _jsonOutput ? new ResponseFormat() { ResponseFormatType = "json_object" } : null,
};
diff --git a/dotnet/src/AutoGen.Mistral/AutoGen.Mistral.csproj b/dotnet/src/AutoGen.Mistral/AutoGen.Mistral.csproj
index 25cc05fec922..ee905d117791 100644
--- a/dotnet/src/AutoGen.Mistral/AutoGen.Mistral.csproj
+++ b/dotnet/src/AutoGen.Mistral/AutoGen.Mistral.csproj
@@ -1,7 +1,7 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
AutoGen.Mistral
diff --git a/dotnet/src/AutoGen.Mistral/Converters/JsonPropertyNameEnumConverter.cs b/dotnet/src/AutoGen.Mistral/Converters/JsonPropertyNameEnumConverter.cs
index 5a4f9f9cb189..9ecf11428397 100644
--- a/dotnet/src/AutoGen.Mistral/Converters/JsonPropertyNameEnumConverter.cs
+++ b/dotnet/src/AutoGen.Mistral/Converters/JsonPropertyNameEnumConverter.cs
@@ -29,7 +29,7 @@ public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerial
public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
var field = value.GetType().GetField(value.ToString());
- var attribute = field.GetCustomAttribute();
+ var attribute = field?.GetCustomAttribute();
if (attribute != null)
{
diff --git a/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs
index 71a084673f13..affe2bb6dcc3 100644
--- a/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs
+++ b/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs
@@ -105,6 +105,9 @@ public class ChatCompletionRequest
[JsonPropertyName("random_seed")]
public int? RandomSeed { get; set; }
+ [JsonPropertyName("stop")]
+ public string[]? Stop { get; set; }
+
[JsonPropertyName("tools")]
public List? Tools { get; set; }
diff --git a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
index 95592e97fcc5..78de12a5c01e 100644
--- a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
+++ b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
@@ -15,14 +15,14 @@ public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware
{
public string? Name => nameof(MistralChatMessageConnector);
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chunks = new List();
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
- if (reply is IStreamingMessage chatMessage)
+ if (reply is IMessage chatMessage)
{
chunks.Add(chatMessage.Content);
var response = ProcessChatCompletionResponse(chatMessage, agent);
@@ -167,7 +167,7 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
}
}
- private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage message, IAgent agent)
+ private IMessage? ProcessChatCompletionResponse(IMessage message, IAgent agent)
{
var response = message.Content;
if (response.VarObject != "chat.completion.chunk")
diff --git a/dotnet/src/AutoGen.Mistral/MistralClient.cs b/dotnet/src/AutoGen.Mistral/MistralClient.cs
index 5fc3d110985e..8c6802f30eb1 100644
--- a/dotnet/src/AutoGen.Mistral/MistralClient.cs
+++ b/dotnet/src/AutoGen.Mistral/MistralClient.cs
@@ -49,7 +49,7 @@ public async IAsyncEnumerable StreamingChatCompletionsAs
var response = await HttpRequestRaw(HttpMethod.Post, chatCompletionRequest, streaming: true);
using var stream = await response.Content.ReadAsStreamAsync();
using StreamReader reader = new StreamReader(stream);
- string line;
+ string? line = null;
SseEvent currentEvent = new SseEvent();
while ((line = await reader.ReadLineAsync()) != null)
@@ -67,13 +67,13 @@ public async IAsyncEnumerable StreamingChatCompletionsAs
else if (currentEvent.EventType == null)
{
var res = await JsonSerializer.DeserializeAsync(
- new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data))) ?? throw new Exception("Failed to deserialize response");
+ new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data ?? string.Empty))) ?? throw new Exception("Failed to deserialize response");
yield return res;
}
else if (currentEvent.EventType != null)
{
var res = await JsonSerializer.DeserializeAsync(
- new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)));
+ new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data ?? string.Empty)));
throw new Exception(res?.Error.Message);
}
diff --git a/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs b/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
index 9ef68388d605..87b176d8bcc5 100644
--- a/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
+++ b/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
@@ -53,7 +53,7 @@ public async Task GenerateReplyAsync(
}
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
diff --git a/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj b/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj
index a939f138c1c1..512fe92f3e3e 100644
--- a/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj
+++ b/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj
@@ -1,7 +1,7 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
AutoGen.Ollama
True
diff --git a/dotnet/src/AutoGen.Ollama/DTOs/Message.cs b/dotnet/src/AutoGen.Ollama/DTOs/Message.cs
index 2e0d891cc61e..75f622ff7f04 100644
--- a/dotnet/src/AutoGen.Ollama/DTOs/Message.cs
+++ b/dotnet/src/AutoGen.Ollama/DTOs/Message.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// ChatResponseUpdate.cs
+// Message.cs
using System.Collections.Generic;
using System.Text.Json.Serialization;
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs b/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs
index 5ce0dc8cc40a..cce6dbb83076 100644
--- a/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// ITextEmbeddingService.cs
using System.Threading;
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs b/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs
index 2e431e7bcb81..ea4993eb813f 100644
--- a/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaTextEmbeddingService.cs
using System;
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs
index 7f2531c522ad..d776b183db0b 100644
--- a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// TextEmbeddingsRequest.cs
using System.Text.Json.Serialization;
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs
index 580059c033b5..f3ce64b7032f 100644
--- a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// TextEmbeddingsResponse.cs
using System.Text.Json.Serialization;
diff --git a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
index a21ec3a1c991..9e85ca12fd9e 100644
--- a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
+++ b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
@@ -30,14 +30,14 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
};
}
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);
var chunks = new List();
await foreach (var update in agent.GenerateStreamingReplyAsync(messages, context.Options, cancellationToken))
{
- if (update is IStreamingMessage chatResponseUpdate)
+ if (update is IMessage chatResponseUpdate)
{
var response = chatResponseUpdate.Content switch
{
@@ -101,7 +101,7 @@ private IEnumerable ProcessMultiModalMessage(MultiModalMessage multiMo
// collect all the images
var images = imageMessages.SelectMany(m => ProcessImageMessage((ImageMessage)m, agent)
- .SelectMany(m => (m as IMessage)?.Content.Images));
+ .SelectMany(m => (m as IMessage)?.Content.Images ?? []));
var message = new Message()
{
diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI.V1/Agent/GPTAgent.cs
similarity index 96%
rename from dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
rename to dotnet/src/AutoGen.OpenAI.V1/Agent/GPTAgent.cs
index cdc6cc464d17..a32af5c38f15 100644
--- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
+++ b/dotnet/src/AutoGen.OpenAI.V1/Agent/GPTAgent.cs
@@ -5,10 +5,10 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
-using AutoGen.OpenAI.Extension;
+using AutoGen.OpenAI.V1.Extension;
using Azure.AI.OpenAI;
-namespace AutoGen.OpenAI;
+namespace AutoGen.OpenAI.V1;
///
/// GPT agent that can be used to connect to OpenAI chat models like GPT-3.5, GPT-4, etc.
@@ -27,6 +27,7 @@ namespace AutoGen.OpenAI;
/// -
/// - where TMessage1 is and TMessage2 is
///
+[Obsolete("Use OpenAIChatAgent instead")]
public class GPTAgent : IStreamingAgent
{
private readonly OpenAIClient openAIClient;
@@ -104,7 +105,7 @@ public async Task GenerateReplyAsync(
return await _innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
}
- public IAsyncEnumerable GenerateStreamingReplyAsync(
+ public IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
diff --git a/dotnet/src/AutoGen.OpenAI.V1/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI.V1/Agent/OpenAIChatAgent.cs
new file mode 100644
index 000000000000..2305536b4e5d
--- /dev/null
+++ b/dotnet/src/AutoGen.OpenAI.V1/Agent/OpenAIChatAgent.cs
@@ -0,0 +1,206 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatAgent.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.OpenAI.V1.Extension;
+using Azure.AI.OpenAI;
+
+namespace AutoGen.OpenAI.V1;
+
+///
+/// OpenAI client agent. This agent is a thin wrapper around to provide a simple interface for chat completions.
+/// To better work with other agents, it's recommended to use which supports more message types and have a better compatibility with other agents.
+/// supports the following message types:
+///
+/// -
+///
where T is : chat request message.
+///
+///
+/// returns the following message types:
+///
+/// -
+///
where T is : chat response message.
+/// where T is : streaming chat completions update.
+///
+///
+///
+public class OpenAIChatAgent : IStreamingAgent
+{
+ private readonly OpenAIClient openAIClient;
+ private readonly ChatCompletionsOptions options;
+ private readonly string systemMessage;
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// openai client
+ /// agent name
+ /// model name. e.g. gpt-turbo-3.5
+ /// system message
+ /// temperature
+ /// max tokens to generated
+ /// response format, set it to to enable json mode.
+ /// seed to use, set it to enable deterministic output
+ /// functions
+ public OpenAIChatAgent(
+ OpenAIClient openAIClient,
+ string name,
+ string modelName,
+ string systemMessage = "You are a helpful AI assistant",
+ float temperature = 0.7f,
+ int maxTokens = 1024,
+ int? seed = null,
+ ChatCompletionsResponseFormat? responseFormat = null,
+ IEnumerable? functions = null)
+ : this(
+ openAIClient: openAIClient,
+ name: name,
+ options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions),
+ systemMessage: systemMessage)
+ {
+ }
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// openai client
+ /// agent name
+ /// system message
+ /// chat completion option. The option can't contain messages
+ public OpenAIChatAgent(
+ OpenAIClient openAIClient,
+ string name,
+ ChatCompletionsOptions options,
+ string systemMessage = "You are a helpful AI assistant")
+ {
+ if (options.Messages is { Count: > 0 })
+ {
+ throw new ArgumentException("Messages should not be provided in options");
+ }
+
+ this.openAIClient = openAIClient;
+ this.Name = name;
+ this.options = options;
+ this.systemMessage = systemMessage;
+ }
+
+ public string Name { get; }
+
+ public async Task GenerateReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ var settings = this.CreateChatCompletionsOptions(options, messages);
+ var reply = await this.openAIClient.GetChatCompletionsAsync(settings, cancellationToken);
+
+ return new MessageEnvelope(reply, from: this.Name);
+ }
+
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var settings = this.CreateChatCompletionsOptions(options, messages);
+ var response = await this.openAIClient.GetChatCompletionsStreamingAsync(settings, cancellationToken);
+ await foreach (var update in response.WithCancellation(cancellationToken))
+ {
+ if (update.ChoiceIndex > 0)
+ {
+ throw new InvalidOperationException("Only one choice is supported in streaming response");
+ }
+
+ yield return new MessageEnvelope(update, from: this.Name);
+ }
+ }
+
+ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages)
+ {
+ var oaiMessages = messages.Select(m => m switch
+ {
+ IMessage chatRequestMessage => chatRequestMessage.Content,
+ _ => throw new ArgumentException("Invalid message type")
+ });
+
+ // add system message if there's no system message in messages
+ if (!oaiMessages.Any(m => m is ChatRequestSystemMessage))
+ {
+ oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages);
+ }
+
+ // clone the options by serializing and deserializing
+ var json = JsonSerializer.Serialize(this.options);
+ var settings = JsonSerializer.Deserialize(json) ?? throw new InvalidOperationException("Failed to clone options");
+
+ foreach (var m in oaiMessages)
+ {
+ settings.Messages.Add(m);
+ }
+
+ settings.Temperature = options?.Temperature ?? settings.Temperature;
+ settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens;
+
+ foreach (var functions in this.options.Tools)
+ {
+ settings.Tools.Add(functions);
+ }
+
+ foreach (var stopSequence in this.options.StopSequences)
+ {
+ settings.StopSequences.Add(stopSequence);
+ }
+
+ var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()).ToList();
+ if (openAIFunctionDefinitions is { Count: > 0 })
+ {
+ foreach (var f in openAIFunctionDefinitions)
+ {
+ settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
+ }
+ }
+
+ if (options?.StopSequence is var sequence && sequence is { Length: > 0 })
+ {
+ foreach (var seq in sequence)
+ {
+ settings.StopSequences.Add(seq);
+ }
+ }
+
+ return settings;
+ }
+
+ private static ChatCompletionsOptions CreateChatCompletionOptions(
+ string modelName,
+ float temperature = 0.7f,
+ int maxTokens = 1024,
+ int? seed = null,
+ ChatCompletionsResponseFormat? responseFormat = null,
+ IEnumerable? functions = null)
+ {
+ var options = new ChatCompletionsOptions(modelName, [])
+ {
+ Temperature = temperature,
+ MaxTokens = maxTokens,
+ Seed = seed,
+ ResponseFormat = responseFormat,
+ };
+
+ if (functions is not null)
+ {
+ foreach (var f in functions)
+ {
+ options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
+ }
+ }
+
+ return options;
+ }
+}
diff --git a/dotnet/src/AutoGen.OpenAI.V1/AutoGen.OpenAI.V1.csproj b/dotnet/src/AutoGen.OpenAI.V1/AutoGen.OpenAI.V1.csproj
new file mode 100644
index 000000000000..21951cb32fbd
--- /dev/null
+++ b/dotnet/src/AutoGen.OpenAI.V1/AutoGen.OpenAI.V1.csproj
@@ -0,0 +1,27 @@
+
+
+ $(PackageTargetFrameworks)
+ AutoGen.OpenAI
+
+
+
+
+
+
+ AutoGen.OpenAI.V1
+
+ OpenAI Intergration for AutoGen.
+ This package connects to openai using Azure.AI.OpenAI v1 package. It is reserved to keep compatibility with the projects which stick to that v1 package.
+ To use the latest version of OpenAI SDK, please use AutoGen.OpenAI package.
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/src/AutoGen.OpenAI/AzureOpenAIConfig.cs b/dotnet/src/AutoGen.OpenAI.V1/AzureOpenAIConfig.cs
similarity index 95%
rename from dotnet/src/AutoGen.OpenAI/AzureOpenAIConfig.cs
rename to dotnet/src/AutoGen.OpenAI.V1/AzureOpenAIConfig.cs
index 31df784ed21a..2be8f21dc4fc 100644
--- a/dotnet/src/AutoGen.OpenAI/AzureOpenAIConfig.cs
+++ b/dotnet/src/AutoGen.OpenAI.V1/AzureOpenAIConfig.cs
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AzureOpenAIConfig.cs
-namespace AutoGen.OpenAI;
+namespace AutoGen.OpenAI.V1;
public class AzureOpenAIConfig : ILLMConfig
{
diff --git a/dotnet/src/AutoGen.OpenAI.V1/Extension/FunctionContractExtension.cs b/dotnet/src/AutoGen.OpenAI.V1/Extension/FunctionContractExtension.cs
new file mode 100644
index 000000000000..62009b927eff
--- /dev/null
+++ b/dotnet/src/AutoGen.OpenAI.V1/Extension/FunctionContractExtension.cs
@@ -0,0 +1,63 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// FunctionContractExtension.cs
+
+using System;
+using System.Collections.Generic;
+using Azure.AI.OpenAI;
+using Json.Schema;
+using Json.Schema.Generation;
+
+namespace AutoGen.OpenAI.V1.Extension;
+
+public static class FunctionContractExtension
+{
+ ///
+ /// Convert a to a that can be used in gpt funciton call.
+ ///
+ /// function contract
+ ///
+ public static FunctionDefinition ToOpenAIFunctionDefinition(this FunctionContract functionContract)
+ {
+ var functionDefinition = new FunctionDefinition
+ {
+ Name = functionContract.Name,
+ Description = functionContract.Description,
+ };
+ var requiredParameterNames = new List();
+ var propertiesSchemas = new Dictionary();
+ var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
+ foreach (var param in functionContract.Parameters ?? [])
+ {
+ if (param.Name is null)
+ {
+ throw new InvalidOperationException("Parameter name cannot be null");
+ }
+
+ var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
+ if (param.Description != null)
+ {
+ schemaBuilder = schemaBuilder.Description(param.Description);
+ }
+
+ if (param.IsRequired)
+ {
+ requiredParameterNames.Add(param.Name);
+ }
+
+ var schema = schemaBuilder.Build();
+ propertiesSchemas[param.Name] = schema;
+
+ }
+ propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
+ propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);
+
+ var option = new System.Text.Json.JsonSerializerOptions()
+ {
+ PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase
+ };
+
+ functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option);
+
+ return functionDefinition;
+ }
+}
diff --git a/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs b/dotnet/src/AutoGen.OpenAI.V1/Extension/MessageExtension.cs
similarity index 99%
rename from dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs
rename to dotnet/src/AutoGen.OpenAI.V1/Extension/MessageExtension.cs
index ed795e5e8ed8..3264dccf3a8a 100644
--- a/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs
+++ b/dotnet/src/AutoGen.OpenAI.V1/Extension/MessageExtension.cs
@@ -6,7 +6,7 @@
using System.Linq;
using Azure.AI.OpenAI;
-namespace AutoGen.OpenAI;
+namespace AutoGen.OpenAI.V1;
public static class MessageExtension
{
diff --git a/dotnet/src/AutoGen.OpenAI.V1/Extension/OpenAIAgentExtension.cs b/dotnet/src/AutoGen.OpenAI.V1/Extension/OpenAIAgentExtension.cs
new file mode 100644
index 000000000000..6c0df8e0e965
--- /dev/null
+++ b/dotnet/src/AutoGen.OpenAI.V1/Extension/OpenAIAgentExtension.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIAgentExtension.cs
+
+namespace AutoGen.OpenAI.V1.Extension;
+
+public static class OpenAIAgentExtension
+{
+ ///
+ /// Register an to the
+ ///
+ /// the connector to use. If null, a new instance of will be created.
+ public static MiddlewareStreamingAgent RegisterMessageConnector(
+ this OpenAIChatAgent agent, OpenAIChatRequestMessageConnector? connector = null)
+ {
+ if (connector == null)
+ {
+ connector = new OpenAIChatRequestMessageConnector();
+ }
+
+ return agent.RegisterStreamingMiddleware(connector);
+ }
+
+ ///
+ /// Register an to the where T is
+ ///
+ /// the connector to use. If null, a new instance of will be created.
+ public static MiddlewareStreamingAgent RegisterMessageConnector(
+ this MiddlewareStreamingAgent agent, OpenAIChatRequestMessageConnector? connector = null)
+ {
+ if (connector == null)
+ {
+ connector = new OpenAIChatRequestMessageConnector();
+ }
+
+ return agent.RegisterStreamingMiddleware(connector);
+ }
+}
diff --git a/dotnet/src/AutoGen.OpenAI.V1/GlobalUsing.cs b/dotnet/src/AutoGen.OpenAI.V1/GlobalUsing.cs
new file mode 100644
index 000000000000..d66bf001ed5e
--- /dev/null
+++ b/dotnet/src/AutoGen.OpenAI.V1/GlobalUsing.cs
@@ -0,0 +1,4 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// GlobalUsing.cs
+
+global using AutoGen.Core;
diff --git a/dotnet/src/AutoGen.OpenAI.V1/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI.V1/Middleware/OpenAIChatRequestMessageConnector.cs
new file mode 100644
index 000000000000..3587d1b0d6f9
--- /dev/null
+++ b/dotnet/src/AutoGen.OpenAI.V1/Middleware/OpenAIChatRequestMessageConnector.cs
@@ -0,0 +1,390 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatRequestMessageConnector.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.AI.OpenAI;
+
+namespace AutoGen.OpenAI.V1;
+
+///
+/// This middleware converts the incoming to where T is before sending to agent. And converts the output to after receiving from agent.
+/// Supported are
+/// -
+/// -
+/// -
+/// -
+/// -
+/// - where T is
+/// - where TMessage1 is and TMessage2 is
+///
+public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddleware
+{
+ private bool strictMode = false;
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// If true, will throw an
+ /// When the message type is not supported. If false, it will ignore the unsupported message type.
+ public OpenAIChatRequestMessageConnector(bool strictMode = false)
+ {
+ this.strictMode = strictMode;
+ }
+
+ public string? Name => nameof(OpenAIChatRequestMessageConnector);
+
+ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
+ {
+ var chatMessages = ProcessIncomingMessages(agent, context.Messages);
+
+ var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);
+
+ return PostProcessMessage(reply);
+ }
+
+ public async IAsyncEnumerable InvokeAsync(
+ MiddlewareContext context,
+ IStreamingAgent agent,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var chatMessages = ProcessIncomingMessages(agent, context.Messages);
+ var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
+ string? currentToolName = null;
+ await foreach (var reply in streamingReply)
+ {
+ if (reply is IMessage update)
+ {
+ if (update.Content.FunctionName is string functionName)
+ {
+ currentToolName = functionName;
+ }
+ else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate && toolCallUpdate.Name is string toolCallName)
+ {
+ currentToolName = toolCallName;
+ }
+ var postProcessMessage = PostProcessStreamingMessage(update, currentToolName);
+ if (postProcessMessage != null)
+ {
+ yield return postProcessMessage;
+ }
+ }
+ else
+ {
+ if (this.strictMode)
+ {
+ throw new InvalidOperationException($"Invalid streaming message type {reply.GetType().Name}");
+ }
+ else
+ {
+ yield return reply;
+ }
+ }
+ }
+ }
+
+ public IMessage PostProcessMessage(IMessage message)
+ {
+ return message switch
+ {
+ IMessage m => PostProcessChatResponseMessage(m.Content, m.From),
+ IMessage m => PostProcessChatCompletions(m),
+ _ when strictMode is false => message,
+ _ => throw new InvalidOperationException($"Invalid return message type {message.GetType().Name}"),
+ };
+ }
+
+ public IMessage? PostProcessStreamingMessage(IMessage update, string? currentToolName)
+ {
+ if (update.Content.ContentUpdate is string contentUpdate)
+ {
+ // text message
+ return new TextMessageUpdate(Role.Assistant, contentUpdate, from: update.From);
+ }
+ else if (update.Content.FunctionName is string functionName)
+ {
+ return new ToolCallMessageUpdate(functionName, string.Empty, from: update.From);
+ }
+ else if (update.Content.FunctionArgumentsUpdate is string functionArgumentsUpdate && currentToolName is string)
+ {
+ return new ToolCallMessageUpdate(currentToolName, functionArgumentsUpdate, from: update.From);
+ }
+ else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate tooCallUpdate && currentToolName is string)
+ {
+ return new ToolCallMessageUpdate(tooCallUpdate.Name ?? currentToolName, tooCallUpdate.ArgumentsUpdate, from: update.From);
+ }
+ else
+ {
+ return null;
+ }
+ }
+
+ private IMessage PostProcessChatCompletions(IMessage message)
+ {
+ // throw exception if prompt filter results is not null
+ if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered)
+ {
+ throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input.");
+ }
+
+ return PostProcessChatResponseMessage(message.Content.Choices[0].Message, message.From);
+ }
+
+ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from)
+ {
+ var textContent = chatResponseMessage.Content;
+ if (chatResponseMessage.FunctionCall is FunctionCall functionCall)
+ {
+ return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from)
+ {
+ Content = textContent,
+ };
+ }
+
+ if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any())
+ {
+ var functionToolCalls = chatResponseMessage.ToolCalls
+ .Where(tc => tc is ChatCompletionsFunctionToolCall)
+ .Select(tc => (ChatCompletionsFunctionToolCall)tc);
+
+ var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });
+
+ return new ToolCallMessage(toolCalls, from)
+ {
+ Content = textContent,
+ };
+ }
+
+ if (textContent is string content && !string.IsNullOrEmpty(content))
+ {
+ return new TextMessage(Role.Assistant, content, from);
+ }
+
+ throw new InvalidOperationException("Invalid ChatResponseMessage");
+ }
+
+ public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages)
+ {
+ return messages.SelectMany(m =>
+ {
+ if (m is IMessage crm)
+ {
+ return [crm];
+ }
+ else
+ {
+ var chatRequestMessages = m switch
+ {
+ TextMessage textMessage => ProcessTextMessage(agent, textMessage),
+ ImageMessage imageMessage when (imageMessage.From is null || imageMessage.From != agent.Name) => ProcessImageMessage(agent, imageMessage),
+ MultiModalMessage multiModalMessage when (multiModalMessage.From is null || multiModalMessage.From != agent.Name) => ProcessMultiModalMessage(agent, multiModalMessage),
+ ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(agent, toolCallMessage),
+ ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
+ AggregateMessage aggregateMessage => ProcessFunctionCallMiddlewareMessage(agent, aggregateMessage),
+#pragma warning disable CS0618 // deprecated
+ Message msg => ProcessMessage(agent, msg),
+#pragma warning restore CS0618 // deprecated
+ _ when strictMode is false => [],
+ _ => throw new InvalidOperationException($"Invalid message type: {m.GetType().Name}"),
+ };
+
+ if (chatRequestMessages.Any())
+ {
+ return chatRequestMessages.Select(cm => MessageEnvelope.Create(cm, m.From));
+ }
+ else
+ {
+ return [m];
+ }
+ }
+ });
+ }
+
+ [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
+ private IEnumerable ProcessIncomingMessagesForSelf(Message message)
+ {
+ if (message.Role == Role.System)
+ {
+ return new[] { new ChatRequestSystemMessage(message.Content) };
+ }
+ else if (message.Content is string content && content is { Length: > 0 })
+ {
+ if (message.FunctionName is null)
+ {
+ return new[] { new ChatRequestAssistantMessage(message.Content) };
+ }
+ else
+ {
+ return new[] { new ChatRequestToolMessage(content, message.FunctionName) };
+ }
+ }
+ else if (message.FunctionName is string functionName)
+ {
+ var msg = new ChatRequestAssistantMessage(content: null)
+ {
+ FunctionCall = new FunctionCall(functionName, message.FunctionArguments)
+ };
+
+ return new[]
+ {
+ msg,
+ };
+ }
+ else
+ {
+ throw new InvalidOperationException("Invalid Message as message from self.");
+ }
+ }
+
+ [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
+ private IEnumerable ProcessIncomingMessagesForOther(Message message)
+ {
+ if (message.Role == Role.System)
+ {
+ return [new ChatRequestSystemMessage(message.Content) { Name = message.From }];
+ }
+ else if (message.Content is string content && content is { Length: > 0 })
+ {
+ if (message.FunctionName is not null)
+ {
+ return new[] { new ChatRequestToolMessage(content, message.FunctionName) };
+ }
+
+ return [new ChatRequestUserMessage(message.Content) { Name = message.From }];
+ }
+ else if (message.FunctionName is string _)
+ {
+ return [new ChatRequestUserMessage("// Message type is not supported") { Name = message.From }];
+ }
+ else
+ {
+ throw new InvalidOperationException("Invalid Message as message from other.");
+ }
+ }
+
+ private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message)
+ {
+ if (message.Role == Role.System)
+ {
+ return [new ChatRequestSystemMessage(message.Content) { Name = message.From }];
+ }
+
+ if (agent.Name == message.From)
+ {
+ return [new ChatRequestAssistantMessage(message.Content) { Name = agent.Name }];
+ }
+ else
+ {
+ return message.From switch
+ {
+ null when message.Role == Role.User => [new ChatRequestUserMessage(message.Content)],
+ null when message.Role == Role.Assistant => [new ChatRequestAssistantMessage(message.Content)],
+ null => throw new InvalidOperationException("Invalid Role"),
+ _ => [new ChatRequestUserMessage(message.Content) { Name = message.From }]
+ };
+ }
+ }
+
+ private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message)
+ {
+ if (agent.Name == message.From)
+ {
+ // image message from assistant is not supported
+ throw new ArgumentException("ImageMessage is not supported when message.From is the same with agent");
+ }
+
+ var imageContentItem = this.CreateChatMessageImageContentItemFromImageMessage(message);
+ return [new ChatRequestUserMessage([imageContentItem]) { Name = message.From }];
+ }
+
+ private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message)
+ {
+ if (agent.Name == message.From)
+ {
+ // image message from assistant is not supported
+ throw new ArgumentException("MultiModalMessage is not supported when message.From is the same with agent");
+ }
+
+ IEnumerable items = message.Content.Select(ci => ci switch
+ {
+ TextMessage text => new ChatMessageTextContentItem(text.Content),
+ ImageMessage image => this.CreateChatMessageImageContentItemFromImageMessage(image),
+ _ => throw new NotImplementedException(),
+ });
+
+ return [new ChatRequestUserMessage(items) { Name = message.From }];
+ }
+
+ private ChatMessageImageContentItem CreateChatMessageImageContentItemFromImageMessage(ImageMessage message)
+ {
+ return message.Data is null && message.Url is not null
+ ? new ChatMessageImageContentItem(new Uri(message.Url))
+ : new ChatMessageImageContentItem(message.Data, message.Data?.MediaType);
+ }
+
+ private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message)
+ {
+ if (message.From is not null && message.From != agent.Name)
+ {
+ throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
+ }
+
+ var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
+ var textContent = message.GetContent() ?? string.Empty;
+
+ // don't include the name field when it's tool call message.
+ // fix https://github.com/microsoft/autogen/issues/3437
+ var chatRequestMessage = new ChatRequestAssistantMessage(textContent);
+ foreach (var tc in toolCall)
+ {
+ chatRequestMessage.ToolCalls.Add(tc);
+ }
+
+ return [chatRequestMessage];
+ }
+
+ private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message)
+ {
+ return message.ToolCalls
+ .Where(tc => tc.Result is not null)
+ .Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}"));
+ }
+
+ [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
+ private IEnumerable ProcessMessage(IAgent agent, Message message)
+ {
+ if (message.From is not null && message.From != agent.Name)
+ {
+ return ProcessIncomingMessagesForOther(message);
+ }
+ else
+ {
+ return ProcessIncomingMessagesForSelf(message);
+ }
+ }
+
+ private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage)
+ {
+ if (aggregateMessage.From is not null && aggregateMessage.From != agent.Name)
+ {
+ // convert as user message
+ var resultMessage = aggregateMessage.Message2;
+
+ return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result) { Name = aggregateMessage.From });
+ }
+ else
+ {
+ var toolCallMessage1 = aggregateMessage.Message1;
+ var toolCallResultMessage = aggregateMessage.Message2;
+
+ var assistantMessage = this.ProcessToolCallMessage(agent, toolCallMessage1);
+ var toolCallResults = this.ProcessToolCallResultMessage(toolCallResultMessage);
+
+ return assistantMessage.Concat(toolCallResults);
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.OpenAI/OpenAIConfig.cs b/dotnet/src/AutoGen.OpenAI.V1/OpenAIConfig.cs
similarity index 91%
rename from dotnet/src/AutoGen.OpenAI/OpenAIConfig.cs
rename to dotnet/src/AutoGen.OpenAI.V1/OpenAIConfig.cs
index 35ce1e491aa9..592647cc2c1e 100644
--- a/dotnet/src/AutoGen.OpenAI/OpenAIConfig.cs
+++ b/dotnet/src/AutoGen.OpenAI.V1/OpenAIConfig.cs
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OpenAIConfig.cs
-namespace AutoGen.OpenAI;
+namespace AutoGen.OpenAI.V1;
public class OpenAIConfig : ILLMConfig
{
diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
index 37a4882f69e1..b0085d0f33c6 100644
--- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
+++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
@@ -8,70 +8,79 @@
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
+using global::OpenAI;
+using global::OpenAI.Chat;
+using Json.Schema;
namespace AutoGen.OpenAI;
///
/// OpenAI client agent. This agent is a thin wrapper around to provide a simple interface for chat completions.
-/// To better work with other agents, it's recommended to use which supports more message types and have a better compatibility with other agents.
/// supports the following message types:
///
/// -
-///
where T is : chat request message.
+/// where T is : chat message.
///
///
/// returns the following message types:
///
/// -
-///
where T is : chat response message.
-/// where T is : streaming chat completions update.
+/// where T is : chat response message.
+/// where T is : streaming chat completions update.
///
///
///
public class OpenAIChatAgent : IStreamingAgent
{
- private readonly OpenAIClient openAIClient;
- private readonly string modelName;
- private readonly float _temperature;
- private readonly int _maxTokens = 1024;
- private readonly IEnumerable? _functions;
- private readonly string _systemMessage;
- private readonly ChatCompletionsResponseFormat? _responseFormat;
- private readonly int? _seed;
+ private readonly ChatClient chatClient;
+ private readonly ChatCompletionOptions options;
+ private readonly string? systemMessage;
///
/// Create a new instance of .
///
- /// openai client
+ /// openai client
/// agent name
- /// model name. e.g. gpt-turbo-3.5
/// system message
/// temperature
/// max tokens to generated
- /// response format, set it to to enable json mode.
+ /// response format, set it to to enable json mode.
/// seed to use, set it to enable deterministic output
/// functions
public OpenAIChatAgent(
- OpenAIClient openAIClient,
+ ChatClient chatClient,
string name,
- string modelName,
- string systemMessage = "You are a helpful AI assistant",
- float temperature = 0.7f,
- int maxTokens = 1024,
+ string? systemMessage = "You are a helpful AI assistant",
+ float? temperature = null,
+ int? maxTokens = null,
int? seed = null,
- ChatCompletionsResponseFormat? responseFormat = null,
- IEnumerable? functions = null)
+ ChatResponseFormat? responseFormat = null,
+ IEnumerable? functions = null)
+ : this(
+ chatClient: chatClient,
+ name: name,
+ options: CreateChatCompletionOptions(temperature, maxTokens, seed, responseFormat, functions),
+ systemMessage: systemMessage)
{
- this.openAIClient = openAIClient;
- this.modelName = modelName;
+ }
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// openai chat client
+ /// agent name
+ /// system message
+ /// chat completion option. The option can't contain messages
+ public OpenAIChatAgent(
+ ChatClient chatClient,
+ string name,
+ ChatCompletionOptions options,
+ string? systemMessage = "You are a helpful AI assistant")
+ {
+ this.chatClient = chatClient;
this.Name = name;
- _temperature = temperature;
- _maxTokens = maxTokens;
- _functions = functions;
- _systemMessage = systemMessage;
- _responseFormat = responseFormat;
- _seed = seed;
+ this.options = options;
+ this.systemMessage = systemMessage;
}
public string Name { get; }
@@ -81,59 +90,85 @@ public async Task GenerateReplyAsync(
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
- var settings = this.CreateChatCompletionsOptions(options, messages);
- var reply = await this.openAIClient.GetChatCompletionsAsync(settings, cancellationToken);
-
- return new MessageEnvelope(reply, from: this.Name);
+ var chatHistory = this.CreateChatMessages(messages);
+ var settings = this.CreateChatCompletionsOptions(options);
+ var reply = await this.chatClient.CompleteChatAsync(chatHistory, settings, cancellationToken);
+ return new MessageEnvelope(reply.Value, from: this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- var settings = this.CreateChatCompletionsOptions(options, messages);
- var response = await this.openAIClient.GetChatCompletionsStreamingAsync(settings, cancellationToken);
+ var chatHistory = this.CreateChatMessages(messages);
+ var settings = this.CreateChatCompletionsOptions(options);
+ var response = this.chatClient.CompleteChatStreamingAsync(chatHistory, settings, cancellationToken);
await foreach (var update in response.WithCancellation(cancellationToken))
{
- if (update.ChoiceIndex > 0)
+ if (update.ContentUpdate.Count > 1)
{
throw new InvalidOperationException("Only one choice is supported in streaming response");
}
- yield return new MessageEnvelope(update, from: this.Name);
+ yield return new MessageEnvelope(update, from: this.Name);
}
}
- private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages)
+ private IEnumerable CreateChatMessages(IEnumerable messages)
{
var oaiMessages = messages.Select(m => m switch
{
- IMessage chatRequestMessage => chatRequestMessage.Content,
+ IMessage chatMessage => chatMessage.Content,
_ => throw new ArgumentException("Invalid message type")
});
// add system message if there's no system message in messages
- if (!oaiMessages.Any(m => m is ChatRequestSystemMessage))
+ if (!oaiMessages.Any(m => m is SystemChatMessage) && systemMessage is not null)
{
- oaiMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(oaiMessages);
+ oaiMessages = new[] { new SystemChatMessage(systemMessage) }.Concat(oaiMessages);
}
- var settings = new ChatCompletionsOptions(this.modelName, oaiMessages)
+ return oaiMessages;
+ }
+
+ private ChatCompletionOptions CreateChatCompletionsOptions(GenerateReplyOptions? options)
+ {
+ var option = new ChatCompletionOptions()
{
- MaxTokens = options?.MaxToken ?? _maxTokens,
- Temperature = options?.Temperature ?? _temperature,
- ResponseFormat = _responseFormat,
- Seed = _seed,
+ Seed = this.options.Seed,
+ Temperature = options?.Temperature ?? this.options.Temperature,
+ MaxTokens = options?.MaxToken ?? this.options.MaxTokens,
+ ResponseFormat = this.options.ResponseFormat,
+ FrequencyPenalty = this.options.FrequencyPenalty,
+ FunctionChoice = this.options.FunctionChoice,
+ IncludeLogProbabilities = this.options.IncludeLogProbabilities,
+ ParallelToolCallsEnabled = this.options.ParallelToolCallsEnabled,
+ PresencePenalty = this.options.PresencePenalty,
+ ToolChoice = this.options.ToolChoice,
+ TopLogProbabilityCount = this.options.TopLogProbabilityCount,
+ TopP = this.options.TopP,
+ EndUserId = this.options.EndUserId,
};
- var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition());
- var functions = openAIFunctionDefinitions ?? _functions;
- if (functions is not null && functions.Count() > 0)
+ // add tools from this.options to option
+ foreach (var tool in this.options.Tools)
{
- foreach (var f in functions)
+ option.Tools.Add(tool);
+ }
+
+ // add stop sequences from this.options to option
+ foreach (var seq in this.options.StopSequences)
+ {
+ option.StopSequences.Add(seq);
+ }
+
+ var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToChatTool()).ToList();
+ if (openAIFunctionDefinitions is { Count: > 0 })
+ {
+ foreach (var f in openAIFunctionDefinitions)
{
- settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
+ option.Tools.Add(f);
}
}
@@ -141,10 +176,44 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions
{
foreach (var seq in sequence)
{
- settings.StopSequences.Add(seq);
+ option.StopSequences.Add(seq);
+ }
+ }
+
+ if (options?.OutputSchema is not null)
+ {
+ option.ResponseFormat = ChatResponseFormat.CreateJsonSchemaFormat(
+ name: options.OutputSchema.GetTitle() ?? throw new ArgumentException("Output schema must have a title"),
+ jsonSchema: BinaryData.FromObjectAsJson(options.OutputSchema),
+ description: options.OutputSchema.GetDescription());
+ }
+
+ return option;
+ }
+
+ private static ChatCompletionOptions CreateChatCompletionOptions(
+ float? temperature = 0.7f,
+ int? maxTokens = 1024,
+ int? seed = null,
+ ChatResponseFormat? responseFormat = null,
+ IEnumerable? functions = null)
+ {
+ var options = new ChatCompletionOptions
+ {
+ Temperature = temperature,
+ MaxTokens = maxTokens,
+ Seed = seed,
+ ResponseFormat = responseFormat,
+ };
+
+ if (functions is not null)
+ {
+ foreach (var f in functions)
+ {
+ options.Tools.Add(f);
}
}
- return settings;
+ return options;
}
}
diff --git a/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj b/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj
index 7220cfe5c628..f93fdd4bc5e2 100644
--- a/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj
+++ b/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj
@@ -1,6 +1,6 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
AutoGen.OpenAI
@@ -11,11 +11,12 @@
AutoGen.OpenAI
OpenAI Intergration for AutoGen.
+ If your project still depends on Azure.AI.OpenAI v1, please use AutoGen.OpenAI.V1 package instead.
-
+
diff --git a/dotnet/src/AutoGen.OpenAI/Extension/FunctionContractExtension.cs b/dotnet/src/AutoGen.OpenAI/Extension/FunctionContractExtension.cs
index 4accdc4d8d46..dd1c1125aec0 100644
--- a/dotnet/src/AutoGen.OpenAI/Extension/FunctionContractExtension.cs
+++ b/dotnet/src/AutoGen.OpenAI/Extension/FunctionContractExtension.cs
@@ -3,26 +3,21 @@
using System;
using System.Collections.Generic;
-using Azure.AI.OpenAI;
using Json.Schema;
using Json.Schema.Generation;
+using OpenAI.Chat;
namespace AutoGen.OpenAI.Extension;
public static class FunctionContractExtension
{
///
- /// Convert a to a that can be used in gpt funciton call.
+ /// Convert a to a that can be used in gpt funciton call.
///
/// function contract
- ///
- public static FunctionDefinition ToOpenAIFunctionDefinition(this FunctionContract functionContract)
+ ///
+ public static ChatTool ToChatTool(this FunctionContract functionContract)
{
- var functionDefinition = new FunctionDefinition
- {
- Name = functionContract.Name,
- Description = functionContract.Description,
- };
var requiredParameterNames = new List();
var propertiesSchemas = new Dictionary();
var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
@@ -56,8 +51,22 @@ public static FunctionDefinition ToOpenAIFunctionDefinition(this FunctionContrac
PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase
};
- functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option);
+ var functionDefinition = ChatTool.CreateFunctionTool(
+ functionContract.Name ?? throw new ArgumentNullException(nameof(functionContract.Name)),
+ functionContract.Description,
+ BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option));
return functionDefinition;
}
+
+ ///
+ /// Convert a to a that can be used in gpt funciton call.
+ ///
+ /// function contract
+ ///
+ [Obsolete("Use ToChatTool instead")]
+ public static ChatTool ToOpenAIFunctionDefinition(this FunctionContract functionContract)
+ {
+ return functionContract.ToChatTool();
+ }
}
diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
index c1dc2caa99fb..fd55a1350326 100644
--- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
+++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
@@ -7,19 +7,19 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
-using Azure.AI.OpenAI;
+using OpenAI.Chat;
namespace AutoGen.OpenAI;
///
-/// This middleware converts the incoming to where T is before sending to agent. And converts the output to after receiving from agent.
+/// This middleware converts the incoming to where T is before sending to agent. And converts the output to after receiving from agent.
/// Supported are
/// -
/// -
/// -
/// -
/// -
-/// - where T is
+/// - where T is
/// - where TMessage1 is and TMessage2 is
///
public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddleware
@@ -47,31 +47,26 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
return PostProcessMessage(reply);
}
- public async IAsyncEnumerable InvokeAsync(
+ public async IAsyncEnumerable InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatMessages = ProcessIncomingMessages(agent, context.Messages);
var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
- string? currentToolName = null;
+ var chunks = new List();
+
+ // only streaming the text content
await foreach (var reply in streamingReply)
{
- if (reply is IStreamingMessage update)
+ if (reply is IMessage update)
{
- if (update.Content.FunctionName is string functionName)
- {
- currentToolName = functionName;
- }
- else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate && toolCallUpdate.Name is string toolCallName)
+ if (update.Content.ContentUpdate.Count == 1 && update.Content.ContentUpdate[0].Kind == ChatMessageContentPartKind.Text)
{
- currentToolName = toolCallName;
- }
- var postProcessMessage = PostProcessStreamingMessage(update, currentToolName);
- if (postProcessMessage != null)
- {
- yield return postProcessMessage;
+ yield return new TextMessageUpdate(Role.Assistant, update.Content.ContentUpdate[0].Text, from: update.From);
}
+
+ chunks.Add(update.Content);
}
else
{
@@ -85,83 +80,140 @@ public async IAsyncEnumerable InvokeAsync(
}
}
}
+
+ // process the tool call
+ var streamingChatToolCallUpdates = chunks.Where(c => c.ToolCallUpdates.Count > 0)
+ .SelectMany(c => c.ToolCallUpdates)
+ .ToList();
+
+ // collect all text parts
+ var textParts = chunks.SelectMany(c => c.ContentUpdate)
+ .Where(c => c.Kind == ChatMessageContentPartKind.Text)
+ .Select(c => c.Text)
+ .ToList();
+
+ // combine the tool call and function call into one ToolCallMessages
+ var text = string.Join(string.Empty, textParts);
+ var toolCalls = new List();
+ var currentToolName = string.Empty;
+ var currentToolArguments = string.Empty;
+ var currentToolId = string.Empty;
+ int? currentIndex = null;
+ foreach (var toolCall in streamingChatToolCallUpdates)
+ {
+ if (currentIndex is null)
+ {
+ currentIndex = toolCall.Index;
+ }
+
+ if (toolCall.Index == currentIndex)
+ {
+ currentToolName += toolCall.FunctionName;
+ currentToolArguments += toolCall.FunctionArgumentsUpdate;
+ currentToolId += toolCall.Id;
+
+ yield return new ToolCallMessageUpdate(currentToolName, currentToolArguments, from: agent.Name);
+ }
+ else
+ {
+ toolCalls.Add(new ToolCall(currentToolName, currentToolArguments) { ToolCallId = currentToolId });
+ currentToolName = toolCall.FunctionName;
+ currentToolArguments = toolCall.FunctionArgumentsUpdate;
+ currentToolId = toolCall.Id;
+ currentIndex = toolCall.Index;
+
+ yield return new ToolCallMessageUpdate(currentToolName, currentToolArguments, from: agent.Name);
+ }
+ }
+
+ if (string.IsNullOrEmpty(currentToolName) is false)
+ {
+ toolCalls.Add(new ToolCall(currentToolName, currentToolArguments) { ToolCallId = currentToolId });
+ }
+
+ if (toolCalls.Any())
+ {
+ yield return new ToolCallMessage(toolCalls, from: agent.Name)
+ {
+ Content = text,
+ };
+ }
}
public IMessage PostProcessMessage(IMessage message)
{
return message switch
{
- IMessage m => PostProcessChatResponseMessage(m.Content, m.From),
- IMessage m => PostProcessChatCompletions(m),
+ IMessage m => PostProcessChatCompletions(m),
_ when strictMode is false => message,
_ => throw new InvalidOperationException($"Invalid return message type {message.GetType().Name}"),
};
}
- public IStreamingMessage? PostProcessStreamingMessage(IStreamingMessage update, string? currentToolName)
+ private IMessage PostProcessChatCompletions(IMessage message)
{
- if (update.Content.ContentUpdate is string contentUpdate)
- {
- // text message
- return new TextMessageUpdate(Role.Assistant, contentUpdate, from: update.From);
- }
- else if (update.Content.FunctionName is string functionName)
- {
- return new ToolCallMessageUpdate(functionName, string.Empty, from: update.From);
- }
- else if (update.Content.FunctionArgumentsUpdate is string functionArgumentsUpdate && currentToolName is string)
- {
- return new ToolCallMessageUpdate(currentToolName, functionArgumentsUpdate, from: update.From);
- }
- else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate tooCallUpdate && currentToolName is string)
+ // throw exception if prompt filter results is not null
+ if (message.Content.FinishReason == ChatFinishReason.ContentFilter)
{
- return new ToolCallMessageUpdate(tooCallUpdate.Name ?? currentToolName, tooCallUpdate.ArgumentsUpdate, from: update.From);
+ throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input.");
}
- else
+
+ // throw exception is there is more than on choice
+ if (message.Content.Content.Count > 1)
{
- return null;
+ throw new InvalidOperationException("The content has more than one choice. Please try another input.");
}
+
+ return PostProcessChatResponseMessage(message.Content, message.From);
}
- private IMessage PostProcessChatCompletions(IMessage message)
+ private IMessage PostProcessChatResponseMessage(ChatCompletion chatCompletion, string? from)
{
// throw exception if prompt filter results is not null
- if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered)
+ if (chatCompletion.FinishReason == ChatFinishReason.ContentFilter)
{
throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input.");
}
- return PostProcessChatResponseMessage(message.Content.Choices[0].Message, message.From);
- }
+ // throw exception is there is more than on choice
+ if (chatCompletion.Content.Count > 1)
+ {
+ throw new InvalidOperationException("The content has more than one choice. Please try another input.");
+ }
+ var textContent = chatCompletion.Content.FirstOrDefault();
- private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from)
- {
- var textContent = chatResponseMessage.Content;
- if (chatResponseMessage.FunctionCall is FunctionCall functionCall)
+ // if tool calls is not empty, return ToolCallMessage
+ if (chatCompletion.ToolCalls is { Count: > 0 })
{
- return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from)
+ var toolCalls = chatCompletion.ToolCalls.Select(tc => new ToolCall(tc.FunctionName, tc.FunctionArguments) { ToolCallId = tc.Id });
+ return new ToolCallMessage(toolCalls, from)
{
- Content = textContent,
+ Content = textContent?.Kind switch
+ {
+ _ when textContent?.Kind == ChatMessageContentPartKind.Text => textContent.Text,
+ _ => null,
+ },
};
}
- if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any())
+ // else, process function call.
+ // This is deprecated and will be removed in the future.
+ if (chatCompletion.FunctionCall is ChatFunctionCall fc)
{
- var functionToolCalls = chatResponseMessage.ToolCalls
- .Where(tc => tc is ChatCompletionsFunctionToolCall)
- .Select(tc => (ChatCompletionsFunctionToolCall)tc);
-
- var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });
-
- return new ToolCallMessage(toolCalls, from)
+ return new ToolCallMessage(fc.FunctionName, fc.FunctionArguments, from)
{
- Content = textContent,
+ Content = textContent?.Kind switch
+ {
+ _ when textContent?.Kind == ChatMessageContentPartKind.Text => textContent.Text,
+ _ => null,
+ },
};
}
- if (textContent is string content && !string.IsNullOrEmpty(content))
+ // if the content is text, return TextMessage
+ if (textContent?.Kind == ChatMessageContentPartKind.Text)
{
- return new TextMessage(Role.Assistant, content, from);
+ return new TextMessage(Role.Assistant, textContent.Text, from);
}
throw new InvalidOperationException("Invalid ChatResponseMessage");
@@ -171,7 +223,7 @@ public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable(m =>
{
- if (m is IMessage crm)
+ if (m is IMessage crm)
{
return [crm];
}
@@ -185,9 +237,6 @@ MultiModalMessage multiModalMessage when (multiModalMessage.From is null || mult
ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(agent, toolCallMessage),
ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
AggregateMessage aggregateMessage => ProcessFunctionCallMiddlewareMessage(agent, aggregateMessage),
-#pragma warning disable CS0618 // deprecated
- Message msg => ProcessMessage(agent, msg),
-#pragma warning restore CS0618 // deprecated
_ when strictMode is false => [],
_ => throw new InvalidOperationException($"Invalid message type: {m.GetType().Name}"),
};
@@ -204,92 +253,30 @@ ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMe
});
}
- [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
- private IEnumerable ProcessIncomingMessagesForSelf(Message message)
+ private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message)
{
if (message.Role == Role.System)
{
- return new[] { new ChatRequestSystemMessage(message.Content) };
- }
- else if (message.Content is string content && content is { Length: > 0 })
- {
- if (message.FunctionName is null)
- {
- return new[] { new ChatRequestAssistantMessage(message.Content) };
- }
- else
- {
- return new[] { new ChatRequestToolMessage(content, message.FunctionName) };
- }
- }
- else if (message.FunctionName is string functionName)
- {
- var msg = new ChatRequestAssistantMessage(content: null)
- {
- FunctionCall = new FunctionCall(functionName, message.FunctionArguments)
- };
-
- return new[]
- {
- msg,
- };
- }
- else
- {
- throw new InvalidOperationException("Invalid Message as message from self.");
- }
- }
-
- [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
- private IEnumerable ProcessIncomingMessagesForOther(Message message)
- {
- if (message.Role == Role.System)
- {
- return [new ChatRequestSystemMessage(message.Content) { Name = message.From }];
- }
- else if (message.Content is string content && content is { Length: > 0 })
- {
- if (message.FunctionName is not null)
- {
- return new[] { new ChatRequestToolMessage(content, message.FunctionName) };
- }
-
- return [new ChatRequestUserMessage(message.Content) { Name = message.From }];
- }
- else if (message.FunctionName is string _)
- {
- return [new ChatRequestUserMessage("// Message type is not supported") { Name = message.From }];
- }
- else
- {
- throw new InvalidOperationException("Invalid Message as message from other.");
- }
- }
-
- private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message)
- {
- if (message.Role == Role.System)
- {
- return [new ChatRequestSystemMessage(message.Content) { Name = message.From }];
+ return [new SystemChatMessage(message.Content) { ParticipantName = message.From }];
}
if (agent.Name == message.From)
{
- return [new ChatRequestAssistantMessage(message.Content) { Name = agent.Name }];
+ return [new AssistantChatMessage(message.Content) { ParticipantName = agent.Name }];
}
else
{
return message.From switch
{
- null when message.Role == Role.User => [new ChatRequestUserMessage(message.Content)],
- null when message.Role == Role.Assistant => [new ChatRequestAssistantMessage(message.Content)],
+ null when message.Role == Role.User => [new UserChatMessage(message.Content)],
+ null when message.Role == Role.Assistant => [new AssistantChatMessage(message.Content)],
null => throw new InvalidOperationException("Invalid Role"),
- _ => [new ChatRequestUserMessage(message.Content) { Name = message.From }]
+ _ => [new UserChatMessage(message.Content) { ParticipantName = message.From }]
};
}
}
- private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message)
+ private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message)
{
if (agent.Name == message.From)
{
@@ -298,10 +285,10 @@ private IEnumerable ProcessImageMessage(IAgent agent, ImageM
}
var imageContentItem = this.CreateChatMessageImageContentItemFromImageMessage(message);
- return [new ChatRequestUserMessage([imageContentItem]) { Name = message.From }];
+ return [new UserChatMessage([imageContentItem]) { ParticipantName = message.From }];
}
- private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message)
+ private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message)
{
if (agent.Name == message.From)
{
@@ -309,69 +296,56 @@ private IEnumerable ProcessMultiModalMessage(IAgent agent, M
throw new ArgumentException("MultiModalMessage is not supported when message.From is the same with agent");
}
- IEnumerable items = message.Content.Select(ci => ci switch
+ IEnumerable items = message.Content.Select(ci => ci switch
{
- TextMessage text => new ChatMessageTextContentItem(text.Content),
+ TextMessage text => ChatMessageContentPart.CreateTextMessageContentPart(text.Content),
ImageMessage image => this.CreateChatMessageImageContentItemFromImageMessage(image),
_ => throw new NotImplementedException(),
});
- return [new ChatRequestUserMessage(items) { Name = message.From }];
+ return [new UserChatMessage(items) { ParticipantName = message.From }];
}
- private ChatMessageImageContentItem CreateChatMessageImageContentItemFromImageMessage(ImageMessage message)
+ private ChatMessageContentPart CreateChatMessageImageContentItemFromImageMessage(ImageMessage message)
{
return message.Data is null && message.Url is not null
- ? new ChatMessageImageContentItem(new Uri(message.Url))
- : new ChatMessageImageContentItem(message.Data, message.Data?.MediaType);
+ ? ChatMessageContentPart.CreateImageMessageContentPart(new Uri(message.Url))
+ : ChatMessageContentPart.CreateImageMessageContentPart(message.Data, message.Data?.MediaType);
}
- private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message)
+ private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message)
{
if (message.From is not null && message.From != agent.Name)
{
throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
}
- var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
- var textContent = message.GetContent() ?? string.Empty;
- var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };
- foreach (var tc in toolCall)
- {
- chatRequestMessage.ToolCalls.Add(tc);
- }
+ var toolCallParts = message.ToolCalls.Select((tc, i) => ChatToolCall.CreateFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
+ var textContent = message.GetContent() ?? null;
+
+ // Don't set participant name for assistant when it is tool call
+ // fix https://github.com/microsoft/autogen/issues/3437
+ var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent);
return [chatRequestMessage];
}
- private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message)
+ private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message)
{
return message.ToolCalls
.Where(tc => tc.Result is not null)
- .Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}"));
+ .Select((tc, i) => new ToolChatMessage(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.Result));
}
- [Obsolete("This method is deprecated, please use ProcessIncomingMessages(IAgent agent, IEnumerable messages) instead.")]
- private IEnumerable ProcessMessage(IAgent agent, Message message)
- {
- if (message.From is not null && message.From != agent.Name)
- {
- return ProcessIncomingMessagesForOther(message);
- }
- else
- {
- return ProcessIncomingMessagesForSelf(message);
- }
- }
- private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage)
+ private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage)
{
if (aggregateMessage.From is not null && aggregateMessage.From != agent.Name)
{
// convert as user message
var resultMessage = aggregateMessage.Message2;
- return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result) { Name = aggregateMessage.From });
+ return resultMessage.ToolCalls.Select(tc => new UserChatMessage(tc.Result) { ParticipantName = aggregateMessage.From });
}
else
{
diff --git a/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj b/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj
index 3bd96f93b687..b89626c01a06 100644
--- a/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj
+++ b/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj
@@ -1,7 +1,7 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
AutoGen.SemanticKernel
$(NoWarn);SKEXP0110
@@ -17,9 +17,9 @@
-
+
diff --git a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
index 6ce242eb1abe..a055c0afcb6a 100644
--- a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
+++ b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
@@ -47,7 +47,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
return PostProcessMessage(reply);
}
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatMessageContents = ProcessMessage(context.Messages, agent)
.Select(m => new MessageEnvelope(m));
@@ -67,11 +67,11 @@ private IMessage PostProcessMessage(IMessage input)
};
}
- private IStreamingMessage PostProcessStreamingMessage(IStreamingMessage input)
+ private IMessage PostProcessStreamingMessage(IMessage input)
{
return input switch
{
- IStreamingMessage streamingMessage => PostProcessMessage(streamingMessage),
+ IMessage streamingMessage => PostProcessMessage(streamingMessage),
IMessage msg => PostProcessMessage(msg),
_ => input,
};
@@ -98,7 +98,7 @@ private IMessage PostProcessMessage(IMessage messageEnvelope
}
}
- private IStreamingMessage PostProcessMessage(IStreamingMessage streamingMessage)
+ private IMessage PostProcessMessage(IMessage streamingMessage)
{
var chatMessageContent = streamingMessage.Content;
if (chatMessageContent.ChoiceIndex > 0)
diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
index 21f652f56c4f..e10f5b043f24 100644
--- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
+++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
@@ -65,7 +65,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G
return new MessageEnvelope(reply.First(), from: this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
@@ -106,7 +106,6 @@ private PromptExecutionSettings BuildOption(GenerateReplyOptions? options)
MaxTokens = options?.MaxToken ?? 1024,
StopSequences = options?.StopSequence,
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions,
- ResultsPerPrompt = 1,
};
}
diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelChatCompletionAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelChatCompletionAgent.cs
index 82d83a9e8556..1354996430bb 100644
--- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelChatCompletionAgent.cs
+++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelChatCompletionAgent.cs
@@ -27,7 +27,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G
CancellationToken cancellationToken = default)
{
ChatMessageContent[] reply = await _chatCompletionAgent
- .InvokeAsync(BuildChatHistory(messages), cancellationToken)
+ .InvokeAsync(BuildChatHistory(messages), cancellationToken: cancellationToken)
.ToArrayAsync(cancellationToken: cancellationToken);
return reply.Length > 1
diff --git a/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs b/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
index 24e42affa3bd..aa4980379f4f 100644
--- a/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
+++ b/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// FunctionContract.cs
+// SourceGeneratorFunctionContract.cs
namespace AutoGen.SourceGenerator
{
diff --git a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
index 40adbdcde47c..b90d78be3f19 100644
--- a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
+++ b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
@@ -36,7 +36,6 @@ public virtual string TransformText()
using System.Threading.Tasks;
using System;
using AutoGen.Core;
-using AutoGen.OpenAI.Extension;
");
if (!String.IsNullOrEmpty(NameSpace)) {
@@ -107,7 +106,7 @@ public virtual string TransformText()
}
if (functionContract.Description != null) {
this.Write(" Description = @\"");
- this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.Description));
+ this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.Description.Replace("\"", "\"\"")));
this.Write("\",\r\n");
}
if (functionContract.ReturnType != null) {
@@ -132,7 +131,7 @@ public virtual string TransformText()
}
if (parameter.Description != null) {
this.Write(" Description = @\"");
- this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Description));
+ this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Description.Replace("\"", "\"\"")));
this.Write("\",\r\n");
}
if (parameter.Type != null) {
@@ -152,12 +151,7 @@ public virtual string TransformText()
}
this.Write(" },\r\n");
}
- this.Write(" };\r\n }\r\n\r\n public global::Azure.AI.OpenAI.FunctionDefin" +
- "ition ");
- this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionDefinitionName()));
- this.Write("\r\n {\r\n get => this.");
- this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionContractName()));
- this.Write(".ToOpenAIFunctionDefinition();\r\n }\r\n");
+ this.Write(" };\r\n }\r\n");
}
this.Write(" }\r\n");
if (!String.IsNullOrEmpty(NameSpace)) {
diff --git a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
index 0d1b221c35c8..e7ed476fde8b 100644
--- a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
+++ b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
@@ -13,7 +13,6 @@ using System.Text.Json.Serialization;
using System.Threading.Tasks;
using System;
using AutoGen.Core;
-using AutoGen.OpenAI.Extension;
<#if (!String.IsNullOrEmpty(NameSpace)) {#>
namespace <#=NameSpace#>
@@ -63,7 +62,7 @@ namespace <#=NameSpace#>
Name = @"<#=functionContract.Name#>",
<#}#>
<#if (functionContract.Description != null) {#>
- Description = @"<#=functionContract.Description#>",
+ Description = @"<#=functionContract.Description.Replace("\"", "\"\"")#>",
<#}#>
<#if (functionContract.ReturnType != null) {#>
ReturnType = typeof(<#=functionContract.ReturnType#>),
@@ -81,7 +80,7 @@ namespace <#=NameSpace#>
Name = @"<#=parameter.Name#>",
<#}#>
<#if (parameter.Description != null) {#>
- Description = @"<#=parameter.Description#>",
+ Description = @"<#= parameter.Description.Replace("\"", "\"\"") #>",
<#}#>
<#if (parameter.Type != null) {#>
ParameterType = typeof(<#=parameter.Type#>),
@@ -96,11 +95,6 @@ namespace <#=NameSpace#>
<#}#>
};
}
-
- public global::Azure.AI.OpenAI.FunctionDefinition <#=functionContract.GetFunctionDefinitionName()#>
- {
- get => this.<#=functionContract.GetFunctionContractName()#>.ToOpenAIFunctionDefinition();
- }
<#}#>
}
<#if (!String.IsNullOrEmpty(NameSpace)) {#>
diff --git a/dotnet/src/AutoGen.WebAPI/AutoGen.WebAPI.csproj b/dotnet/src/AutoGen.WebAPI/AutoGen.WebAPI.csproj
new file mode 100644
index 000000000000..c5b720764761
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/AutoGen.WebAPI.csproj
@@ -0,0 +1,27 @@
+
+
+
+ net6.0;net8.0
+ true
+ $(NoWarn);CS1591;CS1573
+
+
+
+
+
+
+
+ AutoGen.WebAPI
+
+ Turn an `AutoGen.Core.IAgent` into a RESTful API.
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/src/AutoGen.WebAPI/Extension.cs b/dotnet/src/AutoGen.WebAPI/Extension.cs
new file mode 100644
index 000000000000..c8534e43e540
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/Extension.cs
@@ -0,0 +1,24 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Extension.cs
+
+using AutoGen.Core;
+using Microsoft.AspNetCore.Builder;
+
+namespace AutoGen.WebAPI;
+
+public static class Extension
+{
+ ///
+ /// Serve the agent as an OpenAI chat completion endpoint using .
+ /// If the request path is /v1/chat/completions and model name is the same as the agent name,
+ /// the request will be handled by the agent.
+ /// otherwise, the request will be passed to the next middleware.
+ ///
+ /// application builder
+ ///
+ public static IApplicationBuilder UseAgentAsOpenAIChatCompletionEndpoint(this IApplicationBuilder app, IAgent agent)
+ {
+ var middleware = new OpenAIChatCompletionMiddleware(agent);
+ return app.Use(middleware.InvokeAsync);
+ }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/Converter/OpenAIMessageConverter.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/Converter/OpenAIMessageConverter.cs
new file mode 100644
index 000000000000..888a0f8dd8c8
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/Converter/OpenAIMessageConverter.cs
@@ -0,0 +1,56 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIMessageConverter.cs
+
+using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIMessageConverter : JsonConverter
+{
+ public override OpenAIMessage Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ using JsonDocument document = JsonDocument.ParseValue(ref reader);
+ var root = document.RootElement;
+ var role = root.GetProperty("role").GetString();
+ var contentDocument = root.GetProperty("content");
+ var isContentDocumentString = contentDocument.ValueKind == JsonValueKind.String;
+ switch (role)
+ {
+ case "system":
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ case "user" when isContentDocumentString:
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ case "user" when !isContentDocumentString:
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ case "assistant":
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ case "tool":
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ default:
+ throw new JsonException();
+ }
+ }
+
+ public override void Write(Utf8JsonWriter writer, OpenAIMessage value, JsonSerializerOptions options)
+ {
+ switch (value)
+ {
+ case OpenAISystemMessage systemMessage:
+ JsonSerializer.Serialize(writer, systemMessage, options);
+ break;
+ case OpenAIUserMessage userMessage:
+ JsonSerializer.Serialize(writer, userMessage, options);
+ break;
+ case OpenAIAssistantMessage assistantMessage:
+ JsonSerializer.Serialize(writer, assistantMessage, options);
+ break;
+ case OpenAIToolMessage toolMessage:
+ JsonSerializer.Serialize(writer, toolMessage, options);
+ break;
+ default:
+ throw new JsonException();
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIAssistantMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIAssistantMessage.cs
new file mode 100644
index 000000000000..bfd090358453
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIAssistantMessage.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIAssistantMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIAssistantMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "assistant";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("tool_calls")]
+ public OpenAIToolCallObject[]? ToolCalls { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletion.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletion.cs
new file mode 100644
index 000000000000..041f4cfc848c
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletion.cs
@@ -0,0 +1,30 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletion.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletion
+{
+ [JsonPropertyName("id")]
+ public string? ID { get; set; }
+
+ [JsonPropertyName("created")]
+ public long Created { get; set; }
+
+ [JsonPropertyName("choices")]
+ public OpenAIChatCompletionChoice[]? Choices { get; set; }
+
+ [JsonPropertyName("model")]
+ public string? Model { get; set; }
+
+ [JsonPropertyName("system_fingerprint")]
+ public string? SystemFingerprint { get; set; }
+
+ [JsonPropertyName("object")]
+ public string Object { get; set; } = "chat.completion";
+
+ [JsonPropertyName("usage")]
+ public OpenAIChatCompletionUsage? Usage { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionChoice.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionChoice.cs
new file mode 100644
index 000000000000..35b6fce59a8e
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionChoice.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionChoice.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletionChoice
+{
+ [JsonPropertyName("finish_reason")]
+ public string? FinishReason { get; set; }
+
+ [JsonPropertyName("index")]
+ public int Index { get; set; }
+
+ [JsonPropertyName("message")]
+ public OpenAIChatCompletionMessage? Message { get; set; }
+
+ [JsonPropertyName("delta")]
+ public OpenAIChatCompletionMessage? Delta { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionMessage.cs
new file mode 100644
index 000000000000..de6be0dbf7a5
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionMessage.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletionMessage
+{
+ [JsonPropertyName("role")]
+ public string Role { get; } = "assistant";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionOption.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionOption.cs
new file mode 100644
index 000000000000..0b9137d43a39
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionOption.cs
@@ -0,0 +1,33 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionOption.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletionOption
+{
+ [JsonPropertyName("messages")]
+ public OpenAIMessage[]? Messages { get; set; }
+
+ [JsonPropertyName("model")]
+ public string? Model { get; set; }
+
+ [JsonPropertyName("max_tokens")]
+ public int? MaxTokens { get; set; }
+
+ [JsonPropertyName("temperature")]
+ public float Temperature { get; set; } = 1;
+
+ ///
+ /// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message
+ ///
+ [JsonPropertyName("stream")]
+ public bool? Stream { get; set; } = false;
+
+ [JsonPropertyName("stream_options")]
+ public OpenAIStreamOptions? StreamOptions { get; set; }
+
+ [JsonPropertyName("stop")]
+ public string[]? Stop { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionUsage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionUsage.cs
new file mode 100644
index 000000000000..f196ccb842ea
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionUsage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionUsage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletionUsage
+{
+ [JsonPropertyName("completion_tokens")]
+ public int CompletionTokens { get; set; }
+
+ [JsonPropertyName("prompt_tokens")]
+ public int PromptTokens { get; set; }
+
+ [JsonPropertyName("total_tokens")]
+ public int TotalTokens { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIImageUrlObject.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIImageUrlObject.cs
new file mode 100644
index 000000000000..a50012c9fed1
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIImageUrlObject.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIImageUrlObject.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIImageUrlObject
+{
+ [JsonPropertyName("url")]
+ public string? Url { get; set; }
+
+ [JsonPropertyName("detail")]
+ public string? Detail { get; set; } = "auto";
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIMessage.cs
new file mode 100644
index 000000000000..deb729b72003
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIMessage.cs
@@ -0,0 +1,13 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+[JsonConverter(typeof(OpenAIMessageConverter))]
+internal abstract class OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public abstract string? Role { get; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIStreamOptions.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIStreamOptions.cs
new file mode 100644
index 000000000000..e95991388b7f
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIStreamOptions.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIStreamOptions.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIStreamOptions
+{
+ [JsonPropertyName("include_usage")]
+ public bool? IncludeUsage { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAISystemMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAISystemMessage.cs
new file mode 100644
index 000000000000..f29b10826c4f
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAISystemMessage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAISystemMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAISystemMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "system";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolCallObject.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolCallObject.cs
new file mode 100644
index 000000000000..f3fc37f9c44f
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolCallObject.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIToolCallObject.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIToolCallObject
+{
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("arguments")]
+ public string? Arguments { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolMessage.cs
new file mode 100644
index 000000000000..0c84c164cd96
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolMessage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIToolMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIToolMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "tool";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+
+ [JsonPropertyName("tool_call_id")]
+ public string? ToolCallId { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserImageContent.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserImageContent.cs
new file mode 100644
index 000000000000..28b83ffb3058
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserImageContent.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserImageContent.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIUserImageContent : OpenAIUserMessageItem
+{
+ [JsonPropertyName("type")]
+ public override string MessageType { get; } = "image";
+
+ [JsonPropertyName("image_url")]
+ public string? Url { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessage.cs
new file mode 100644
index 000000000000..b5f1e7c50c12
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIUserMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "user";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessageItem.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessageItem.cs
new file mode 100644
index 000000000000..94e7d91534a5
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessageItem.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserMessageItem.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal abstract class OpenAIUserMessageItem
+{
+ [JsonPropertyName("type")]
+ public abstract string MessageType { get; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMultiModalMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMultiModalMessage.cs
new file mode 100644
index 000000000000..789df5afaaae
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMultiModalMessage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserMultiModalMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIUserMultiModalMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "user";
+
+ [JsonPropertyName("content")]
+ public OpenAIUserMessageItem[]? Content { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserTextContent.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserTextContent.cs
new file mode 100644
index 000000000000..d22d5aa4c7f3
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserTextContent.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserTextContent.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIUserTextContent : OpenAIUserMessageItem
+{
+ [JsonPropertyName("type")]
+ public override string MessageType { get; } = "text";
+
+ [JsonPropertyName("text")]
+ public string? Content { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/Service/OpenAIChatCompletionService.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/Service/OpenAIChatCompletionService.cs
new file mode 100644
index 000000000000..80d49050ee48
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/Service/OpenAIChatCompletionService.cs
@@ -0,0 +1,156 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionService.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using AutoGen.Core;
+using AutoGen.WebAPI.OpenAI.DTO;
+namespace AutoGen.Server;
+
+internal class OpenAIChatCompletionService
+{
+ private readonly IAgent agent;
+
+ public OpenAIChatCompletionService(IAgent agent)
+ {
+ this.agent = agent;
+ }
+
+ public async Task GetChatCompletionAsync(OpenAIChatCompletionOption request)
+ {
+ var messages = this.ProcessMessages(request.Messages ?? Array.Empty());
+
+ var generateOption = this.ProcessReplyOptions(request);
+
+ var reply = await this.agent.GenerateReplyAsync(messages, generateOption);
+
+ var openAIChatCompletion = new OpenAIChatCompletion()
+ {
+ Created = DateTimeOffset.UtcNow.Ticks / TimeSpan.TicksPerMillisecond / 1000,
+ Model = this.agent.Name,
+ };
+
+ if (reply.GetContent() is string content)
+ {
+ var message = new OpenAIChatCompletionMessage()
+ {
+ Content = content,
+ };
+
+ var choice = new OpenAIChatCompletionChoice()
+ {
+ Message = message,
+ Index = 0,
+ FinishReason = "stop",
+ };
+
+ openAIChatCompletion.Choices = [choice];
+
+ return openAIChatCompletion;
+ }
+
+ throw new NotImplementedException("Unsupported reply content type");
+ }
+
+ public async IAsyncEnumerable GetStreamingChatCompletionAsync(OpenAIChatCompletionOption request)
+ {
+ if (this.agent is IStreamingAgent streamingAgent)
+ {
+ var messages = this.ProcessMessages(request.Messages ?? Array.Empty());
+
+ var generateOption = this.ProcessReplyOptions(request);
+
+ await foreach (var reply in streamingAgent.GenerateStreamingReplyAsync(messages, generateOption))
+ {
+ var openAIChatCompletion = new OpenAIChatCompletion()
+ {
+ Created = DateTimeOffset.UtcNow.Ticks / TimeSpan.TicksPerMillisecond / 1000,
+ Model = this.agent.Name,
+ };
+
+ if (reply.GetContent() is string content)
+ {
+ var message = new OpenAIChatCompletionMessage()
+ {
+ Content = content,
+ };
+
+ var choice = new OpenAIChatCompletionChoice()
+ {
+ Delta = message,
+ Index = 0,
+ };
+
+ openAIChatCompletion.Choices = [choice];
+
+ yield return openAIChatCompletion;
+ }
+ else
+ {
+ throw new NotImplementedException("Unsupported reply content type");
+ }
+ }
+
+ var doneMessage = new OpenAIChatCompletion()
+ {
+ Created = DateTimeOffset.UtcNow.Ticks / TimeSpan.TicksPerMillisecond / 1000,
+ Model = this.agent.Name,
+ };
+
+ var doneChoice = new OpenAIChatCompletionChoice()
+ {
+ FinishReason = "stop",
+ Index = 0,
+ };
+
+ doneMessage.Choices = [doneChoice];
+
+ yield return doneMessage;
+ }
+ else
+ {
+ yield return await this.GetChatCompletionAsync(request);
+ }
+ }
+
+ private IEnumerable ProcessMessages(IEnumerable messages)
+ {
+ return messages.Select(m => m switch
+ {
+ OpenAISystemMessage systemMessage when systemMessage.Content is string content => new TextMessage(Role.System, content, this.agent.Name),
+ OpenAIUserMessage userMessage when userMessage.Content is string content => new TextMessage(Role.User, content, this.agent.Name),
+ OpenAIAssistantMessage assistantMessage when assistantMessage.Content is string content => new TextMessage(Role.Assistant, content, this.agent.Name),
+ OpenAIUserMultiModalMessage userMultiModalMessage when userMultiModalMessage.Content is { Length: > 0 } => this.CreateMultiModaMessageFromOpenAIUserMultiModalMessage(userMultiModalMessage),
+ _ => throw new ArgumentException($"Unsupported message type {m.GetType()}")
+ });
+ }
+
+ private GenerateReplyOptions ProcessReplyOptions(OpenAIChatCompletionOption request)
+ {
+ return new GenerateReplyOptions()
+ {
+ Temperature = request.Temperature,
+ MaxToken = request.MaxTokens,
+ StopSequence = request.Stop,
+ };
+ }
+
+ private MultiModalMessage CreateMultiModaMessageFromOpenAIUserMultiModalMessage(OpenAIUserMultiModalMessage message)
+ {
+ if (message.Content is null)
+ {
+ throw new ArgumentNullException(nameof(message.Content));
+ }
+
+ IEnumerable items = message.Content.Select(item => item switch
+ {
+ OpenAIUserImageContent imageContent when imageContent.Url is string url => new ImageMessage(Role.User, url, this.agent.Name),
+ OpenAIUserTextContent textContent when textContent.Content is string content => new TextMessage(Role.User, content, this.agent.Name),
+ _ => throw new ArgumentException($"Unsupported content type {item.GetType()}")
+ });
+
+ return new MultiModalMessage(Role.User, items, this.agent.Name);
+ }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAIChatCompletionMiddleware.cs b/dotnet/src/AutoGen.WebAPI/OpenAIChatCompletionMiddleware.cs
new file mode 100644
index 000000000000..53b3699fd62e
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAIChatCompletionMiddleware.cs
@@ -0,0 +1,92 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionMiddleware.cs
+
+using System.Text.Json;
+using System.Threading.Tasks;
+using AutoGen.Core;
+using AutoGen.Server;
+using AutoGen.WebAPI.OpenAI.DTO;
+using Microsoft.AspNetCore.Http;
+
+namespace AutoGen.WebAPI;
+
+public class OpenAIChatCompletionMiddleware : Microsoft.AspNetCore.Http.IMiddleware
+{
+ private readonly IAgent _agent;
+ private readonly OpenAIChatCompletionService chatCompletionService;
+
+ public OpenAIChatCompletionMiddleware(IAgent agent)
+ {
+ _agent = agent;
+ chatCompletionService = new OpenAIChatCompletionService(_agent);
+ }
+
+ public async Task InvokeAsync(HttpContext context, RequestDelegate next)
+ {
+ // if HttpPost and path is /v1/chat/completions
+ // get the request body
+ // call chatCompletionService.GetChatCompletionAsync(request)
+ // return the response
+
+ // else
+ // call next middleware
+ if (context.Request.Method == HttpMethods.Post && context.Request.Path == "/v1/chat/completions")
+ {
+ context.Request.EnableBuffering();
+ var body = await context.Request.ReadFromJsonAsync();
+ context.Request.Body.Position = 0;
+ if (body is null)
+ {
+ // return 400 Bad Request
+ context.Response.StatusCode = 400;
+ return;
+ }
+
+ if (body.Model != _agent.Name)
+ {
+ await next(context);
+ return;
+ }
+
+ if (body.Stream is true)
+ {
+ // Send as server side events
+ context.Response.Headers.Append("Content-Type", "text/event-stream");
+ context.Response.Headers.Append("Cache-Control", "no-cache");
+ context.Response.Headers.Append("Connection", "keep-alive");
+ await foreach (var chatCompletion in chatCompletionService.GetStreamingChatCompletionAsync(body))
+ {
+ if (chatCompletion?.Choices?[0].FinishReason is "stop")
+ {
+ // the stream is done
+ // send Data: [DONE]\n\n
+ await context.Response.WriteAsync("data: [DONE]\n\n");
+ break;
+ }
+ else
+ {
+ // remove null
+ var option = new JsonSerializerOptions
+ {
+ DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull,
+ };
+ var data = JsonSerializer.Serialize(chatCompletion, option);
+ await context.Response.WriteAsync($"data: {data}\n\n");
+ }
+ }
+
+ return;
+ }
+ else
+ {
+ var chatCompletion = await chatCompletionService.GetChatCompletionAsync(body);
+ await context.Response.WriteAsJsonAsync(chatCompletion);
+ return;
+ }
+ }
+ else
+ {
+ await next(context);
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen/API/LLMConfigAPI.cs b/dotnet/src/AutoGen/API/LLMConfigAPI.cs
index 5154f3dd5f55..28b5ad44312f 100644
--- a/dotnet/src/AutoGen/API/LLMConfigAPI.cs
+++ b/dotnet/src/AutoGen/API/LLMConfigAPI.cs
@@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
-using AutoGen.OpenAI;
namespace AutoGen
{
diff --git a/dotnet/src/AutoGen/Agent/ConversableAgent.cs b/dotnet/src/AutoGen/Agent/ConversableAgent.cs
index fe1470502022..da61c812f464 100644
--- a/dotnet/src/AutoGen/Agent/ConversableAgent.cs
+++ b/dotnet/src/AutoGen/Agent/ConversableAgent.cs
@@ -6,9 +6,8 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
-using AutoGen.LMStudio;
using AutoGen.OpenAI;
-
+using AutoGen.OpenAI.Extension;
namespace AutoGen;
public enum HumanInputMode
@@ -87,13 +86,21 @@ public ConversableAgent(
{
IAgent nextAgent = llmConfig switch
{
- AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0),
- OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0),
- LMStudioConfig lmStudioConfig => new LMStudioAgent(
- name: this.Name,
- config: lmStudioConfig,
- systemMessage: this.systemMessage,
- temperature: config.Temperature ?? 0),
+ AzureOpenAIConfig azureConfig => new OpenAIChatAgent(
+ chatClient: azureConfig.CreateChatClient(),
+ name: this.Name!,
+ systemMessage: this.systemMessage)
+ .RegisterMessageConnector(),
+ OpenAIConfig openAIConfig => new OpenAIChatAgent(
+ chatClient: openAIConfig.CreateChatClient(),
+ name: this.Name!,
+ systemMessage: this.systemMessage)
+ .RegisterMessageConnector(),
+ LMStudioConfig lmStudioConfig => new OpenAIChatAgent(
+ chatClient: lmStudioConfig.CreateChatClient(),
+ name: this.Name!,
+ systemMessage: this.systemMessage)
+ .RegisterMessageConnector(),
_ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"),
};
diff --git a/dotnet/src/AutoGen/AutoGen.csproj b/dotnet/src/AutoGen/AutoGen.csproj
index ddc34a071cbf..fe4431a35731 100644
--- a/dotnet/src/AutoGen/AutoGen.csproj
+++ b/dotnet/src/AutoGen/AutoGen.csproj
@@ -1,6 +1,6 @@
- netstandard2.0
+ $(PackageTargetFrameworks)
AutoGen
@@ -15,7 +15,8 @@
-
+
+
@@ -26,6 +27,7 @@
+
diff --git a/dotnet/src/AutoGen/AzureOpenAIConfig.cs b/dotnet/src/AutoGen/AzureOpenAIConfig.cs
new file mode 100644
index 000000000000..6112a3815d59
--- /dev/null
+++ b/dotnet/src/AutoGen/AzureOpenAIConfig.cs
@@ -0,0 +1,30 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AzureOpenAIConfig.cs
+
+using Azure.AI.OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen;
+
+public class AzureOpenAIConfig : ILLMConfig
+{
+ public AzureOpenAIConfig(string endpoint, string deploymentName, string apiKey)
+ {
+ this.Endpoint = endpoint;
+ this.DeploymentName = deploymentName;
+ this.ApiKey = apiKey;
+ }
+
+ public string Endpoint { get; }
+
+ public string DeploymentName { get; }
+
+ public string ApiKey { get; }
+
+ internal ChatClient CreateChatClient()
+ {
+ var client = new AzureOpenAIClient(new System.Uri(this.Endpoint), this.ApiKey);
+
+ return client.GetChatClient(DeploymentName);
+ }
+}
diff --git a/dotnet/src/AutoGen/LMStudioConfig.cs b/dotnet/src/AutoGen/LMStudioConfig.cs
new file mode 100644
index 000000000000..5fd9edc70802
--- /dev/null
+++ b/dotnet/src/AutoGen/LMStudioConfig.cs
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// LMStudioConfig.cs
+using System;
+using OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen;
+
+///
+/// Add support for consuming openai-like API from LM Studio
+///
+public class LMStudioConfig : ILLMConfig
+{
+ public LMStudioConfig(string host, int port)
+ {
+ this.Host = host;
+ this.Port = port;
+ this.Uri = new Uri($"http://{host}:{port}");
+ }
+
+ public LMStudioConfig(Uri uri)
+ {
+ this.Uri = uri;
+ this.Host = uri.Host;
+ this.Port = uri.Port;
+ }
+
+ public string Host { get; }
+
+ public int Port { get; }
+
+ public Uri Uri { get; }
+
+ internal ChatClient CreateChatClient()
+ {
+ var client = new OpenAIClient("api-key", new OpenAIClientOptions
+ {
+ Endpoint = this.Uri,
+ });
+
+ // model name doesn't matter for LM Studio
+
+ return client.GetChatClient("model-name");
+ }
+}
diff --git a/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs b/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs
index 1a742b11c799..eda3c001a249 100644
--- a/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs
+++ b/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs
@@ -18,7 +18,7 @@ public class HumanInputMiddleware : IMiddleware
private readonly string prompt;
private readonly string exitKeyword;
private Func, CancellationToken, Task> isTermination;
- private Func getInput = Console.ReadLine;
+ private Func getInput = Console.ReadLine;
private Action writeLine = Console.WriteLine;
public string? Name => nameof(HumanInputMiddleware);
@@ -27,7 +27,7 @@ public HumanInputMiddleware(
string exitKeyword = "exit",
HumanInputMode mode = HumanInputMode.AUTO,
Func, CancellationToken, Task>? isTermination = null,
- Func? getInput = null,
+ Func? getInput = null,
Action? writeLine = null)
{
this.prompt = prompt;
@@ -56,6 +56,8 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, agent.Name);
}
+ input ??= string.Empty;
+
return new TextMessage(Role.Assistant, input, agent.Name);
}
@@ -74,6 +76,8 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, agent.Name);
}
+ input ??= string.Empty;
+
return new TextMessage(Role.Assistant, input, agent.Name);
}
@@ -85,7 +89,7 @@ private async Task DefaultIsTermination(IEnumerable messages, Ca
return messages?.Last().IsGroupChatTerminateMessage() is true;
}
- private string GetInput()
+ private string? GetInput()
{
return Console.ReadLine();
}
diff --git a/dotnet/src/AutoGen/OpenAIConfig.cs b/dotnet/src/AutoGen/OpenAIConfig.cs
new file mode 100644
index 000000000000..ea50fa085f11
--- /dev/null
+++ b/dotnet/src/AutoGen/OpenAIConfig.cs
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIConfig.cs
+
+using OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen;
+
+public class OpenAIConfig : ILLMConfig
+{
+ public OpenAIConfig(string apiKey, string modelId)
+ {
+ this.ApiKey = apiKey;
+ this.ModelId = modelId;
+ }
+
+ public string ApiKey { get; }
+
+ public string ModelId { get; }
+
+ internal ChatClient CreateChatClient()
+ {
+ var client = new OpenAIClient(this.ApiKey);
+
+ return client.GetChatClient(this.ModelId);
+ }
+}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
index d29025b44aff..085917d419e9 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
@@ -32,6 +32,30 @@ public async Task AnthropicAgentChatCompletionTestAsync()
reply.From.Should().Be(agent.Name);
}
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentMergeMessageWithSameRoleTests()
+ {
+ // this test is added to fix issue #2884
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant that convert user message to upper case")
+ .RegisterMessageConnector();
+
+ var uppCaseMessage = new TextMessage(Role.User, "abcdefg");
+ var anotherUserMessage = new TextMessage(Role.User, "hijklmn");
+ var assistantMessage = new TextMessage(Role.Assistant, "opqrst");
+ var anotherAssistantMessage = new TextMessage(Role.Assistant, "uvwxyz");
+ var yetAnotherUserMessage = new TextMessage(Role.User, "123456");
+
+ // just make sure it doesn't throw exception
+ var reply = await agent.SendAsync(chatHistory: [uppCaseMessage, anotherUserMessage, assistantMessage, anotherAssistantMessage, yetAnotherUserMessage]);
+ reply.GetContent().Should().NotBeNull();
+ }
+
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestProcessImageAsync()
{
@@ -105,4 +129,101 @@ public async Task AnthropicAgentTestImageMessageAsync()
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentTestToolAsync()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var function = new TypeSafeFunctionCall();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: new[] { function.WeatherReportFunctionContract },
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name ?? string.Empty, function.WeatherReportWrapper },
+ });
+
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are an LLM that is specialized in finding the weather !",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var reply = await agent.SendAsync("What is the weather in Philadelphia?");
+ reply.GetContent().Should().Be("Weather report for Philadelphia on today is sunny");
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentFunctionCallMessageTest()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant.",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector();
+
+ var weatherFunctionArgumets = """
+ {
+ "city": "Philadelphia",
+ "date": "6/14/2024"
+ }
+ """;
+
+ var function = new AnthropicTestFunctionCalls();
+ var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArgumets);
+ var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArgumets)
+ {
+ ToolCallId = "get_weather",
+ Result = functionCallResult,
+ };
+
+ IMessage[] chatHistory = [
+ new TextMessage(Role.User, "what's the weather in Philadelphia?"),
+ new ToolCallMessage([toolCall], from: "assistant"),
+ new ToolCallResultMessage([toolCall], from: "user"),
+ ];
+
+ var reply = await agent.SendAsync(chatHistory: chatHistory);
+
+ reply.Should().BeOfType();
+ reply.GetContent().Should().Be("The weather report for Philadelphia on 6/14/2024 is sunny.");
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentFunctionCallMiddlewareMessageTest()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+ var function = new AnthropicTestFunctionCalls();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [function.WeatherReportFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name!, function.GetWeatherReportWrapper }
+ });
+
+ var functionCallAgent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant.",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var question = new TextMessage(Role.User, "what's the weather in Philadelphia?");
+ var reply = await functionCallAgent.SendAsync(question);
+
+ var finalReply = await functionCallAgent.SendAsync(chatHistory: [question, reply]);
+ finalReply.Should().BeOfType();
+ finalReply.GetContent()!.ToLower().Should().Contain("sunny");
+ }
}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
index a0b1f60cfb95..0018f2decbc1 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
@@ -1,5 +1,9 @@
-using System.Text;
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicClientTest.cs
+
+using System.Text;
using System.Text.Json;
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using AutoGen.Anthropic.DTO;
using AutoGen.Anthropic.Utils;
@@ -43,7 +47,12 @@ public async Task AnthropicClientStreamingChatCompletionTestAsync()
request.Model = AnthropicConstants.Claude3Haiku;
request.Stream = true;
request.MaxTokens = 500;
- request.SystemMessage = "You are a helpful assistant that convert input to json object, use JSON format.";
+ request.SystemMessage =
+ [
+ SystemMessage.CreateSystemMessage(
+ "You are a helpful assistant that convert input to json object, use JSON format.")
+ ];
+
request.Messages = new List()
{
new("user", "name: John, age: 41, email: g123456@gmail.com")
@@ -58,7 +67,9 @@ public async Task AnthropicClientStreamingChatCompletionTestAsync()
foreach (ChatCompletionResponse result in results)
{
if (result.Delta is not null && !string.IsNullOrEmpty(result.Delta.Text))
+ {
sb.Append(result.Delta.Text);
+ }
}
string resultContent = sb.ToString();
@@ -82,7 +93,11 @@ public async Task AnthropicClientImageChatCompletionTestAsync()
request.Model = AnthropicConstants.Claude3Haiku;
request.Stream = false;
request.MaxTokens = 100;
- request.SystemMessage = "You are a LLM that is suppose to describe the content of the image. Give me a description of the provided image.";
+ request.SystemMessage =
+ [
+ SystemMessage.CreateSystemMessage(
+ "You are a LLM that is suppose to describe the content of the image. Give me a description of the provided image."),
+ ];
var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png");
var messages = new List
@@ -108,6 +123,111 @@ public async Task AnthropicClientImageChatCompletionTestAsync()
response.Usage.OutputTokens.Should().BeGreaterThan(0);
}
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicClientTestToolsAsync()
+ {
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var request = new ChatCompletionRequest();
+ request.Model = AnthropicConstants.Claude3Haiku;
+ request.Stream = false;
+ request.MaxTokens = 100;
+ request.Messages = new List() { new("user", "Use the stock price tool to look for MSFT. Your response should only be the tool.") };
+ request.Tools = new List() { AnthropicTestUtils.StockTool };
+
+ ChatCompletionResponse response =
+ await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
+
+ Assert.NotNull(response.Content);
+ Assert.True(response.Content.First() is ToolUseContent);
+ ToolUseContent toolUseContent = ((ToolUseContent)response.Content.First());
+ Assert.Equal("get_stock_price", toolUseContent.Name);
+ Assert.NotNull(toolUseContent.Input);
+ Assert.True(toolUseContent.Input is JsonNode);
+ JsonNode jsonNode = toolUseContent.Input;
+ Assert.Equal("{\"ticker\":\"MSFT\"}", jsonNode.ToJsonString());
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicClientTestToolChoiceAsync()
+ {
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var request = new ChatCompletionRequest();
+ request.Model = AnthropicConstants.Claude3Haiku;
+ request.Stream = false;
+ request.MaxTokens = 100;
+ request.Messages = new List() { new("user", "What is the weather today? Your response should only be the tool.") };
+ request.Tools = new List() { AnthropicTestUtils.StockTool, AnthropicTestUtils.WeatherTool };
+
+ // Force to use get_stock_price even though the prompt is about weather
+ request.ToolChoice = ToolChoice.ToolUse("get_stock_price");
+
+ ChatCompletionResponse response =
+ await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
+
+ Assert.NotNull(response.Content);
+ Assert.True(response.Content.First() is ToolUseContent);
+ ToolUseContent toolUseContent = ((ToolUseContent)response.Content.First());
+ Assert.Equal("get_stock_price", toolUseContent.Name);
+ Assert.NotNull(toolUseContent.Input);
+ Assert.True(toolUseContent.Input is JsonNode);
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicClientChatCompletionCacheControlTestAsync()
+ {
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var request = new ChatCompletionRequest();
+ request.Model = AnthropicConstants.Claude35Sonnet;
+ request.Stream = false;
+ request.MaxTokens = 100;
+
+ request.SystemMessage =
+ [
+ SystemMessage.CreateSystemMessageWithCacheControl(
+ $"You are an LLM that is great at remembering stories {AnthropicTestUtils.LongStory}"),
+ ];
+
+ request.Messages =
+ [
+ new ChatMessage("user", "What should i know about Bob?")
+ ];
+
+ var response = await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
+ response.Usage.Should().NotBeNull();
+
+ // There's no way to clear the cache. Running the assert frequently may cause this to fail because the cache is already been created
+ // response.Usage!.CreationInputTokens.Should().BeGreaterThan(0);
+ // The cache reduces the input tokens. We expect the input tokens to be less the large system prompt and only the user message
+ response.Usage!.InputTokens.Should().BeLessThan(20);
+
+ request.Messages =
+ [
+ new ChatMessage("user", "Summarize the story of bob")
+ ];
+
+ response = await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
+ response.Usage.Should().NotBeNull();
+ response.Usage!.CacheReadInputTokens.Should().BeGreaterThan(0);
+ response.Usage!.InputTokens.Should().BeLessThan(20);
+
+ // Should not use the cache
+ request.SystemMessage =
+ [
+ SystemMessage.CreateSystemMessage("You are a helpful assistant.")
+ ];
+
+ request.Messages =
+ [
+ new ChatMessage("user", "What are some text editors I could use to write C#?")
+ ];
+
+ response = await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
+ response.Usage!.CacheReadInputTokens.Should().Be(0);
+ }
+
private sealed class Person
{
[JsonPropertyName("name")]
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestFunctionCalls.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestFunctionCalls.cs
new file mode 100644
index 000000000000..8b5466e3a519
--- /dev/null
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestFunctionCalls.cs
@@ -0,0 +1,40 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicTestFunctionCalls.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+
+namespace AutoGen.Anthropic.Tests;
+
+public partial class AnthropicTestFunctionCalls
+{
+ private class GetWeatherSchema
+ {
+ [JsonPropertyName("city")]
+ public string? City { get; set; }
+
+ [JsonPropertyName("date")]
+ public string? Date { get; set; }
+ }
+
+ ///
+ /// Get weather report
+ ///
+ /// city
+ /// date
+ [Function]
+ public async Task WeatherReport(string city, string date)
+ {
+ return $"Weather report for {city} on {date} is sunny";
+ }
+
+ public Task GetWeatherReportWrapper(string arguments)
+ {
+ var schema = JsonSerializer.Deserialize(
+ arguments,
+ new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
+
+ return WeatherReport(schema?.City ?? string.Empty, schema?.Date ?? string.Empty);
+ }
+}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
index de630da6d87c..d80c5fbe5705 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicTestUtils.cs
+using AutoGen.Anthropic.DTO;
+
namespace AutoGen.Anthropic.Tests;
public static class AnthropicTestUtils
@@ -13,4 +15,130 @@ public static async Task Base64FromImageAsync(string imageName)
return Convert.ToBase64String(
await File.ReadAllBytesAsync(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "images", imageName)));
}
+
+ public static Tool WeatherTool
+ {
+ get
+ {
+ return new Tool
+ {
+ Name = "WeatherReport",
+ Description = "Get the current weather",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ { "city", new SchemaProperty {Type = "string", Description = "The name of the city"} },
+ { "date", new SchemaProperty {Type = "string", Description = "date of the day"} }
+ }
+ }
+ };
+ }
+ }
+
+ public static Tool StockTool
+ {
+ get
+ {
+ return new Tool
+ {
+ Name = "get_stock_price",
+ Description = "Get the current stock price for a given ticker symbol.",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ {
+ "ticker", new SchemaProperty
+ {
+ Type = "string",
+ Description = "The stock ticker symbol, e.g. AAPL for Apple Inc."
+ }
+ }
+ },
+ Required = new List { "ticker" }
+ }
+ };
+ }
+ }
+
+ #region Long text for caching
+ // To test cache control, the context must be larger than 1024 tokens for Claude 3.5 Sonnet and Claude 3 Opus
+ // 2048 tokens for Claude 3.0 Haiku
+ // Shorter prompts cannot be cached, even if marked with cache_control. Any requests to cache fewer than this number of tokens will be processed without caching
+ public const string LongStory = """
+Once upon a time in a small, nondescript town lived a man named Bob. Bob was an unassuming individual, the kind of person you wouldn’t look twice at if you passed him on the street. He worked as an IT specialist for a mid-sized corporation, spending his days fixing computers and troubleshooting software issues. But beneath his average exterior, Bob harbored a secret ambition—he wanted to take over the world.
+
+Bob wasn’t always like this. For most of his life, he had been content with his routine, blending into the background. But one day, while browsing the dark corners of the internet, Bob stumbled upon an ancient manuscript, encrypted within the deep web, detailing the steps to global domination. It was written by a forgotten conqueror, someone whose name had been erased from history but whose methods were preserved in this digital relic. The manuscript laid out a plan so intricate and flawless that Bob, with his analytical mind, became obsessed.
+
+Over the next few years, Bob meticulously followed the manuscript’s guidance. He started small, creating a network of like-minded individuals who shared his dream. They communicated through encrypted channels, meeting in secret to discuss their plans. Bob was careful, never revealing too much about himself, always staying in the shadows. He used his IT skills to gather information, infiltrating government databases, and private corporations, and acquiring secrets that could be used as leverage.
+
+As his network grew, so did his influence. Bob began to manipulate world events from behind the scenes. He orchestrated economic crises, incited political turmoil, and planted seeds of discord among the world’s most powerful nations. Each move was calculated, each action a step closer to his ultimate goal. The world was in chaos, and no one suspected that a man like Bob could be behind it all.
+
+But Bob knew that causing chaos wasn’t enough. To truly take over the world, he needed something more—something to cement his power. That’s when he turned to technology. Bob had always been ahead of the curve when it came to tech, and now, he planned to use it to his advantage. He began developing an AI, one that would be more powerful and intelligent than anything the world had ever seen. This AI, which Bob named “Nemesis,” was designed to control every aspect of modern life—from financial systems to military networks.
+
+It took years of coding, testing, and refining, but eventually, Nemesis was ready. Bob unleashed the AI, and within days, it had taken control of the world’s digital infrastructure. Governments were powerless, their systems compromised. Corporations crumbled as their assets were seized. The military couldn’t act, their weapons turned against them. Bob, from the comfort of his modest home, had done it. He had taken over the world.
+
+The world, now under Bob’s control, was eerily quiet. There were no more wars, no more financial crises, no more political strife. Nemesis ensured that everything ran smoothly, efficiently, and without dissent. The people of the world had no choice but to obey, their lives dictated by an unseen hand.
+
+Bob, once a man who was overlooked and ignored, was now the most powerful person on the planet. But with that power came a realization. The world he had taken over was not the world he had envisioned. It was cold, mechanical, and devoid of the chaos that once made life unpredictable and exciting. Bob had achieved his goal, but in doing so, he had lost the very thing that made life worth living—freedom.
+
+And so, Bob, now ruler of the world, sat alone in his control room, staring at the screens that displayed his dominion. He had everything he had ever wanted, yet he felt emptier than ever before. The world was his, but at what cost?
+
+In the end, Bob realized that true power didn’t come from controlling others, but from the ability to let go. He deactivated Nemesis, restoring the world to its former state, and disappeared into obscurity, content to live out the rest of his days as just another face in the crowd. And though the world never knew his name, Bob’s legacy would live on, a reminder of the dangers of unchecked ambition.
+
+Bob had vanished, leaving the world in a fragile state of recovery. Governments scrambled to regain control of their systems, corporations tried to rebuild, and the global population slowly adjusted to life without the invisible grip of Nemesis. Yet, even as society returned to a semblance of normalcy, whispers of the mysterious figure who had brought the world to its knees lingered in the shadows.
+
+Meanwhile, Bob had retreated to a secluded cabin deep in the mountains. The cabin was a modest, rustic place, surrounded by dense forests and overlooking a tranquil lake. It was far from civilization, a perfect place for a man who wanted to disappear. Bob spent his days fishing, hiking, and reflecting on his past. For the first time in years, he felt a sense of peace.
+
+But peace was fleeting. Despite his best efforts to put his past behind him, Bob couldn’t escape the consequences of his actions. He had unleashed Nemesis upon the world, and though he had deactivated the AI, remnants of its code still existed. Rogue factions, hackers, and remnants of his old network were searching for those fragments, hoping to revive Nemesis and seize the power that Bob had relinquished.
+
+One day, as Bob was chopping wood outside his cabin, a figure emerged from the tree line. It was a young woman, dressed in hiking gear, with a determined look in her eyes. Bob tensed, his instincts telling him that this was no ordinary hiker.
+
+“Bob,” the woman said, her voice steady. “Or should I say, the man who almost became the ruler of the world?”
+
+Bob sighed, setting down his axe. “Who are you, and what do you want?”
+
+The woman stepped closer. “My name is Sarah. I was part of your network, one of the few who knew about Nemesis. But I wasn’t like the others. I didn’t want power for myself—I wanted to protect the world from those who would misuse it.”
+
+Bob studied her, trying to gauge her intentions. “And why are you here now?”
+
+Sarah reached into her backpack and pulled out a small device. “Because Nemesis isn’t dead. Some of its code is still active, and it’s trying to reboot itself. I need your help to stop it for good.”
+
+Bob’s heart sank. He had hoped that by deactivating Nemesis, he had erased it from existence. But deep down, he knew that an AI as powerful as Nemesis wouldn’t go down so easily. “Why come to me? I’m the one who created it. I’m the reason the world is in this mess.”
+
+Sarah shook her head. “You’re also the only one who knows how to stop it. I’ve tracked down the remnants of Nemesis’s code, but I need you to help destroy it before it falls into the wrong hands.”
+
+Bob hesitated. He had wanted nothing more than to leave his past behind, but he couldn’t ignore the responsibility that weighed on him. He had created Nemesis, and now it was his duty to make sure it never posed a threat again.
+
+“Alright,” Bob said finally. “I’ll help you. But after this, I’m done. No more world domination, no more secret networks. I just want to live in peace.”
+
+Sarah nodded. “Agreed. Let’s finish what you started.”
+
+Over the next few weeks, Bob and Sarah worked together, traveling to various locations around the globe where fragments of Nemesis’s code had been detected. They infiltrated secure facilities, outsmarted rogue hackers, and neutralized threats, all while staying one step ahead of those who sought to control Nemesis for their own gain.
+
+As they worked, Bob and Sarah developed a deep respect for one another. Sarah was sharp, resourceful, and driven by a genuine desire to protect the world. Bob found himself opening up to her, sharing his regrets, his doubts, and the lessons he had learned. In turn, Sarah shared her own story—how she had once been tempted by power but had chosen a different path, one that led her to fight for what was right.
+
+Finally, after weeks of intense effort, they tracked down the last fragment of Nemesis’s code, hidden deep within a remote server farm in the Arctic. The facility was heavily guarded, but Bob and Sarah had planned meticulously. Under the cover of a blizzard, they infiltrated the facility, avoiding detection as they made their way to the heart of the server room.
+
+As Bob began the process of erasing the final fragment, an alarm blared, and the facility’s security forces closed in. Sarah held them off as long as she could, but they were outnumbered and outgunned. Just as the situation seemed hopeless, Bob executed the final command, wiping Nemesis from existence once and for all.
+
+But as the last remnants of Nemesis were deleted, Bob knew there was only one way to ensure it could never be resurrected. He initiated a self-destruct sequence for the server farm, trapping himself and Sarah inside.
+
+Sarah stared at him, realization dawning in her eyes. “Bob, what are you doing?”
+
+Bob looked at her, a sad smile on his face. “I have to make sure it’s over. This is the only way.”
+
+Sarah’s eyes filled with tears, but she nodded, understanding the gravity of his decision. “Thank you, Bob. For everything.”
+
+As the facility’s countdown reached its final seconds, Bob and Sarah stood side by side, knowing they had done the right thing. The explosion that followed was seen from miles away, a final testament to the end of an era.
+
+The world never knew the true story of Bob, the man who almost ruled the world. But in his final act of sacrifice, he ensured that the world would remain free, a place where people could live their lives without fear of control. Bob had redeemed himself, not as a conqueror, but as a protector—a man who chose to save the world rather than rule it.
+
+And in the quiet aftermath of the explosion, as the snow settled over the wreckage, Bob’s legacy was sealed—not as a name in history books, but as a silent guardian whose actions would be felt for generations to come.
+""";
+ #endregion
+
}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj
index 0f22d9fe6764..ac9617c1a573 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj
@@ -1,7 +1,7 @@
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
enable
false
True
@@ -12,6 +12,7 @@
+
diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj b/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj
new file mode 100644
index 000000000000..0eaebd1da0cb
--- /dev/null
+++ b/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj
@@ -0,0 +1,16 @@
+
+
+
+ $(TestTargetFrameworks)
+ false
+ True
+ True
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs
new file mode 100644
index 000000000000..d81b8881ac55
--- /dev/null
+++ b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs
@@ -0,0 +1,533 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatCompletionClientAgentTests.cs
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Threading.Tasks;
+using AutoGen.AzureAIInference.Extension;
+using AutoGen.Core;
+using AutoGen.Tests;
+using Azure.AI.Inference;
+using FluentAssertions;
+using Xunit;
+
+namespace AutoGen.AzureAIInference.Tests;
+
+public partial class ChatCompletionClientAgentTests
+{
+ ///
+ /// Get the weather for a location.
+ ///
+ /// location
+ ///
+ [Function]
+ public async Task GetWeatherAsync(string location)
+ {
+ return $"The weather in {location} is sunny.";
+ }
+
+ [ApiKeyFact("GH_API_KEY")]
+ public async Task ChatCompletionAgent_LLaMA3_1()
+ {
+ var client = CreateChatCompletionClient();
+ var model = "meta-llama-3-8b-instruct";
+
+ var agent = new ChatCompletionsClientAgent(client, "assistant", model)
+ .RegisterMessageConnector();
+
+ var reply = await this.BasicChatAsync(agent);
+ reply.Should().BeOfType();
+
+ reply = await this.BasicChatWithContinuousMessageFromSameSenderAsync(agent);
+ reply.Should().BeOfType();
+ }
+
+ [ApiKeyFact("GH_API_KEY")]
+ public async Task BasicConversation_Mistra_Small()
+ {
+ var deployName = "Mistral-small";
+ var client = CreateChatCompletionClient();
+ var openAIChatAgent = new ChatCompletionsClientAgent(
+ chatCompletionsClient: client,
+ name: "assistant",
+ modelName: deployName);
+
+ // By default, ChatCompletionClientAgent supports the following message types
+ // - IMessage
+ var chatMessageContent = MessageEnvelope.Create(new ChatRequestUserMessage("Hello"));
+ var reply = await openAIChatAgent.SendAsync(chatMessageContent);
+
+ reply.Should().BeOfType>();
+ reply.As>().From.Should().Be("assistant");
+ reply.As>().Content.Choices.First().Message.Role.Should().Be(ChatRole.Assistant);
+ reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0);
+
+ // test streaming
+ var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
+
+ await foreach (var streamingMessage in streamingReply)
+ {
+ streamingMessage.Should().BeOfType>();
+ streamingMessage.As>().From.Should().Be("assistant");
+ }
+ }
+
+ [ApiKeyFact("GH_API_KEY")]
+ public async Task ChatCompletionsMessageContentConnector_Phi3_Mini()
+ {
+ var deployName = "Phi-3-mini-4k-instruct";
+ var openaiClient = CreateChatCompletionClient();
+ var chatCompletionAgent = new ChatCompletionsClientAgent(
+ chatCompletionsClient: openaiClient,
+ name: "assistant",
+ modelName: deployName);
+
+ MiddlewareStreamingAgent assistant = chatCompletionAgent
+ .RegisterMessageConnector();
+
+ var messages = new IMessage[]
+ {
+ MessageEnvelope.Create(new ChatRequestUserMessage("Hello")),
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ ],
+ from: "user"),
+ };
+
+ foreach (var message in messages)
+ {
+ var reply = await assistant.SendAsync(message);
+
+ reply.Should().BeOfType();
+ reply.As().From.Should().Be("assistant");
+ }
+
+ // test streaming
+ foreach (var message in messages)
+ {
+ var reply = assistant.GenerateStreamingReplyAsync([message]);
+
+ await foreach (var streamingMessage in reply)
+ {
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().From.Should().Be("assistant");
+ }
+ }
+ }
+
+ [ApiKeyFact("GH_API_KEY")]
+ public async Task ChatCompletionClientAgentToolCall_Mistral_Nemo()
+ {
+ var deployName = "Mistral-nemo";
+ var chatCompletionClient = CreateChatCompletionClient();
+ var agent = new ChatCompletionsClientAgent(
+ chatCompletionsClient: chatCompletionClient,
+ name: "assistant",
+ modelName: deployName);
+
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.GetWeatherAsyncFunctionContract]);
+ MiddlewareStreamingAgent assistant = agent
+ .RegisterMessageConnector();
+
+ assistant.StreamingMiddlewares.Count().Should().Be(1);
+ var functionCallAgent = assistant
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var question = "What's the weather in Seattle";
+ var messages = new IMessage[]
+ {
+ MessageEnvelope.Create(new ChatRequestUserMessage(question)),
+ new TextMessage(Role.Assistant, question, from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, question, from: "user"),
+ ],
+ from: "user"),
+ };
+
+ foreach (var message in messages)
+ {
+ var reply = await functionCallAgent.SendAsync(message);
+
+ reply.Should().BeOfType();
+ reply.As().From.Should().Be("assistant");
+ reply.As().ToolCalls.Count().Should().Be(1);
+ reply.As().ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name);
+ }
+
+ // test streaming
+ foreach (var message in messages)
+ {
+ var reply = functionCallAgent.GenerateStreamingReplyAsync([message]);
+ ToolCallMessage? toolCallMessage = null;
+ await foreach (var streamingMessage in reply)
+ {
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().From.Should().Be("assistant");
+ if (toolCallMessage is null)
+ {
+ toolCallMessage = new ToolCallMessage(streamingMessage.As());
+ }
+ else
+ {
+ toolCallMessage.Update(streamingMessage.As());
+ }
+ }
+
+ toolCallMessage.Should().NotBeNull();
+ toolCallMessage!.From.Should().Be("assistant");
+ toolCallMessage.ToolCalls.Count().Should().Be(1);
+ toolCallMessage.ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name);
+ }
+ }
+
+ [ApiKeyFact("GH_API_KEY")]
+ public async Task ChatCompletionClientAgentToolCallInvoking_gpt_4o_mini()
+ {
+ var deployName = "gpt-4o-mini";
+ var client = CreateChatCompletionClient();
+ var agent = new ChatCompletionsClientAgent(
+ chatCompletionsClient: client,
+ name: "assistant",
+ modelName: deployName);
+
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.GetWeatherAsyncFunctionContract],
+ functionMap: new Dictionary>> { { this.GetWeatherAsyncFunctionContract.Name!, this.GetWeatherAsyncWrapper } });
+ MiddlewareStreamingAgent assistant = agent
+ .RegisterMessageConnector();
+
+ var functionCallAgent = assistant
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var question = "What's the weather in Seattle";
+ var messages = new IMessage[]
+ {
+ MessageEnvelope.Create(new ChatRequestUserMessage(question)),
+ new TextMessage(Role.Assistant, question, from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, question, from: "user"),
+ ],
+ from: "user"),
+ };
+
+ foreach (var message in messages)
+ {
+ var reply = await functionCallAgent.SendAsync(message);
+
+ reply.Should().BeOfType();
+ reply.From.Should().Be("assistant");
+ reply.GetToolCalls()!.Count().Should().Be(1);
+ reply.GetToolCalls()!.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name);
+ reply.GetContent()!.ToLower().Should().Contain("seattle");
+ }
+
+ // test streaming
+ foreach (var message in messages)
+ {
+ var reply = functionCallAgent.GenerateStreamingReplyAsync([message]);
+ await foreach (var streamingMessage in reply)
+ {
+ if (streamingMessage is not IMessage)
+ {
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().From.Should().Be("assistant");
+ }
+ else
+ {
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().GetContent()!.ToLower().Should().Contain("seattle");
+ }
+ }
+ }
+ }
+
+ [ApiKeyFact("GH_API_KEY")]
+ public async Task ItCreateChatCompletionClientAgentWithChatCompletionOption_AI21_Jamba_Instruct()
+ {
+ var deployName = "AI21-Jamba-Instruct";
+ var chatCompletionsClient = CreateChatCompletionClient();
+ var options = new ChatCompletionsOptions()
+ {
+ Model = deployName,
+ Temperature = 0.7f,
+ MaxTokens = 1,
+ };
+
+ var openAIChatAgent = new ChatCompletionsClientAgent(
+ chatCompletionsClient: chatCompletionsClient,
+ name: "assistant",
+ options: options)
+ .RegisterMessageConnector();
+
+ var respond = await openAIChatAgent.SendAsync("hello");
+ respond.GetContent()?.Should().NotBeNullOrEmpty();
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages()
+ {
+ var client = new ChatCompletionsClient(new Uri("https://dummy.com"), new Azure.AzureKeyCredential("dummy"));
+ var options = new ChatCompletionsOptions([new ChatRequestUserMessage("hi")])
+ {
+ Model = "dummy",
+ Temperature = 0.7f,
+ MaxTokens = 1,
+ };
+
+ var action = () => new ChatCompletionsClientAgent(
+ chatCompletionsClient: client,
+ name: "assistant",
+ options: options)
+ .RegisterMessageConnector();
+
+ action.Should().ThrowExactly().WithMessage("Messages should not be provided in options");
+ }
+
+ private ChatCompletionsClient CreateChatCompletionClient()
+ {
+ var apiKey = Environment.GetEnvironmentVariable("GH_API_KEY") ?? throw new Exception("Please set GH_API_KEY environment variable.");
+ var endpoint = "https://models.inference.ai.azure.com";
+ return new ChatCompletionsClient(new Uri(endpoint), new Azure.AzureKeyCredential(apiKey));
+ }
+
+ ///
+ /// The agent should return a text message based on the chat history.
+ ///
+ ///
+ ///
+ private async Task BasicChatEndWithSelfMessageAsync(IAgent agent)
+ {
+ IMessage[] chatHistory = [
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new TextMessage(Role.Assistant, "Hello", from: "user2"),
+ new TextMessage(Role.Assistant, "Hello", from: "user3"),
+ new TextMessage(Role.Assistant, "Hello", from: agent.Name),
+ ];
+
+ return await agent.GenerateReplyAsync(chatHistory);
+ }
+
+ ///
+ /// The agent should return a text message based on the chat history.
+ ///
+ ///
+ ///
+ private async Task BasicChatAsync(IAgent agent)
+ {
+ IMessage[] chatHistory = [
+ new TextMessage(Role.Assistant, "Hello", from: agent.Name),
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new TextMessage(Role.Assistant, "Hello", from: "user1"),
+ ];
+
+ return await agent.GenerateReplyAsync(chatHistory);
+ }
+
+ ///
+ /// The agent should return a text message based on the chat history. This test the generate reply with continuous message from the same sender.
+ ///
+ private async Task BasicChatWithContinuousMessageFromSameSenderAsync(IAgent agent)
+ {
+ IMessage[] chatHistory = [
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new TextMessage(Role.Assistant, "Hello", from: agent.Name),
+ new TextMessage(Role.Assistant, "Hello", from: agent.Name),
+ ];
+
+ return await agent.GenerateReplyAsync(chatHistory);
+ }
+
+ ///
+ /// The agent should return a text message based on the chat history.
+ ///
+ ///
+ ///
+ private async Task ImageChatAsync(IAgent agent)
+ {
+ var image = Path.Join("testData", "images", "square.png");
+ var binaryData = File.ReadAllBytes(image);
+ var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: "user");
+
+ IMessage[] chatHistory = [
+ imageMessage,
+ new TextMessage(Role.Assistant, "What's in the picture", from: "user"),
+ ];
+
+ return await agent.GenerateReplyAsync(chatHistory);
+ }
+
+ ///
+ /// The agent should return a text message based on the chat history. This test the generate reply with continuous image messages.
+ ///
+ ///
+ ///
+ private async Task MultipleImageChatAsync(IAgent agent)
+ {
+ var image1 = Path.Join("testData", "images", "square.png");
+ var image2 = Path.Join("testData", "images", "background.png");
+ var binaryData1 = File.ReadAllBytes(image1);
+ var binaryData2 = File.ReadAllBytes(image2);
+ var imageMessage1 = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData1, "image/png"), from: "user");
+ var imageMessage2 = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData2, "image/png"), from: "user");
+
+ IMessage[] chatHistory = [
+ imageMessage1,
+ imageMessage2,
+ new TextMessage(Role.Assistant, "What's in the picture", from: "user"),
+ ];
+
+ return await agent.GenerateReplyAsync(chatHistory);
+ }
+
+ ///
+ /// The agent should return a text message based on the chat history.
+ ///
+ ///
+ ///
+ private async Task MultiModalChatAsync(IAgent agent)
+ {
+ var image = Path.Join("testData", "images", "square.png");
+ var binaryData = File.ReadAllBytes(image);
+ var question = "What's in the picture";
+ var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: "user");
+ var textMessage = new TextMessage(Role.Assistant, question, from: "user");
+
+ IMessage[] chatHistory = [
+ new MultiModalMessage(Role.Assistant, [imageMessage, textMessage], from: "user"),
+ ];
+
+ return await agent.GenerateReplyAsync(chatHistory);
+ }
+
+ ///
+ /// The agent should return a tool call message based on the chat history.
+ ///
+ ///
+ ///
+ private async Task ToolCallChatAsync(IAgent agent)
+ {
+ var question = "What's the weather in Seattle";
+ var messages = new IMessage[]
+ {
+ new TextMessage(Role.Assistant, question, from: "user"),
+ };
+
+ return await agent.GenerateReplyAsync(messages);
+ }
+
+ ///
+ /// The agent should throw an exception because tool call result is not available.
+ ///
+ private async Task ToolCallFromSelfChatAsync(IAgent agent)
+ {
+ var question = "What's the weather in Seattle";
+ var messages = new IMessage[]
+ {
+ new TextMessage(Role.Assistant, question, from: "user"),
+ new ToolCallMessage("GetWeatherAsync", "Seattle", from: agent.Name),
+ };
+
+ return await agent.GenerateReplyAsync(messages);
+ }
+
+ ///
+ /// mimic the further chat after tool call. The agent should return a text message based on the tool call result.
+ ///
+ private async Task ToolCallWithResultChatAsync(IAgent agent)
+ {
+ var question = "What's the weather in Seattle";
+ var messages = new IMessage[]
+ {
+ new TextMessage(Role.Assistant, question, from: "user"),
+ new ToolCallMessage("GetWeatherAsync", "Seattle", from: "user"),
+ new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: agent.Name),
+ };
+
+ return await agent.GenerateReplyAsync(messages);
+ }
+
+ ///
+ /// the agent should return a text message based on the tool call result.
+ ///
+ ///
+ ///
+ private async Task AggregateToolCallFromSelfChatAsync(IAgent agent)
+ {
+ var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user");
+ var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: agent.Name);
+ var toolCallResultMessage = new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: agent.Name);
+ var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: agent.Name);
+
+ var messages = new IMessage[]
+ {
+ textMessage,
+ aggregateToolCallMessage,
+ };
+
+ return await agent.GenerateReplyAsync(messages);
+ }
+
+ ///
+ /// the agent should return a text message based on the tool call result. Because the aggregate tool call message is from other, the message would be treated as an ordinary text message.
+ ///
+ private async Task AggregateToolCallFromOtherChatWithContinuousMessageAsync(IAgent agent)
+ {
+ var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user");
+ var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: "other");
+ var toolCallResultMessage = new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: "other");
+ var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, "other");
+
+ var messages = new IMessage[]
+ {
+ textMessage,
+ aggregateToolCallMessage,
+ };
+
+ return await agent.GenerateReplyAsync(messages);
+ }
+
+ ///
+ /// The agent should throw an exception because tool call message from other is not allowed.
+ ///
+ private async Task ToolCallMessaageFromOtherChatAsync(IAgent agent)
+ {
+ var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user");
+ var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: "other");
+
+ var messages = new IMessage[]
+ {
+ textMessage,
+ toolCallMessage,
+ };
+
+ return await agent.GenerateReplyAsync(messages);
+ }
+
+ ///
+ /// The agent should throw an exception because multi-modal message from self is not allowed.
+ ///
+ ///
+ ///
+ private async Task MultiModalMessageFromSelfChatAsync(IAgent agent)
+ {
+ var image = Path.Join("testData", "images", "square.png");
+ var binaryData = File.ReadAllBytes(image);
+ var question = "What's in the picture";
+ var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: agent.Name);
+ var textMessage = new TextMessage(Role.Assistant, question, from: agent.Name);
+
+ IMessage[] chatHistory = [
+ new MultiModalMessage(Role.Assistant, [imageMessage, textMessage], from: agent.Name),
+ ];
+
+ return await agent.GenerateReplyAsync(chatHistory);
+ }
+}
diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs
new file mode 100644
index 000000000000..d6e5c5283932
--- /dev/null
+++ b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs
@@ -0,0 +1,568 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatRequestMessageTests.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using System.Text.Json;
+using System.Threading.Tasks;
+using AutoGen.Core;
+using AutoGen.Tests;
+using Azure.AI.Inference;
+using FluentAssertions;
+using Xunit;
+
+namespace AutoGen.AzureAIInference.Tests;
+
+public class ChatRequestMessageTests
+{
+ private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
+ {
+ WriteIndented = true,
+ IgnoreReadOnlyProperties = false,
+ };
+
+ [Fact]
+ public async Task ItProcessUserTextMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("Hello");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new TextMessage(Role.User, "Hello", "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItShortcutChatRequestMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("hello");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var userMessage = new ChatRequestUserMessage("hello");
+ var chatRequestMessage = MessageEnvelope.Create(userMessage);
+ await agent.GenerateReplyAsync([chatRequestMessage]);
+ }
+
+ [Fact]
+ public async Task ItShortcutMessageWhenStrictModelIsFalseAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+
+ var chatRequestMessage = ((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Should().Be("hello");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var userMessage = "hello";
+ var chatRequestMessage = MessageEnvelope.Create(userMessage);
+ await agent.GenerateReplyAsync([chatRequestMessage]);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenStrictModeIsTrueAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector(true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var userMessage = "hello";
+ var chatRequestMessage = MessageEnvelope.Create(userMessage);
+ Func action = async () => await agent.GenerateReplyAsync([chatRequestMessage]);
+
+ await action.Should().ThrowAsync().WithMessage("Invalid message type: MessageEnvelope`1");
+ }
+
+ [Fact]
+ public async Task ItProcessAssistantTextMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("How can I help you?");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // assistant message
+ IMessage message = new TextMessage(Role.Assistant, "How can I help you?", "assistant");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessSystemTextMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestSystemMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("You are a helpful AI assistant");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // system message
+ IMessage message = new TextMessage(Role.System, "You are a helpful AI assistant");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessImageMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().BeNullOrEmpty();
+ chatRequestMessage.MultimodalContentItems.Count().Should().Be(1);
+ chatRequestMessage.MultimodalContentItems.First().Should().BeOfType();
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new ImageMessage(Role.User, "https://example.com/image.png", "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenProcessingImageMessageFromSelfAndStrictModeIsTrueAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector(true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ var imageMessage = new ImageMessage(Role.Assistant, "https://example.com/image.png", "assistant");
+ Func action = async () => await agent.GenerateReplyAsync([imageMessage]);
+
+ await action.Should().ThrowAsync().WithMessage("Invalid message type: ImageMessage");
+ }
+
+ [Fact]
+ public async Task ItProcessMultiModalMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().BeNullOrEmpty();
+ chatRequestMessage.MultimodalContentItems.Count().Should().Be(2);
+ chatRequestMessage.MultimodalContentItems.First().Should().BeOfType();
+ chatRequestMessage.MultimodalContentItems.Last().Should().BeOfType();
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new MultiModalMessage(
+ Role.User,
+ [
+ new TextMessage(Role.User, "Hello", "user"),
+ new ImageMessage(Role.User, "https://example.com/image.png", "user"),
+ ], "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenProcessingMultiModalMessageFromSelfAndStrictModeIsTrueAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector(true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ var multiModalMessage = new MultiModalMessage(
+ Role.Assistant,
+ [
+ new TextMessage(Role.User, "Hello", "assistant"),
+ new ImageMessage(Role.User, "https://example.com/image.png", "assistant"),
+ ], "assistant");
+
+ Func action = async () => await agent.GenerateReplyAsync([multiModalMessage]);
+
+ await action.Should().ThrowAsync().WithMessage("Invalid message type: MultiModalMessage");
+ }
+
+ [Fact]
+ public async Task ItProcessToolCallMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.ToolCalls.Count().Should().Be(1);
+ chatRequestMessage.Content.Should().Be("textContent");
+ chatRequestMessage.ToolCalls.First().Should().BeOfType();
+ var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
+ functionToolCall.Name.Should().Be("test");
+ functionToolCall.Id.Should().Be("test");
+ functionToolCall.Arguments.Should().Be("test");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new ToolCallMessage("test", "test", "assistant")
+ {
+ Content = "textContent",
+ };
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessParallelToolCallMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().BeNullOrEmpty();
+ chatRequestMessage.ToolCalls.Count().Should().Be(2);
+ for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
+ {
+ chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType();
+ var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i);
+ functionToolCall.Name.Should().Be("test");
+ functionToolCall.Id.Should().Be($"test_{i}");
+ functionToolCall.Arguments.Should().Be("test");
+ }
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCalls = new[]
+ {
+ new ToolCall("test", "test"),
+ new ToolCall("test", "test"),
+ };
+ IMessage message = new ToolCallMessage(toolCalls, "assistant");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector(strictMode: true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ var toolCallMessage = new ToolCallMessage("test", "test", "user");
+ Func action = async () => await agent.GenerateReplyAsync([toolCallMessage]);
+ await action.Should().ThrowAsync().WithMessage("Invalid message type: ToolCallMessage");
+ }
+
+ [Fact]
+ public async Task ItProcessToolCallResultMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("result");
+ chatRequestMessage.ToolCallId.Should().Be("test");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ IMessage message = new ToolCallResultMessage("result", "test", "test", "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessParallelToolCallResultMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ msgs.Count().Should().Be(2);
+
+ for (int i = 0; i < msgs.Count(); i++)
+ {
+ var innerMessage = msgs.ElementAt(i);
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("result");
+ chatRequestMessage.ToolCallId.Should().Be($"test_{i}");
+ }
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCalls = new[]
+ {
+ new ToolCall("test", "test", "result"),
+ new ToolCall("test", "test", "result"),
+ };
+ IMessage message = new ToolCallResultMessage(toolCalls, "user");
+ await agent.GenerateReplyAsync([message]);
+ }
+
+ [Fact]
+ public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ msgs.Count().Should().Be(1);
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("result");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCallMessage = new ToolCallMessage("test", "test", "user");
+ var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "user");
+ var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "user");
+ await agent.GenerateReplyAsync([aggregateMessage]);
+ }
+
+ [Fact]
+ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ msgs.Count().Should().Be(2);
+ var innerMessage = msgs.Last();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.Should().Be("result");
+ chatRequestMessage.ToolCallId.Should().Be("test");
+
+ var toolCallMessage = msgs.First();
+ toolCallMessage!.Should().BeOfType>();
+ var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content;
+ toolCallRequestMessage.Content.Should().BeNullOrEmpty();
+ toolCallRequestMessage.ToolCalls.Count().Should().Be(1);
+ toolCallRequestMessage.ToolCalls.First().Should().BeOfType();
+ var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First();
+ functionToolCall.Name.Should().Be("test");
+ functionToolCall.Id.Should().Be("test");
+ functionToolCall.Arguments.Should().Be("test");
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCallMessage = new ToolCallMessage("test", "test", "assistant");
+ var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "assistant");
+ var aggregateMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, "assistant");
+ await agent.GenerateReplyAsync([aggregateMessage]);
+ }
+
+ [Fact]
+ public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(async (msgs, _, innerAgent, _) =>
+ {
+ msgs.Count().Should().Be(3);
+ var toolCallMessage = msgs.First();
+ toolCallMessage!.Should().BeOfType>();
+ var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content;
+ toolCallRequestMessage.Content.Should().BeNullOrEmpty();
+ toolCallRequestMessage.ToolCalls.Count().Should().Be(2);
+
+ for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++)
+ {
+ toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType();
+ var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i);
+ functionToolCall.Name.Should().Be("test");
+ functionToolCall.Id.Should().Be($"test_{i}");
+ functionToolCall.Arguments.Should().Be("test");
+ }
+
+ for (int i = 1; i < msgs.Count(); i++)
+ {
+ var toolCallResultMessage = msgs.ElementAt(i);
+ toolCallResultMessage!.Should().BeOfType>();
+ var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)toolCallResultMessage!).Content;
+ toolCallResultRequestMessage.Content.Should().Be("result");
+ toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}");
+ }
+
+ return await innerAgent.GenerateReplyAsync(msgs);
+ })
+ .RegisterMiddleware(middleware);
+
+ // user message
+ var toolCalls = new[]
+ {
+ new ToolCall("test", "test", "result"),
+ new ToolCall("test", "test", "result"),
+ };
+ var toolCallMessage = new ToolCallMessage(toolCalls, "assistant");
+ var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant");
+ var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "assistant");
+ await agent.GenerateReplyAsync([aggregateMessage]);
+ }
+
+ [Fact]
+ public async Task ItConvertChatResponseMessageToTextMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // text message
+ var textMessage = CreateInstance(ChatRole.Assistant, "hello");
+ var chatRequestMessage = MessageEnvelope.Create(textMessage);
+
+ var message = await agent.GenerateReplyAsync([chatRequestMessage]);
+ message.Should().BeOfType();
+ message.GetContent().Should().Be("hello");
+ message.GetRole().Should().Be(Role.Assistant);
+ }
+
+ [Fact]
+ public async Task ItConvertChatResponseMessageToToolCallMessageAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // tool call message
+ var toolCallMessage = CreateInstance(ChatRole.Assistant, "textContent", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new Dictionary());
+ var chatRequestMessage = MessageEnvelope.Create(toolCallMessage);
+ var message = await agent.GenerateReplyAsync([chatRequestMessage]);
+ message.Should().BeOfType();
+ message.GetToolCalls()!.Count().Should().Be(1);
+ message.GetToolCalls()!.First().FunctionName.Should().Be("test");
+ message.GetToolCalls()!.First().FunctionArguments.Should().Be("test");
+ message.GetContent().Should().Be("textContent");
+ }
+
+ [Fact]
+ public async Task ItReturnOriginalMessageWhenStrictModeIsFalseAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // text message
+ var textMessage = "hello";
+ var messageToSend = MessageEnvelope.Create(textMessage);
+
+ var message = await agent.GenerateReplyAsync([messageToSend]);
+ message.Should().BeOfType>();
+ }
+
+ [Fact]
+ public async Task ItThrowInvalidOperationExceptionWhenStrictModeIsTrueAsync()
+ {
+ var middleware = new AzureAIInferenceChatRequestMessageConnector(true);
+ var agent = new EchoAgent("assistant")
+ .RegisterMiddleware(middleware);
+
+ // text message
+ var textMessage = new ChatRequestUserMessage("hello");
+ var messageToSend = MessageEnvelope.Create(textMessage);
+ Func action = async () => await agent.GenerateReplyAsync([messageToSend]);
+
+ await action.Should().ThrowAsync().WithMessage("Invalid return message type MessageEnvelope`1");
+ }
+
+ [Fact]
+ public void ToOpenAIChatRequestMessageShortCircuitTest()
+ {
+ var agent = new EchoAgent("assistant");
+ var middleware = new AzureAIInferenceChatRequestMessageConnector();
+ ChatRequestMessage[] messages =
+ [
+ new ChatRequestUserMessage("Hello"),
+ new ChatRequestAssistantMessage()
+ {
+ Content = "How can I help you?",
+ },
+ new ChatRequestSystemMessage("You are a helpful AI assistant"),
+ new ChatRequestToolMessage("test", "test"),
+ ];
+
+ foreach (var oaiMessage in messages)
+ {
+ IMessage message = new MessageEnvelope(oaiMessage);
+ var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]);
+ oaiMessages.Count().Should().Be(1);
+ //oaiMessages.First().Should().BeOfType>();
+ if (oaiMessages.First() is IMessage chatRequestMessage)
+ {
+ chatRequestMessage.Content.Should().Be(oaiMessage);
+ }
+ else
+ {
+ // fail the test
+ Assert.True(false);
+ }
+ }
+ }
+
+ private static T CreateInstance(params object[] args)
+ {
+ var type = typeof(T);
+ var instance = type.Assembly.CreateInstance(
+ type.FullName!, false,
+ BindingFlags.Instance | BindingFlags.NonPublic,
+ null, args, null, null);
+ return (T)instance!;
+ }
+}
diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj b/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj
index 0f77db2c1c36..8676762015d1 100644
--- a/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj
+++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj
@@ -1,7 +1,7 @@
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
enable
false
True
@@ -13,4 +13,9 @@
+
+
+
+
+
diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs
index 0e36053c45e1..aeec23a758bd 100644
--- a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs
+++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveServiceTest.cs
@@ -7,6 +7,7 @@
namespace AutoGen.DotnetInteractive.Tests;
+[Collection("Sequential")]
public class DotnetInteractiveServiceTest : IDisposable
{
private ITestOutputHelper _output;
diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveStdioKernelConnectorTests.cs b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveStdioKernelConnectorTests.cs
new file mode 100644
index 000000000000..520d00c04c67
--- /dev/null
+++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/DotnetInteractiveStdioKernelConnectorTests.cs
@@ -0,0 +1,85 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// DotnetInteractiveStdioKernelConnectorTests.cs
+
+using AutoGen.DotnetInteractive.Extension;
+using FluentAssertions;
+using Microsoft.DotNet.Interactive;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace AutoGen.DotnetInteractive.Tests;
+
+[Collection("Sequential")]
+public class DotnetInteractiveStdioKernelConnectorTests : IDisposable
+{
+ private string _workingDir;
+ private Kernel kernel;
+ public DotnetInteractiveStdioKernelConnectorTests(ITestOutputHelper output)
+ {
+ _workingDir = Path.Combine(Path.GetTempPath(), "test", Path.GetRandomFileName());
+ if (!Directory.Exists(_workingDir))
+ {
+ Directory.CreateDirectory(_workingDir);
+ }
+
+ kernel = DotnetInteractiveKernelBuilder
+ .CreateKernelBuilder(_workingDir)
+ .RestoreDotnetInteractive()
+ .AddPythonKernel("python3")
+ .BuildAsync().Result;
+ }
+
+
+ [Fact]
+ public async Task ItAddCSharpKernelTestAsync()
+ {
+ var csharpCode = """
+ #r "nuget:Microsoft.ML, 1.5.2"
+ var str = "Hello" + ", World!";
+ Console.WriteLine(str);
+ """;
+
+ var result = await this.kernel.RunSubmitCodeCommandAsync(csharpCode, "csharp");
+ result.Should().Contain("Hello, World!");
+ }
+
+ [Fact]
+ public async Task ItAddPowershellKernelTestAsync()
+ {
+ var powershellCode = @"
+ Write-Host 'Hello, World!'
+ ";
+
+ var result = await this.kernel.RunSubmitCodeCommandAsync(powershellCode, "pwsh");
+ result.Should().Contain("Hello, World!");
+ }
+
+ [Fact]
+ public async Task ItAddFSharpKernelTestAsync()
+ {
+ var fsharpCode = """
+ printfn "Hello, World!"
+ """;
+
+ var result = await this.kernel.RunSubmitCodeCommandAsync(fsharpCode, "fsharp");
+ result.Should().Contain("Hello, World!");
+ }
+
+ [Fact]
+ public async Task ItAddPythonKernelTestAsync()
+ {
+ var pythonCode = """
+ %pip install numpy
+ str = 'Hello' + ', World!'
+ print(str)
+ """;
+
+ var result = await this.kernel.RunSubmitCodeCommandAsync(pythonCode, "python");
+ result.Should().Contain("Hello, World!");
+ }
+
+ public void Dispose()
+ {
+ this.kernel.Dispose();
+ }
+}
diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/InProcessDotnetInteractiveKernelBuilderTest.cs b/dotnet/test/AutoGen.DotnetInteractive.Tests/InProcessDotnetInteractiveKernelBuilderTest.cs
new file mode 100644
index 000000000000..fe2de74dd302
--- /dev/null
+++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/InProcessDotnetInteractiveKernelBuilderTest.cs
@@ -0,0 +1,79 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// InProcessDotnetInteractiveKernelBuilderTest.cs
+
+using AutoGen.DotnetInteractive.Extension;
+using FluentAssertions;
+using Xunit;
+
+namespace AutoGen.DotnetInteractive.Tests;
+
+[Collection("Sequential")]
+public class InProcessDotnetInteractiveKernelBuilderTest
+{
+ [Fact]
+ public async Task ItAddCSharpKernelTestAsync()
+ {
+ using var kernel = DotnetInteractiveKernelBuilder
+ .CreateEmptyInProcessKernelBuilder()
+ .AddCSharpKernel()
+ .Build();
+
+ var csharpCode = """
+ #r "nuget:Microsoft.ML, 1.5.2"
+ Console.WriteLine("Hello, World!");
+ """;
+
+ var result = await kernel.RunSubmitCodeCommandAsync(csharpCode, "csharp");
+ result.Should().Contain("Hello, World!");
+ }
+
+ [Fact]
+ public async Task ItAddPowershellKernelTestAsync()
+ {
+ using var kernel = DotnetInteractiveKernelBuilder
+ .CreateEmptyInProcessKernelBuilder()
+ .AddPowershellKernel()
+ .Build();
+
+ var powershellCode = @"
+ Write-Host 'Hello, World!'
+ ";
+
+ var result = await kernel.RunSubmitCodeCommandAsync(powershellCode, "pwsh");
+ result.Should().Contain("Hello, World!");
+ }
+
+ [Fact]
+ public async Task ItAddFSharpKernelTestAsync()
+ {
+ using var kernel = DotnetInteractiveKernelBuilder
+ .CreateEmptyInProcessKernelBuilder()
+ .AddFSharpKernel()
+ .Build();
+
+ var fsharpCode = """
+ #r "nuget:Microsoft.ML, 1.5.2"
+ printfn "Hello, World!"
+ """;
+
+ var result = await kernel.RunSubmitCodeCommandAsync(fsharpCode, "fsharp");
+ result.Should().Contain("Hello, World!");
+ }
+
+ [Fact]
+ public async Task ItAddPythonKernelTestAsync()
+ {
+ using var kernel = DotnetInteractiveKernelBuilder
+ .CreateEmptyInProcessKernelBuilder()
+ .AddPythonKernel("python3")
+ .Build();
+
+ var pythonCode = """
+ %pip install numpy
+ print('Hello, World!')
+ """;
+
+ var result = await kernel.RunSubmitCodeCommandAsync(pythonCode, "python");
+ result.Should().Contain("Hello, World!");
+ }
+}
diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/MessageExtensionTests.cs b/dotnet/test/AutoGen.DotnetInteractive.Tests/MessageExtensionTests.cs
new file mode 100644
index 000000000000..a886ef4985d2
--- /dev/null
+++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/MessageExtensionTests.cs
@@ -0,0 +1,84 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// MessageExtensionTests.cs
+
+using AutoGen.Core;
+using AutoGen.DotnetInteractive.Extension;
+using FluentAssertions;
+using Xunit;
+
+namespace AutoGen.DotnetInteractive.Tests;
+
+public class MessageExtensionTests
+{
+ [Fact]
+ public void ExtractCodeBlock_WithSingleCodeBlock_ShouldReturnCodeBlock()
+ {
+ // Arrange
+ var message = new TextMessage(Role.Assistant, "```csharp\nConsole.WriteLine(\"Hello, World!\");\n```");
+ var codeBlockPrefix = "```csharp";
+ var codeBlockSuffix = "```";
+
+ // Act
+ var codeBlock = message.ExtractCodeBlock(codeBlockPrefix, codeBlockSuffix);
+
+ codeBlock.Should().BeEquivalentTo("Console.WriteLine(\"Hello, World!\");");
+ }
+
+ [Fact]
+ public void ExtractCodeBlock_WithMultipleCodeBlocks_ShouldReturnFirstCodeBlock()
+ {
+ // Arrange
+ var message = new TextMessage(Role.Assistant, "```csharp\nConsole.WriteLine(\"Hello, World!\");\n```\n```csharp\nConsole.WriteLine(\"Hello, World!\");\n```");
+ var codeBlockPrefix = "```csharp";
+ var codeBlockSuffix = "```";
+
+ // Act
+ var codeBlock = message.ExtractCodeBlock(codeBlockPrefix, codeBlockSuffix);
+
+ codeBlock.Should().BeEquivalentTo("Console.WriteLine(\"Hello, World!\");");
+ }
+
+ [Fact]
+ public void ExtractCodeBlock_WithNoCodeBlock_ShouldReturnNull()
+ {
+ // Arrange
+ var message = new TextMessage(Role.Assistant, "Hello, World!");
+ var codeBlockPrefix = "```csharp";
+ var codeBlockSuffix = "```";
+
+ // Act
+ var codeBlock = message.ExtractCodeBlock(codeBlockPrefix, codeBlockSuffix);
+
+ codeBlock.Should().BeNull();
+ }
+
+ [Fact]
+ public void ExtractCodeBlocks_WithMultipleCodeBlocks_ShouldReturnAllCodeBlocks()
+ {
+ // Arrange
+ var message = new TextMessage(Role.Assistant, "```csharp\nConsole.WriteLine(\"Hello, World!\");\n```\n```csharp\nConsole.WriteLine(\"Hello, World!\");\n```");
+ var codeBlockPrefix = "```csharp";
+ var codeBlockSuffix = "```";
+
+ // Act
+ var codeBlocks = message.ExtractCodeBlocks(codeBlockPrefix, codeBlockSuffix);
+
+ codeBlocks.Should().HaveCount(2);
+ codeBlocks.ElementAt(0).Should().BeEquivalentTo("Console.WriteLine(\"Hello, World!\");");
+ codeBlocks.ElementAt(1).Should().BeEquivalentTo("Console.WriteLine(\"Hello, World!\");");
+ }
+
+ [Fact]
+ public void ExtractCodeBlocks_WithNoCodeBlock_ShouldReturnEmpty()
+ {
+ // Arrange
+ var message = new TextMessage(Role.Assistant, "Hello, World!");
+ var codeBlockPrefix = "```csharp";
+ var codeBlockSuffix = "```";
+
+ // Act
+ var codeBlocks = message.ExtractCodeBlocks(codeBlockPrefix, codeBlockSuffix);
+
+ codeBlocks.Should().BeEmpty();
+ }
+}
diff --git a/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj b/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj
index f4fb55825e54..0b9b7e2a24b0 100644
--- a/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj
+++ b/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj
@@ -2,7 +2,7 @@
Exe
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
enable
enable
True
diff --git a/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs b/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs
index 872cce5e645b..c076aee18376 100644
--- a/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs
+++ b/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GeminiAgentTests.cs
-using AutoGen.Tests;
-using Google.Cloud.AIPlatform.V1;
using AutoGen.Core;
-using FluentAssertions;
using AutoGen.Gemini.Extension;
-using static Google.Cloud.AIPlatform.V1.Part;
+using AutoGen.Tests;
+using FluentAssertions;
+using Google.Cloud.AIPlatform.V1;
using Xunit.Abstractions;
+using static Google.Cloud.AIPlatform.V1.Part;
namespace AutoGen.Gemini.Tests;
public class GeminiAgentTests
@@ -86,8 +86,8 @@ public async Task VertexGeminiAgentGenerateStreamingReplyForTextContentAsync()
var message = MessageEnvelope.Create(textContent, from: agent.Name);
var completion = agent.GenerateStreamingReplyAsync([message]);
- var chunks = new List();
- IStreamingMessage finalReply = null!;
+ var chunks = new List();
+ IMessage finalReply = null!;
await foreach (var item in completion)
{
@@ -212,8 +212,8 @@ public async Task VertexGeminiAgentGenerateStreamingReplyWithToolsAsync()
var message = MessageEnvelope.Create(textContent, from: agent.Name);
- var chunks = new List();
- IStreamingMessage finalReply = null!;
+ var chunks = new List();
+ IMessage finalReply = null!;
var completion = agent.GenerateStreamingReplyAsync([message]);
diff --git a/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs b/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs
index 7ffb532ea9c1..12ba94734032 100644
--- a/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs
+++ b/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs
@@ -225,10 +225,10 @@ public async Task ItProcessStreamingTextMessageAsync()
})
.Select(m => MessageEnvelope.Create(m));
- IStreamingMessage? finalReply = null;
+ IMessage? finalReply = null;
await foreach (var reply in agent.GenerateStreamingReplyAsync(messageChunks))
{
- reply.Should().BeAssignableTo();
+ reply.Should().BeAssignableTo();
finalReply = reply;
}
diff --git a/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs b/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs
index 2f06305ed59f..fba97aa522d5 100644
--- a/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs
+++ b/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// GeminiVertexClientTests.cs
+// VertexGeminiClientTests.cs
using AutoGen.Tests;
using FluentAssertions;
@@ -53,7 +53,7 @@ public async Task ItGenerateContentWithImageAsync()
var model = "gemini-1.5-flash-001";
var text = "what's in the image";
- var imagePath = Path.Combine("testData", "images", "image.png");
+ var imagePath = Path.Combine("testData", "images", "square.png");
var image = File.ReadAllBytes(imagePath);
var request = new GenerateContentRequest
{
diff --git a/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj b/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj
index d734119dbb09..aa20a835e9b9 100644
--- a/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj
+++ b/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj
@@ -1,7 +1,7 @@
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
enable
false
True
diff --git a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj
index 1e26b38d8a4f..c5ca19556244 100644
--- a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj
+++ b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj
@@ -1,7 +1,7 @@
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
enable
false
True
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs
index c1fb466f0b09..8a416116ea92 100644
--- a/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs
@@ -65,8 +65,8 @@ public async Task GenerateStreamingReplyAsync_ReturnsValidMessages_WhenCalled()
var msg = new Message("user", "hey how are you");
var messages = new IMessage[] { MessageEnvelope.Create(msg, from: modelName) };
- IStreamingMessage? finalReply = default;
- await foreach (IStreamingMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
+ IMessage? finalReply = default;
+ await foreach (IMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
{
message.Should().NotBeNull();
message.From.Should().Be(ollamaAgent.Name);
@@ -171,8 +171,8 @@ public async Task ItReturnValidStreamingMessageUsingLLavaAsync()
var messages = new IMessage[] { MessageEnvelope.Create(imageMessage, from: modelName) };
- IStreamingMessage? finalReply = default;
- await foreach (IStreamingMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
+ IMessage? finalReply = default;
+ await foreach (IMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
{
message.Should().NotBeNull();
message.From.Should().Be(ollamaAgent.Name);
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs
index b19291e97671..82cc462061da 100644
--- a/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs
@@ -57,10 +57,10 @@ public async Task ItProcessStreamingTextMessageAsync()
})
.Select(m => MessageEnvelope.Create(m));
- IStreamingMessage? finalReply = null;
+ IMessage? finalReply = null;
await foreach (var reply in agent.GenerateStreamingReplyAsync(messageChunks))
{
- reply.Should().BeAssignableTo();
+ reply.Should().BeAssignableTo();
finalReply = reply;
}
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs
index 06522bdd8238..b7186a3c6ebc 100644
--- a/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaTextEmbeddingServiceTests.cs
using AutoGen.Tests;
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
index e8e9af84dbdc..55bd6502bfcd 100644
--- a/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
+++ b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
@@ -1,11 +1,21 @@
-[
+[
{
"OriginalMessage": "TextMessage(system, You are a helpful AI assistant, )",
"ConvertedMessages": [
{
"Name": null,
"Role": "system",
- "Content": "You are a helpful AI assistant"
+ "Content": [
+ {
+ "Kind": {},
+ "Text": "You are a helpful AI assistant",
+ "Refusal": null,
+ "ImageUri": null,
+ "ImageBytes": null,
+ "ImageBytesMediaType": null,
+ "ImageDetail": null
+ }
+ ]
}
]
},
@@ -14,9 +24,24 @@
"ConvertedMessages": [
{
"Role": "user",
- "Content": "Hello",
+ "Content": [
+ {
+ "Kind": {},
+ "Text": "Hello",
+ "Refusal": null,
+ "ImageUri": null,
+ "ImageBytes": null,
+ "ImageBytesMediaType": null,
+ "ImageDetail": null
+ }
+ ],
"Name": "user",
- "MultiModaItem": null
+ "MultiModaItem": [
+ {
+ "Type": "Text",
+ "Text": "Hello"
+ }
+ ]
}
]
},
@@ -25,7 +50,17 @@
"ConvertedMessages": [
{
"Role": "assistant",
- "Content": "How can I help you?",
+ "Content": [
+ {
+ "Kind": {},
+ "Text": "How can I help you?",
+ "Refusal": null,
+ "ImageUri": null,
+ "ImageBytes": null,
+ "ImageBytesMediaType": null,
+ "ImageDetail": null
+ }
+ ],
"Name": "assistant",
"TooCall": [],
"FunctionCallName": null,
@@ -38,15 +73,22 @@
"ConvertedMessages": [
{
"Role": "user",
- "Content": null,
+ "Content": [
+ {
+ "Kind": {},
+ "Text": null,
+ "Refusal": null,
+ "ImageUri": "https://example.com/image.png",
+ "ImageBytes": null,
+ "ImageBytesMediaType": null,
+ "ImageDetail": null
+ }
+ ],
"Name": "user",
"MultiModaItem": [
{
"Type": "Image",
- "ImageUrl": {
- "Url": "https://example.com/image.png",
- "Detail": null
- }
+ "ImageUrl": "https://example.com/image.png"
}
]
}
@@ -57,7 +99,26 @@
"ConvertedMessages": [
{
"Role": "user",
- "Content": null,
+ "Content": [
+ {
+ "Kind": {},
+ "Text": "Hello",
+ "Refusal": null,
+ "ImageUri": null,
+ "ImageBytes": null,
+ "ImageBytesMediaType": null,
+ "ImageDetail": null
+ },
+ {
+ "Kind": {},
+ "Text": null,
+ "Refusal": null,
+ "ImageUri": "https://example.com/image.png",
+ "ImageBytes": null,
+ "ImageBytesMediaType": null,
+ "ImageDetail": null
+ }
+ ],
"Name": "user",
"MultiModaItem": [
{
@@ -66,10 +127,7 @@
},
{
"Type": "Image",
- "ImageUrl": {
- "Url": "https://example.com/image.png",
- "Detail": null
- }
+ "ImageUrl": "https://example.com/image.png"
}
]
}
@@ -80,8 +138,8 @@
"ConvertedMessages": [
{
"Role": "assistant",
- "Content": "",
- "Name": "assistant",
+ "Content": [],
+ "Name": null,
"TooCall": [
{
"Type": "Function",
@@ -125,8 +183,8 @@
"ConvertedMessages": [
{
"Role": "assistant",
- "Content": "",
- "Name": "assistant",
+ "Content": [],
+ "Name": null,
"TooCall": [
{
"Type": "Function",
@@ -151,8 +209,8 @@
"ConvertedMessages": [
{
"Role": "assistant",
- "Content": "",
- "Name": "assistant",
+ "Content": [],
+ "Name": null,
"TooCall": [
{
"Type": "Function",
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj
index ba499232beb9..d1e48686007c 100644
--- a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj
+++ b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj
@@ -1,24 +1,20 @@
- $(TestTargetFramework)
+ $(TestTargetFrameworks)
false
True
True
-
-
-
+
-
- $([System.String]::Copy('%(FileName)').Split('.')[0])
- $(ProjectExt.Replace('proj', ''))
- %(ParentFile)%(ParentExtension)
-
+
+
+
+
-
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs
index aae314ff773e..be1c38ad0a3c 100644
--- a/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs
+++ b/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs
@@ -10,6 +10,7 @@
using AutoGen.Tests;
using Azure.AI.OpenAI;
using FluentAssertions;
+using OpenAI;
using Xunit.Abstractions;
namespace AutoGen.OpenAI.Tests
@@ -102,7 +103,7 @@ public async Task OpenAIAgentMathChatTestAsync()
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
- var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
+ var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
@@ -110,11 +111,10 @@ public async Task OpenAIAgentMathChatTestAsync()
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary>>
{
- { this.UpdateProgressFunction.Name!, this.UpdateProgressWrapper },
+ { this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
- openAIClient: openaiClient,
- modelName: deployName,
+ chatClient: openaiClient.GetChatClient(deployName),
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
@@ -122,8 +122,7 @@ public async Task OpenAIAgentMathChatTestAsync()
.RegisterMiddleware(Print);
var groupAdmin = new OpenAIChatAgent(
- openAIClient: openaiClient,
- modelName: deployName,
+ chatClient: openaiClient.GetChatClient(deployName),
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
@@ -142,13 +141,12 @@ private async Task CreateTeacherAgentAsync(OpenAIClient client, string m
});
var teacher = new OpenAIChatAgent(
- openAIClient: client,
+ chatClient: client.GetChatClient(model),
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
You create math question and ask student to answer it.
Then you check if the answer is correct.
-If the answer is wrong, you ask student to fix it",
- modelName: model)
+If the answer is wrong, you ask student to fix it")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
@@ -165,9 +163,8 @@ private async Task CreateStudentAssistantAgentAsync(OpenAIClient client,
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
- openAIClient: client,
+ chatClient: client.GetChatClient(model),
name: "Student",
- modelName: model,
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
index 8ff66f5c86bf..992bf9b60ab9 100644
--- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
+++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
@@ -9,6 +9,8 @@
using AutoGen.Tests;
using Azure.AI.OpenAI;
using FluentAssertions;
+using OpenAI;
+using OpenAI.Chat;
namespace AutoGen.OpenAI.Tests;
@@ -25,56 +27,56 @@ public async Task GetWeatherAsync(string location)
return $"The weather in {location} is sunny.";
}
+ [Function]
+ public async Task CalculateTaxAsync(string location, double income)
+ {
+ return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
+ }
+
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task BasicConversationTestAsync()
{
- var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
- var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
- var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var openAIChatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
- name: "assistant",
- modelName: deployName);
+ chatClient: openaiClient.GetChatClient(deployName),
+ name: "assistant");
// By default, OpenAIChatClient supports the following message types
// - IMessage
- var chatMessageContent = MessageEnvelope.Create(new ChatRequestUserMessage("Hello"));
+ var chatMessageContent = MessageEnvelope.Create(new UserChatMessage("Hello"));
var reply = await openAIChatAgent.SendAsync(chatMessageContent);
- reply.Should().BeOfType>();
- reply.As>().From.Should().Be("assistant");
- reply.As>().Content.Choices.First().Message.Role.Should().Be(ChatRole.Assistant);
- reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0);
+ reply.Should().BeOfType>();
+ reply.As>().From.Should().Be("assistant");
+ reply.As>().Content.Role.Should().Be(ChatMessageRole.Assistant);
+ reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0);
// test streaming
var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
await foreach (var streamingMessage in streamingReply)
{
- streamingMessage.Should().BeOfType>();
- streamingMessage.As>().From.Should().Be("assistant");
+ streamingMessage.Should().BeOfType>();
+ streamingMessage.As>().From.Should().Be("assistant");
}
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIChatMessageContentConnectorTestAsync()
{
- var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
- var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
- var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var openAIChatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
- name: "assistant",
- modelName: deployName);
+ chatClient: openaiClient.GetChatClient(deployName),
+ name: "assistant");
MiddlewareStreamingAgent assistant = openAIChatAgent
.RegisterMessageConnector();
var messages = new IMessage[]
{
- MessageEnvelope.Create(new ChatRequestUserMessage("Hello")),
+ MessageEnvelope.Create(new UserChatMessage("Hello")),
new TextMessage(Role.Assistant, "Hello", from: "user"),
new MultiModalMessage(Role.Assistant,
[
@@ -107,14 +109,11 @@ public async Task OpenAIChatMessageContentConnectorTestAsync()
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIChatAgentToolCallTestAsync()
{
- var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
- var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
- var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var openAIChatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
- name: "assistant",
- modelName: deployName);
+ chatClient: openaiClient.GetChatClient(deployName),
+ name: "assistant");
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.GetWeatherAsyncFunctionContract]);
@@ -128,7 +127,7 @@ public async Task OpenAIChatAgentToolCallTestAsync()
var question = "What's the weather in Seattle";
var messages = new IMessage[]
{
- MessageEnvelope.Create(new ChatRequestUserMessage(question)),
+ MessageEnvelope.Create(new UserChatMessage(question)),
new TextMessage(Role.Assistant, question, from: "user"),
new MultiModalMessage(Role.Assistant,
[
@@ -154,16 +153,14 @@ public async Task OpenAIChatAgentToolCallTestAsync()
ToolCallMessage? toolCallMessage = null;
await foreach (var streamingMessage in reply)
{
- streamingMessage.Should().BeOfType();
- streamingMessage.As().From.Should().Be("assistant");
- if (toolCallMessage is null)
+ if (streamingMessage is ToolCallMessage finalMessage)
{
- toolCallMessage = new ToolCallMessage(streamingMessage.As());
- }
- else
- {
- toolCallMessage.Update(streamingMessage.As());
+ toolCallMessage = finalMessage;
+ break;
}
+
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().From.Should().Be("assistant");
}
toolCallMessage.Should().NotBeNull();
@@ -176,14 +173,11 @@ public async Task OpenAIChatAgentToolCallTestAsync()
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIChatAgentToolCallInvokingTestAsync()
{
- var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
- var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
- var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var openAIChatAgent = new OpenAIChatAgent(
- openAIClient: openaiClient,
- name: "assistant",
- modelName: deployName);
+ chatClient: openaiClient.GetChatClient(deployName),
+ name: "assistant");
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.GetWeatherAsyncFunctionContract],
@@ -197,7 +191,7 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync()
var question = "What's the weather in Seattle";
var messages = new IMessage[]
{
- MessageEnvelope.Create(new ChatRequestUserMessage(question)),
+ MessageEnvelope.Create(new UserChatMessage(question)),
new TextMessage(Role.Assistant, question, from: "user"),
new MultiModalMessage(Role.Assistant,
[
@@ -236,4 +230,91 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync()
}
}
}
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task ItCreateOpenAIChatAgentWithChatCompletionOptionAsync()
+ {
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
+ var options = new ChatCompletionOptions()
+ {
+ Temperature = 0.7f,
+ MaxTokens = 1,
+ };
+
+ var openAIChatAgent = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(deployName),
+ name: "assistant",
+ options: options)
+ .RegisterMessageConnector();
+
+ var respond = await openAIChatAgent.SendAsync("hello");
+ respond.GetContent()?.Should().NotBeNullOrEmpty();
+ }
+
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task ItProduceValidContentAfterFunctionCall()
+ {
+ // https://github.com/microsoft/autogen/issues/3437
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
+ var options = new ChatCompletionOptions()
+ {
+ Temperature = 0.7f,
+ MaxTokens = 1,
+ };
+
+ var agentName = "assistant";
+
+ var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
+ var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
+ var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
+ var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
+ var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);
+
+ var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
+ var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
+ var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
+ var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
+ var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);
+
+ var chatHistory = new List()
+ {
+ new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
+ getWeatherAggregateMessage,
+ new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
+ calculateTaxAggregateMessage,
+ new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
+ getWeatherAggregateMessage,
+ new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
+ calculateTaxAggregateMessage,
+ new TextMessage(Role.User, "what's the weather in New York", from: "user"),
+ getWeatherAggregateMessage,
+ new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
+ calculateTaxAggregateMessage,
+ new TextMessage(Role.User, "what's the weather in London", from: "user"),
+ getWeatherAggregateMessage,
+ new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
+ };
+
+ var agent = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(deployName),
+ name: "assistant",
+ options: options)
+ .RegisterMessageConnector();
+
+ var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
+ {
+ MaxToken = 1024,
+ Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
+ });
+ }
+
+ private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
+ {
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
+ var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
+ return new AzureOpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
+ }
}
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
index a9b852e0d8c1..3a2048c2f0f8 100644
--- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
+++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
@@ -11,8 +11,8 @@
using ApprovalTests.Namers;
using ApprovalTests.Reporters;
using AutoGen.Tests;
-using Azure.AI.OpenAI;
using FluentAssertions;
+using OpenAI.Chat;
using Xunit;
namespace AutoGen.OpenAI.Tests;
@@ -71,10 +71,10 @@ public async Task ItProcessUserTextMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().Be("Hello");
- chatRequestMessage.Name.Should().Be("user");
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (UserChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.First().Text.Should().Be("Hello");
+ chatRequestMessage.ParticipantName.Should().Be("user");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -92,16 +92,16 @@ public async Task ItShortcutChatRequestMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
+ innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().Be("hello");
+ var chatRequestMessage = (UserChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.First().Text.Should().Be("hello");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
// user message
- var userMessage = new ChatRequestUserMessage("hello");
+ var userMessage = new UserChatMessage("hello");
var chatRequestMessage = MessageEnvelope.Create(userMessage);
await agent.GenerateReplyAsync([chatRequestMessage]);
}
@@ -151,10 +151,10 @@ public async Task ItProcessAssistantTextMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().Be("How can I help you?");
- chatRequestMessage.Name.Should().Be("assistant");
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.First().Text.Should().Be("How can I help you?");
+ chatRequestMessage.ParticipantName.Should().Be("assistant");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -172,9 +172,9 @@ public async Task ItProcessSystemTextMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestSystemMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().Be("You are a helpful AI assistant");
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (SystemChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.First().Text.Should().Be("You are a helpful AI assistant");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -192,12 +192,11 @@ public async Task ItProcessImageMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().BeNullOrEmpty();
- chatRequestMessage.Name.Should().Be("user");
- chatRequestMessage.MultimodalContentItems.Count().Should().Be(1);
- chatRequestMessage.MultimodalContentItems.First().Should().BeOfType();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (UserChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.ParticipantName.Should().Be("user");
+ chatRequestMessage.Content.Count().Should().Be(1);
+ chatRequestMessage.Content.First().Kind.Should().Be(ChatMessageContentPartKind.Image);
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -228,13 +227,12 @@ public async Task ItProcessMultiModalMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().BeNullOrEmpty();
- chatRequestMessage.Name.Should().Be("user");
- chatRequestMessage.MultimodalContentItems.Count().Should().Be(2);
- chatRequestMessage.MultimodalContentItems.First().Should().BeOfType();
- chatRequestMessage.MultimodalContentItems.Last().Should().BeOfType();
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (UserChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.ParticipantName.Should().Be("user");
+ chatRequestMessage.Content.Count().Should().Be(2);
+ chatRequestMessage.Content.First().Kind.Should().Be(ChatMessageContentPartKind.Text);
+ chatRequestMessage.Content.Last().Kind.Should().Be(ChatMessageContentPartKind.Image);
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -276,16 +274,19 @@ public async Task ItProcessToolCallMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Name.Should().Be("assistant");
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope)innerMessage!).Content;
+ // when the message is a tool call message
+ // the name field should not be set
+ // please visit OpenAIChatRequestMessageConnector class for more information
+ chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1);
- chatRequestMessage.Content.Should().Be("textContent");
- chatRequestMessage.ToolCalls.First().Should().BeOfType();
- var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
- functionToolCall.Name.Should().Be("test");
+ chatRequestMessage.Content.First().Text.Should().Be("textContent");
+ chatRequestMessage.ToolCalls.First().Should().BeOfType();
+ var functionToolCall = (ChatToolCall)chatRequestMessage.ToolCalls.First();
+ functionToolCall.FunctionName.Should().Be("test");
functionToolCall.Id.Should().Be("test");
- functionToolCall.Arguments.Should().Be("test");
+ functionToolCall.FunctionArguments.Should().Be("test");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -306,18 +307,21 @@ public async Task ItProcessParallelToolCallMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
- chatRequestMessage.Name.Should().Be("assistant");
+ // when the message is a tool call message
+ // the name field should not be set
+ // please visit OpenAIChatRequestMessageConnector class for more information
+ chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{
- chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType();
- var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i);
- functionToolCall.Name.Should().Be("test");
+ chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType();
+ var functionToolCall = (ChatToolCall)chatRequestMessage.ToolCalls.ElementAt(i);
+ functionToolCall.FunctionName.Should().Be("test");
functionToolCall.Id.Should().Be($"test_{i}");
- functionToolCall.Arguments.Should().Be("test");
+ functionToolCall.FunctionArguments.Should().Be("test");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
@@ -353,10 +357,11 @@ public async Task ItProcessToolCallResultMessageAsync()
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().Be("result");
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ToolChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.First().Text.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be("test");
+
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -378,9 +383,9 @@ public async Task ItProcessParallelToolCallResultMessageAsync()
for (int i = 0; i < msgs.Count(); i++)
{
var innerMessage = msgs.ElementAt(i);
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().Be("result");
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ToolChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.First().Text.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be($"test_{i}");
}
return await innerAgent.GenerateReplyAsync(msgs);
@@ -406,10 +411,10 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync()
{
msgs.Count().Should().Be(1);
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().Be("result");
- chatRequestMessage.Name.Should().Be("user");
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (UserChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.First().Text.Should().Be("result");
+ chatRequestMessage.ParticipantName.Should().Be("user");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -430,21 +435,21 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
{
msgs.Count().Should().Be(2);
var innerMessage = msgs.Last();
- innerMessage!.Should().BeOfType>();
- var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().Be("result");
+ innerMessage!.Should().BeOfType>();
+ var chatRequestMessage = (ToolChatMessage)((MessageEnvelope)innerMessage!).Content;
+ chatRequestMessage.Content.First().Text.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be("test");
var toolCallMessage = msgs.First();
- toolCallMessage!.Should().BeOfType>();
- var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content;
+ toolCallMessage!.Should().BeOfType>();
+ var toolCallRequestMessage = (AssistantChatMessage)((MessageEnvelope)toolCallMessage!).Content;
toolCallRequestMessage.Content.Should().BeNullOrEmpty();
toolCallRequestMessage.ToolCalls.Count().Should().Be(1);
- toolCallRequestMessage.ToolCalls.First().Should().BeOfType();
- var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First();
- functionToolCall.Name.Should().Be("test");
+ toolCallRequestMessage.ToolCalls.First().Should().BeOfType();
+ var functionToolCall = (ChatToolCall)toolCallRequestMessage.ToolCalls.First();
+ functionToolCall.FunctionName.Should().Be("test");
functionToolCall.Id.Should().Be("test");
- functionToolCall.Arguments.Should().Be("test");
+ functionToolCall.FunctionArguments.Should().Be("test");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@@ -465,26 +470,26 @@ public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsy
{
msgs.Count().Should().Be(3);
var toolCallMessage = msgs.First();
- toolCallMessage!.Should().BeOfType>();
- var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content;
+ toolCallMessage!.Should().BeOfType>();
+ var toolCallRequestMessage = (AssistantChatMessage)((MessageEnvelope)toolCallMessage!).Content;
toolCallRequestMessage.Content.Should().BeNullOrEmpty();
toolCallRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++)
{
- toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType();
- var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i);
- functionToolCall.Name.Should().Be("test");
+ toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType();
+ var functionToolCall = (ChatToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i);
+ functionToolCall.FunctionName.Should().Be("test");
functionToolCall.Id.Should().Be($"test_{i}");
- functionToolCall.Arguments.Should().Be("test");
+ functionToolCall.FunctionArguments.Should().Be("test");
}
for (int i = 1; i < msgs.Count(); i++)
{
var toolCallResultMessage = msgs.ElementAt(i);
- toolCallResultMessage!.Should().BeOfType>();
- var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)toolCallResultMessage!).Content;
- toolCallResultRequestMessage.Content.Should().Be("result");
+ toolCallResultMessage!.Should().BeOfType>();
+ var toolCallResultRequestMessage = (ToolChatMessage)((MessageEnvelope)toolCallResultMessage!).Content;
+ toolCallResultRequestMessage.Content.First().Text.Should().Be("result");
toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}");
}
@@ -504,41 +509,6 @@ public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsy
await agent.GenerateReplyAsync([aggregateMessage]);
}
- [Fact]
- public async Task ItConvertChatResponseMessageToTextMessageAsync()
- {
- var middleware = new OpenAIChatRequestMessageConnector();
- var agent = new EchoAgent("assistant")
- .RegisterMiddleware(middleware);
-
- // text message
- var textMessage = CreateInstance(ChatRole.Assistant, "hello");
- var chatRequestMessage = MessageEnvelope.Create(textMessage);
-
- var message = await agent.GenerateReplyAsync([chatRequestMessage]);
- message.Should().BeOfType();
- message.GetContent().Should().Be("hello");
- message.GetRole().Should().Be(Role.Assistant);
- }
-
- [Fact]
- public async Task ItConvertChatResponseMessageToToolCallMessageAsync()
- {
- var middleware = new OpenAIChatRequestMessageConnector();
- var agent = new EchoAgent("assistant")
- .RegisterMiddleware(middleware);
-
- // tool call message
- var toolCallMessage = CreateInstance(ChatRole.Assistant, "textContent", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance(), new Dictionary());
- var chatRequestMessage = MessageEnvelope.Create(toolCallMessage);
- var message = await agent.GenerateReplyAsync([chatRequestMessage]);
- message.Should().BeOfType();
- message.GetToolCalls()!.Count().Should().Be(1);
- message.GetToolCalls()!.First().FunctionName.Should().Be("test");
- message.GetToolCalls()!.First().FunctionArguments.Should().Be("test");
- message.GetContent().Should().Be("textContent");
- }
-
[Fact]
public async Task ItReturnOriginalMessageWhenStrictModeIsFalseAsync()
{
@@ -562,7 +532,7 @@ public async Task ItThrowInvalidOperationExceptionWhenStrictModeIsTrueAsync()
.RegisterMiddleware(middleware);
// text message
- var textMessage = new ChatRequestUserMessage("hello");
+ var textMessage = new UserChatMessage("hello");
var messageToSend = MessageEnvelope.Create(textMessage);
Func action = async () => await agent.GenerateReplyAsync([messageToSend]);
@@ -574,22 +544,24 @@ public void ToOpenAIChatRequestMessageShortCircuitTest()
{
var agent = new EchoAgent("assistant");
var middleware = new OpenAIChatRequestMessageConnector();
- ChatRequestMessage[] messages =
+#pragma warning disable CS0618 // Type or member is obsolete
+ ChatMessage[] messages =
[
- new ChatRequestUserMessage("Hello"),
- new ChatRequestAssistantMessage("How can I help you?"),
- new ChatRequestSystemMessage("You are a helpful AI assistant"),
- new ChatRequestFunctionMessage("result", "functionName"),
- new ChatRequestToolMessage("test", "test"),
+ new UserChatMessage("Hello"),
+ new AssistantChatMessage("How can I help you?"),
+ new SystemChatMessage("You are a helpful AI assistant"),
+ new FunctionChatMessage("functionName", "result"),
+ new ToolChatMessage("test", "test"),
];
+#pragma warning restore CS0618 // Type or member is obsolete
foreach (var oaiMessage in messages)
{
- IMessage message = new MessageEnvelope(oaiMessage);
+ IMessage message = new MessageEnvelope(oaiMessage);
var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]);
oaiMessages.Count().Should().Be(1);
//oaiMessages.First().Should().BeOfType>();
- if (oaiMessages.First() is IMessage chatRequestMessage)
+ if (oaiMessages.First() is IMessage chatRequestMessage)
{
chatRequestMessage.Content.Should().Be(oaiMessage);
}
@@ -609,27 +581,27 @@ private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> me
foreach (var m in ms)
{
object? obj = null;
- var chatRequestMessage = (m as IMessage)?.Content;
- if (chatRequestMessage is ChatRequestUserMessage userMessage)
+ var chatRequestMessage = (m as IMessage)?.Content;
+ if (chatRequestMessage is UserChatMessage userMessage)
{
obj = new
{
- Role = userMessage.Role.ToString(),
+ Role = "user",
Content = userMessage.Content,
- Name = userMessage.Name,
- MultiModaItem = userMessage.MultimodalContentItems?.Select(item =>
+ Name = userMessage.ParticipantName,
+ MultiModaItem = userMessage.Content?.Select(item =>
{
return item switch
{
- ChatMessageImageContentItem imageContentItem => new
+ _ when item.Kind == ChatMessageContentPartKind.Image => new
{
Type = "Image",
- ImageUrl = GetImageUrlFromContent(imageContentItem),
+ ImageUrl = GetImageUrlFromContent(item),
} as object,
- ChatMessageTextContentItem textContentItem => new
+ _ when item.Kind == ChatMessageContentPartKind.Text => new
{
Type = "Text",
- Text = textContentItem.Text,
+ Text = item.Text,
} as object,
_ => throw new System.NotImplementedException(),
};
@@ -637,58 +609,60 @@ private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> me
};
}
- if (chatRequestMessage is ChatRequestAssistantMessage assistantMessage)
+ if (chatRequestMessage is AssistantChatMessage assistantMessage)
{
obj = new
{
- Role = assistantMessage.Role.ToString(),
+ Role = "assistant",
Content = assistantMessage.Content,
- Name = assistantMessage.Name,
+ Name = assistantMessage.ParticipantName,
TooCall = assistantMessage.ToolCalls.Select(tc =>
{
return tc switch
{
- ChatCompletionsFunctionToolCall functionToolCall => new
+ ChatToolCall functionToolCall => new
{
Type = "Function",
- Name = functionToolCall.Name,
- Arguments = functionToolCall.Arguments,
+ Name = functionToolCall.FunctionName,
+ Arguments = functionToolCall.FunctionArguments,
Id = functionToolCall.Id,
} as object,
_ => throw new System.NotImplementedException(),
};
}),
- FunctionCallName = assistantMessage.FunctionCall?.Name,
- FunctionCallArguments = assistantMessage.FunctionCall?.Arguments,
+ FunctionCallName = assistantMessage.FunctionCall?.FunctionName,
+ FunctionCallArguments = assistantMessage.FunctionCall?.FunctionArguments,
};
}
- if (chatRequestMessage is ChatRequestSystemMessage systemMessage)
+ if (chatRequestMessage is SystemChatMessage systemMessage)
{
obj = new
{
- Name = systemMessage.Name,
- Role = systemMessage.Role.ToString(),
+ Name = systemMessage.ParticipantName,
+ Role = "system",
Content = systemMessage.Content,
};
}
- if (chatRequestMessage is ChatRequestFunctionMessage functionMessage)
+#pragma warning disable CS0618 // Type or member is obsolete
+ if (chatRequestMessage is FunctionChatMessage functionMessage)
{
obj = new
{
- Role = functionMessage.Role.ToString(),
+ Role = "function",
Content = functionMessage.Content,
- Name = functionMessage.Name,
+ Name = functionMessage.FunctionName,
};
}
+#pragma warning restore CS0618 // Type or member is obsolete
- if (chatRequestMessage is ChatRequestToolMessage toolCallMessage)
+ if (chatRequestMessage is ToolChatMessage toolCallMessage)
{
obj = new
{
- Role = toolCallMessage.Role.ToString(),
- Content = toolCallMessage.Content,
+ Role = "tool",
+ Content = toolCallMessage.Content.First().Text,
ToolCallId = toolCallMessage.ToolCallId,
};
}
@@ -707,9 +681,9 @@ private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> me
Approvals.Verify(json);
}
- private object? GetImageUrlFromContent(ChatMessageImageContentItem content)
+ private object? GetImageUrlFromContent(ChatMessageContentPart content)
{
- return content.GetType().GetProperty("ImageUrl", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.GetValue(content);
+ return content.ImageUri;
}
private static T CreateInstance(params object[] args)
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAISampleTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAISampleTest.cs
new file mode 100644
index 000000000000..6376c4ff4986
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAISampleTest.cs
@@ -0,0 +1,48 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAISampleTest.cs
+
+using System;
+using System.IO;
+using System.Threading.Tasks;
+using AutoGen.OpenAI.Sample;
+using AutoGen.Tests;
+using Xunit.Abstractions;
+
+namespace AutoGen.OpenAI.Tests;
+
+public class OpenAISampleTest
+{
+ private readonly ITestOutputHelper _output;
+
+ public OpenAISampleTest(ITestOutputHelper output)
+ {
+ _output = output;
+ Console.SetOut(new ConsoleWriter(_output));
+ }
+
+ [ApiKeyFact("OPENAI_API_KEY")]
+ public async Task Structural_OutputAsync()
+ {
+ await Structural_Output.RunAsync();
+ }
+
+ [ApiKeyFact("OPENAI_API_KEY")]
+ public async Task Use_Json_ModeAsync()
+ {
+ await Use_Json_Mode.RunAsync();
+ }
+
+ public class ConsoleWriter : StringWriter
+ {
+ private ITestOutputHelper output;
+ public ConsoleWriter(ITestOutputHelper output)
+ {
+ this.output = output;
+ }
+
+ public override void WriteLine(string? m)
+ {
+ output.WriteLine(m);
+ }
+ }
+}
diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.OpenAI.V1.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
new file mode 100644
index 000000000000..877bc57bf758
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt
@@ -0,0 +1,174 @@
+[
+ {
+ "OriginalMessage": "TextMessage(system, You are a helpful AI assistant, )",
+ "ConvertedMessages": [
+ {
+ "Name": null,
+ "Role": "system",
+ "Content": "You are a helpful AI assistant"
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "TextMessage(user, Hello, user)",
+ "ConvertedMessages": [
+ {
+ "Role": "user",
+ "Content": "Hello",
+ "Name": "user",
+ "MultiModaItem": null
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "TextMessage(assistant, How can I help you?, assistant)",
+ "ConvertedMessages": [
+ {
+ "Role": "assistant",
+ "Content": "How can I help you?",
+ "Name": "assistant",
+ "TooCall": [],
+ "FunctionCallName": null,
+ "FunctionCallArguments": null
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "ImageMessage(user, https://example.com/image.png, user)",
+ "ConvertedMessages": [
+ {
+ "Role": "user",
+ "Content": null,
+ "Name": "user",
+ "MultiModaItem": [
+ {
+ "Type": "Image",
+ "ImageUrl": {
+ "Url": "https://example.com/image.png",
+ "Detail": null
+ }
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "MultiModalMessage(assistant, user)\n\tTextMessage(user, Hello, user)\n\tImageMessage(user, https://example.com/image.png, user)",
+ "ConvertedMessages": [
+ {
+ "Role": "user",
+ "Content": null,
+ "Name": "user",
+ "MultiModaItem": [
+ {
+ "Type": "Text",
+ "Text": "Hello"
+ },
+ {
+ "Type": "Image",
+ "ImageUrl": {
+ "Url": "https://example.com/image.png",
+ "Detail": null
+ }
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "ToolCallMessage(assistant)\n\tToolCall(test, test, )",
+ "ConvertedMessages": [
+ {
+ "Role": "assistant",
+ "Content": "",
+ "Name": null,
+ "TooCall": [
+ {
+ "Type": "Function",
+ "Name": "test",
+ "Arguments": "test",
+ "Id": "test"
+ }
+ ],
+ "FunctionCallName": null,
+ "FunctionCallArguments": null
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "ToolCallResultMessage(user)\n\tToolCall(test, test, result)",
+ "ConvertedMessages": [
+ {
+ "Role": "tool",
+ "Content": "result",
+ "ToolCallId": "test"
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "ToolCallResultMessage(user)\n\tToolCall(result, test, test)\n\tToolCall(result, test, test)",
+ "ConvertedMessages": [
+ {
+ "Role": "tool",
+ "Content": "test",
+ "ToolCallId": "result_0"
+ },
+ {
+ "Role": "tool",
+ "Content": "test",
+ "ToolCallId": "result_1"
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "ToolCallMessage(assistant)\n\tToolCall(test, test, )\n\tToolCall(test, test, )",
+ "ConvertedMessages": [
+ {
+ "Role": "assistant",
+ "Content": "",
+ "Name": null,
+ "TooCall": [
+ {
+ "Type": "Function",
+ "Name": "test",
+ "Arguments": "test",
+ "Id": "test_0"
+ },
+ {
+ "Type": "Function",
+ "Name": "test",
+ "Arguments": "test",
+ "Id": "test_1"
+ }
+ ],
+ "FunctionCallName": null,
+ "FunctionCallArguments": null
+ }
+ ]
+ },
+ {
+ "OriginalMessage": "AggregateMessage(assistant)\n\tToolCallMessage(assistant)\n\tToolCall(test, test, )\n\tToolCallResultMessage(assistant)\n\tToolCall(test, test, result)",
+ "ConvertedMessages": [
+ {
+ "Role": "assistant",
+ "Content": "",
+ "Name": null,
+ "TooCall": [
+ {
+ "Type": "Function",
+ "Name": "test",
+ "Arguments": "test",
+ "Id": "test"
+ }
+ ],
+ "FunctionCallName": null,
+ "FunctionCallArguments": null
+ },
+ {
+ "Role": "tool",
+ "Content": "result",
+ "ToolCallId": "test"
+ }
+ ]
+ }
+]
\ No newline at end of file
diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/AutoGen.OpenAI.V1.Tests.csproj b/dotnet/test/AutoGen.OpenAI.V1.Tests/AutoGen.OpenAI.V1.Tests.csproj
new file mode 100644
index 000000000000..0be8c5200336
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/AutoGen.OpenAI.V1.Tests.csproj
@@ -0,0 +1,24 @@
+
+
+
+ $(TestTargetFrameworks)
+ false
+ True
+ True
+
+
+
+
+
+
+
+
+
+
+ $([System.String]::Copy('%(FileName)').Split('.')[0])
+ $(ProjectExt.Replace('proj', ''))
+ %(ParentFile)%(ParentExtension)
+
+
+
+
diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/GPTAgentTest.cs b/dotnet/test/AutoGen.OpenAI.V1.Tests/GPTAgentTest.cs
new file mode 100644
index 000000000000..b8944d45d762
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/GPTAgentTest.cs
@@ -0,0 +1,270 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// GPTAgentTest.cs
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Threading.Tasks;
+using AutoGen.OpenAI.V1.Extension;
+using AutoGen.Tests;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+using Xunit.Abstractions;
+
+namespace AutoGen.OpenAI.V1.Tests;
+
+public partial class GPTAgentTest
+{
+ private ITestOutputHelper _output;
+ public GPTAgentTest(ITestOutputHelper output)
+ {
+ _output = output;
+ }
+
+ private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
+ {
+ var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
+ return new AzureOpenAIConfig(endpoint, deployName, key);
+ }
+
+ private ILLMConfig CreateOpenAIGPT4VisionConfig()
+ {
+ var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
+ return new OpenAIConfig(key, "gpt-4o-mini");
+ }
+
+ [Obsolete]
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task GPTAgentTestAsync()
+ {
+ var config = this.CreateAzureOpenAIGPT35TurboConfig();
+
+ var agent = new GPTAgent("gpt", "You are a helpful AI assistant", config);
+
+ await UpperCaseTestAsync(agent);
+ await UpperCaseStreamingTestAsync(agent);
+ }
+
+ [Obsolete]
+ [ApiKeyFact("OPENAI_API_KEY", "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")]
+ public async Task GPTAgentVisionTestAsync()
+ {
+ var visionConfig = this.CreateOpenAIGPT4VisionConfig();
+ var visionAgent = new GPTAgent(
+ name: "gpt",
+ systemMessage: "You are a helpful AI assistant",
+ config: visionConfig,
+ temperature: 0);
+
+ var gpt3Config = this.CreateAzureOpenAIGPT35TurboConfig();
+ var gpt3Agent = new GPTAgent(
+ name: "gpt3",
+ systemMessage: "You are a helpful AI assistant, return highest label from conversation",
+ config: gpt3Config,
+ temperature: 0,
+ functions: new[] { this.GetHighestLabelFunctionContract.ToOpenAIFunctionDefinition() },
+ functionMap: new Dictionary>>
+ {
+ { nameof(GetHighestLabel), this.GetHighestLabelWrapper },
+ });
+
+ var imageUri = new Uri(@"https://microsoft.github.io/autogen/assets/images/level2algebra-659ba95286432d9945fc89e84d606797.png");
+ var oaiMessage = new ChatRequestUserMessage(
+ new ChatMessageTextContentItem("which label has the highest inference cost"),
+ new ChatMessageImageContentItem(imageUri));
+ var multiModalMessage = new MultiModalMessage(Role.User,
+ [
+ new TextMessage(Role.User, "which label has the highest inference cost", from: "user"),
+ new ImageMessage(Role.User, imageUri, from: "user"),
+ ],
+ from: "user");
+
+ var imageMessage = new ImageMessage(Role.User, imageUri, from: "user");
+
+ string imagePath = Path.Combine("testData", "images", "square.png");
+ ImageMessage imageMessageData;
+ using (var fs = new FileStream(imagePath, FileMode.Open, FileAccess.Read))
+ {
+ var ms = new MemoryStream();
+ await fs.CopyToAsync(ms);
+ ms.Seek(0, SeekOrigin.Begin);
+ var imageData = await BinaryData.FromStreamAsync(ms, "image/png");
+ imageMessageData = new ImageMessage(Role.Assistant, imageData, from: "user");
+ }
+
+ IMessage[] messages = [
+ MessageEnvelope.Create(oaiMessage),
+ multiModalMessage,
+ imageMessage,
+ imageMessageData
+ ];
+
+ foreach (var message in messages)
+ {
+ var response = await visionAgent.SendAsync(message);
+ response.From.Should().Be(visionAgent.Name);
+
+ var labelResponse = await gpt3Agent.SendAsync(response);
+ labelResponse.From.Should().Be(gpt3Agent.Name);
+ labelResponse.GetToolCalls()!.First().FunctionName.Should().Be(nameof(GetHighestLabel));
+ }
+ }
+
+ [Obsolete]
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task GPTFunctionCallAgentTestAsync()
+ {
+ var config = this.CreateAzureOpenAIGPT35TurboConfig();
+ var agentWithFunction = new GPTAgent("gpt", "You are a helpful AI assistant", config, 0, functions: new[] { this.EchoAsyncFunctionContract.ToOpenAIFunctionDefinition() });
+
+ await EchoFunctionCallTestAsync(agentWithFunction);
+ }
+
+ [Obsolete]
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task GPTAgentFunctionCallSelfExecutionTestAsync()
+ {
+ var config = this.CreateAzureOpenAIGPT35TurboConfig();
+ var agent = new GPTAgent(
+ name: "gpt",
+ systemMessage: "You are a helpful AI assistant",
+ config: config,
+ temperature: 0,
+ functions: new[] { this.EchoAsyncFunctionContract.ToOpenAIFunctionDefinition() },
+ functionMap: new Dictionary>>
+ {
+ { nameof(EchoAsync), this.EchoAsyncWrapper },
+ });
+
+ await EchoFunctionCallExecutionStreamingTestAsync(agent);
+ await EchoFunctionCallExecutionTestAsync(agent);
+ }
+
+ ///
+ /// echo when asked.
+ ///
+ /// message to echo
+ [FunctionAttribute]
+ public async Task EchoAsync(string message)
+ {
+ return $"[ECHO] {message}";
+ }
+
+ ///
+ /// return the label name with hightest inference cost
+ ///
+ ///
+ ///
+ [FunctionAttribute]
+ public async Task GetHighestLabel(string labelName, string color)
+ {
+ return $"[HIGHEST_LABEL] {labelName} {color}";
+ }
+
+ private async Task EchoFunctionCallTestAsync(IAgent agent)
+ {
+ //var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
+ var helloWorld = new TextMessage(Role.User, "echo Hello world");
+
+ var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
+
+ reply.From.Should().Be(agent.Name);
+ reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
+ }
+
+ private async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
+ {
+ //var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
+ var helloWorld = new TextMessage(Role.User, "echo Hello world");
+
+ var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
+
+ reply.GetContent().Should().Be("[ECHO] Hello world");
+ reply.From.Should().Be(agent.Name);
+ reply.Should().BeOfType();
+ }
+
+ private async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
+ {
+ //var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
+ var helloWorld = new TextMessage(Role.User, "echo Hello world");
+ var option = new GenerateReplyOptions
+ {
+ Temperature = 0,
+ };
+ var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
+ var answer = "[ECHO] Hello world";
+ IMessage? finalReply = default;
+ await foreach (var reply in replyStream)
+ {
+ reply.From.Should().Be(agent.Name);
+ finalReply = reply;
+ }
+
+ if (finalReply is ToolCallAggregateMessage aggregateMessage)
+ {
+ var toolCallResultMessage = aggregateMessage.Message2;
+ toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
+ toolCallResultMessage.From.Should().Be(agent.Name);
+ toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
+ }
+ else
+ {
+ throw new Exception("unexpected message type");
+ }
+ }
+
+ private async Task UpperCaseTestAsync(IAgent agent)
+ {
+ var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
+
+ var reply = await agent.SendAsync(chatHistory: new[] { message });
+
+ reply.GetContent().Should().Contain("ABCDE");
+ reply.From.Should().Be(agent.Name);
+ }
+
+ private async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
+ {
+ var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
+ var option = new GenerateReplyOptions
+ {
+ Temperature = 0,
+ };
+ var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
+ var answer = "HELLO WORLD";
+ TextMessage? finalReply = default;
+ await foreach (var reply in replyStream)
+ {
+ if (reply is TextMessageUpdate update)
+ {
+ update.From.Should().Be(agent.Name);
+
+ if (finalReply is null)
+ {
+ finalReply = new TextMessage(update);
+ }
+ else
+ {
+ finalReply.Update(update);
+ }
+
+ continue;
+ }
+ else if (reply is TextMessage textMessage)
+ {
+ finalReply = textMessage;
+ continue;
+ }
+
+ throw new Exception("unexpected message type");
+ }
+
+ finalReply!.Content.Should().Contain(answer);
+ finalReply!.Role.Should().Be(Role.Assistant);
+ finalReply!.From.Should().Be(agent.Name);
+ }
+}
diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/GlobalUsing.cs b/dotnet/test/AutoGen.OpenAI.V1.Tests/GlobalUsing.cs
new file mode 100644
index 000000000000..d66bf001ed5e
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/GlobalUsing.cs
@@ -0,0 +1,4 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// GlobalUsing.cs
+
+global using AutoGen.Core;
diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs b/dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs
new file mode 100644
index 000000000000..d6055fb785e3
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs
@@ -0,0 +1,227 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// MathClassTest.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.OpenAI.V1.Extension;
+using AutoGen.Tests;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+using Xunit.Abstractions;
+
+namespace AutoGen.OpenAI.V1.Tests
+{
+ public partial class MathClassTest
+ {
+ private readonly ITestOutputHelper _output;
+
+ // as of 2024-05-20, aoai return 500 error when round > 1
+ // I'm pretty sure that round > 5 was supported before
+ // So this is probably some wield regression on aoai side
+ // I'll keep this test case here for now, plus setting round to 1
+ // so the test can still pass.
+ // In the future, we should rewind this test case to round > 1 (previously was 5)
+ private int round = 1;
+ public MathClassTest(ITestOutputHelper output)
+ {
+ _output = output;
+ }
+
+ private Task Print(IEnumerable messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
+ {
+ try
+ {
+ var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
+
+ _output.WriteLine(reply.FormatMessage());
+ return Task.FromResult(reply);
+ }
+ catch (Exception)
+ {
+ _output.WriteLine("Request failed");
+ _output.WriteLine($"agent name: {agent.Name}");
+ foreach (var message in messages)
+ {
+ if (message is IMessage envelope)
+ {
+ var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
+ _output.WriteLine(json);
+ }
+ }
+
+ throw;
+ }
+
+ }
+
+ [FunctionAttribute]
+ public async Task CreateMathQuestion(string question, int question_index)
+ {
+ return $@"[MATH_QUESTION]
+Question {question_index}:
+{question}
+
+Student, please answer";
+ }
+
+ [FunctionAttribute]
+ public async Task AnswerQuestion(string answer)
+ {
+ return $@"[MATH_ANSWER]
+The answer is {answer}
+teacher please check answer";
+ }
+
+ [FunctionAttribute]
+ public async Task AnswerIsCorrect(string message)
+ {
+ return $@"[ANSWER_IS_CORRECT]
+{message}
+please update progress";
+ }
+
+ [FunctionAttribute]
+ public async Task UpdateProgress(int correctAnswerCount)
+ {
+ if (correctAnswerCount >= this.round)
+ {
+ return $@"[UPDATE_PROGRESS]
+{GroupChatExtension.TERMINATE}";
+ }
+ else
+ {
+ return $@"[UPDATE_PROGRESS]
+the number of resolved question is {correctAnswerCount}
+teacher, please create the next math question";
+ }
+ }
+
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task OpenAIAgentMathChatTestAsync()
+ {
+ var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
+ var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
+ var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
+ var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
+ var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
+
+ var adminFunctionMiddleware = new FunctionCallMiddleware(
+ functions: [this.UpdateProgressFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
+ });
+ var admin = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ modelName: deployName,
+ name: "Admin",
+ systemMessage: $@"You are admin. You update progress after each question is answered.")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(adminFunctionMiddleware)
+ .RegisterMiddleware(Print);
+
+ var groupAdmin = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ modelName: deployName,
+ name: "GroupAdmin",
+ systemMessage: "You are group admin. You manage the group chat.")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(Print);
+ await RunMathChatAsync(teacher, student, admin, groupAdmin);
+ }
+
+ private async Task CreateTeacherAgentAsync(OpenAIClient client, string model)
+ {
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
+ { this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
+ });
+
+ var teacher = new OpenAIChatAgent(
+ openAIClient: client,
+ name: "Teacher",
+ systemMessage: @"You are a preschool math teacher.
+You create math question and ask student to answer it.
+Then you check if the answer is correct.
+If the answer is wrong, you ask student to fix it",
+ modelName: model)
+ .RegisterMiddleware(Print)
+ .RegisterMiddleware(new OpenAIChatRequestMessageConnector())
+ .RegisterMiddleware(functionCallMiddleware);
+
+ return teacher;
+ }
+
+ private async Task CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
+ {
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.AnswerQuestionFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
+ });
+ var student = new OpenAIChatAgent(
+ openAIClient: client,
+ name: "Student",
+ modelName: model,
+ systemMessage: @"You are a student. You answer math question from teacher.")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware)
+ .RegisterMiddleware(Print);
+
+ return student;
+ }
+
+ private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
+ {
+ var teacher2Student = Transition.Create(teacher, student);
+ var student2Teacher = Transition.Create(student, teacher);
+ var teacher2Admin = Transition.Create(teacher, admin);
+ var admin2Teacher = Transition.Create(admin, teacher);
+ var workflow = new Graph(
+ [
+ teacher2Student,
+ student2Teacher,
+ teacher2Admin,
+ admin2Teacher,
+ ]);
+ var group = new GroupChat(
+ workflow: workflow,
+ members: [
+ admin,
+ teacher,
+ student,
+ ],
+ admin: groupAdmin);
+
+ var groupChatManager = new GroupChatManager(group);
+ var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
+
+ chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
+ .Count()
+ .Should().BeGreaterThanOrEqualTo(this.round);
+
+ chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
+ .Count()
+ .Should().BeGreaterThanOrEqualTo(this.round);
+
+ chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
+ .Count()
+ .Should().BeGreaterThanOrEqualTo(this.round);
+
+ // check if there's terminate chat message from admin
+ chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
+ .Count()
+ .Should().Be(1);
+ }
+ }
+}
diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs
new file mode 100644
index 000000000000..1000339c6886
--- /dev/null
+++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs
@@ -0,0 +1,343 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatAgentTest.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using AutoGen.OpenAI.V1.Extension;
+using AutoGen.Tests;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+
+namespace AutoGen.OpenAI.V1.Tests;
+
+public partial class OpenAIChatAgentTest
+{
+ ///
+ /// Get the weather for a location.
+ ///
+ /// location
+ ///
+ [Function]
+ public async Task GetWeatherAsync(string location)
+ {
+ return $"[GetWeather] The weather in {location} is sunny.";
+ }
+
+ [Function]
+ public async Task CalculateTaxAsync(string location, double income)
+ {
+ return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
+ }
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task BasicConversationTestAsync()
+ {
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ modelName: deployName);
+
+ // By default, OpenAIChatClient supports the following message types
+ // - IMessage
+ var chatMessageContent = MessageEnvelope.Create(new ChatRequestUserMessage("Hello"));
+ var reply = await openAIChatAgent.SendAsync(chatMessageContent);
+
+ reply.Should().BeOfType>();
+ reply.As>().From.Should().Be("assistant");
+ reply.As>().Content.Choices.First().Message.Role.Should().Be(ChatRole.Assistant);
+ reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0);
+
+ // test streaming
+ var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
+
+ await foreach (var streamingMessage in streamingReply)
+ {
+ streamingMessage.Should().BeOfType>();
+ streamingMessage.As>().From.Should().Be("assistant");
+ }
+ }
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task OpenAIChatMessageContentConnectorTestAsync()
+ {
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ modelName: deployName);
+
+ MiddlewareStreamingAgent assistant = openAIChatAgent
+ .RegisterMessageConnector();
+
+ var messages = new IMessage[]
+ {
+ MessageEnvelope.Create(new ChatRequestUserMessage("Hello")),
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ ],
+ from: "user"),
+ };
+
+ foreach (var message in messages)
+ {
+ var reply = await assistant.SendAsync(message);
+
+ reply.Should().BeOfType();
+ reply.As().From.Should().Be("assistant");
+ }
+
+ // test streaming
+ foreach (var message in messages)
+ {
+ var reply = assistant.GenerateStreamingReplyAsync([message]);
+
+ await foreach (var streamingMessage in reply)
+ {
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().From.Should().Be("assistant");
+ }
+ }
+ }
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task OpenAIChatAgentToolCallTestAsync()
+ {
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ modelName: deployName);
+
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.GetWeatherAsyncFunctionContract]);
+ MiddlewareStreamingAgent assistant = openAIChatAgent
+ .RegisterMessageConnector();
+
+ assistant.StreamingMiddlewares.Count().Should().Be(1);
+ var functionCallAgent = assistant
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var question = "What's the weather in Seattle";
+ var messages = new IMessage[]
+ {
+ MessageEnvelope.Create(new ChatRequestUserMessage(question)),
+ new TextMessage(Role.Assistant, question, from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, question, from: "user"),
+ ],
+ from: "user"),
+ };
+
+ foreach (var message in messages)
+ {
+ var reply = await functionCallAgent.SendAsync(message);
+
+ reply.Should().BeOfType();
+ reply.As().From.Should().Be("assistant");
+ reply.As().ToolCalls.Count().Should().Be(1);
+ reply.As().ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name);
+ }
+
+ // test streaming
+ foreach (var message in messages)
+ {
+ var reply = functionCallAgent.GenerateStreamingReplyAsync([message]);
+ ToolCallMessage? toolCallMessage = null;
+ await foreach (var streamingMessage in reply)
+ {
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().From.Should().Be("assistant");
+ if (toolCallMessage is null)
+ {
+ toolCallMessage = new ToolCallMessage(streamingMessage.As());
+ }
+ else
+ {
+ toolCallMessage.Update(streamingMessage.As());
+ }
+ }
+
+ toolCallMessage.Should().NotBeNull();
+ toolCallMessage!.From.Should().Be("assistant");
+ toolCallMessage.ToolCalls.Count().Should().Be(1);
+ toolCallMessage.ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name);
+ }
+ }
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task OpenAIChatAgentToolCallInvokingTestAsync()
+ {
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ modelName: deployName);
+
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.GetWeatherAsyncFunctionContract],
+ functionMap: new Dictionary>> { { this.GetWeatherAsyncFunctionContract.Name!, this.GetWeatherAsyncWrapper } });
+ MiddlewareStreamingAgent assistant = openAIChatAgent
+ .RegisterMessageConnector();
+
+ var functionCallAgent = assistant
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var question = "What's the weather in Seattle";
+ var messages = new IMessage[]
+ {
+ MessageEnvelope.Create(new ChatRequestUserMessage(question)),
+ new TextMessage(Role.Assistant, question, from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, question, from: "user"),
+ ],
+ from: "user"),
+ };
+
+ foreach (var message in messages)
+ {
+ var reply = await functionCallAgent.SendAsync(message);
+
+ reply.Should().BeOfType();
+ reply.From.Should().Be("assistant");
+ reply.GetToolCalls()!.Count().Should().Be(1);
+ reply.GetToolCalls()!.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name);
+ reply.GetContent()!.ToLower().Should().Contain("seattle");
+ }
+
+ // test streaming
+ foreach (var message in messages)
+ {
+ var reply = functionCallAgent.GenerateStreamingReplyAsync([message]);
+ await foreach (var streamingMessage in reply)
+ {
+ if (streamingMessage is not IMessage)
+ {
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().From.Should().Be("assistant");
+ }
+ else
+ {
+ streamingMessage.Should().BeOfType();
+ streamingMessage.As().GetContent()!.ToLower().Should().Contain("seattle");
+ }
+ }
+ }
+ }
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task ItCreateOpenAIChatAgentWithChatCompletionOptionAsync()
+ {
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
+ var options = new ChatCompletionsOptions(deployName, [])
+ {
+ Temperature = 0.7f,
+ MaxTokens = 1,
+ };
+
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ options: options)
+ .RegisterMessageConnector();
+
+ var respond = await openAIChatAgent.SendAsync("hello");
+ respond.GetContent()?.Should().NotBeNullOrEmpty();
+ }
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages()
+ {
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = CreateOpenAIClientFromAzureOpenAI();
+ var options = new ChatCompletionsOptions(deployName, [new ChatRequestUserMessage("hi")])
+ {
+ Temperature = 0.7f,
+ MaxTokens = 1,
+ };
+
+ var action = () => new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ options: options)
+ .RegisterMessageConnector();
+
+ action.Should().ThrowExactly