Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tool-call: fix Qwen 2.5 Coder support, add micro benchmarks, support trigger patterns for lazy grammars #12034

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b37779b
sampler: turn lazy grammar trigger words to regexes
ochafik Feb 21, 2025
a456911
add scripts/tool_bench.sh & .py
ochafik Feb 21, 2025
14a4388
optionally allow any spaces in json schema grammars (useful for llama…
ochafik Feb 22, 2025
e2ca8be
constrain llama json output regardless of function name if matches at…
ochafik Feb 22, 2025
53266f9
better error when wrong function called
ochafik Feb 22, 2025
7833c16
improve error message in weather test
ochafik Feb 22, 2025
0e1a00e
add more models to tool_bench.sh
ochafik Feb 22, 2025
44740f7
benchmark other sizes of qwen 2.5 coder
ochafik Feb 23, 2025
dd6eb97
rm duplicate in tool_bench.sh
ochafik Feb 23, 2025
0fc6218
add missing <variant> include
ochafik Feb 23, 2025
6fd4972
fix lints
ochafik Feb 23, 2025
2e656f9
improve "bad" qwen triggers
ochafik Feb 23, 2025
fbd3c19
add cast to please some gccs
ochafik Feb 23, 2025
62a1416
ditch server test request retry logic
ochafik Feb 23, 2025
596ff7f
fix flake8 lints
ochafik Feb 23, 2025
fe6968f
nits
ochafik Feb 23, 2025
1caacd5
remove any_spaces grammar option, allow extra line for airy llama jso…
ochafik Feb 23, 2025
789a3e1
Update test_tool_call.py
ochafik Feb 23, 2025
6493a14
test w/ beefier qwen 2.5 coder 3b
ochafik Feb 23, 2025
cc817a0
revert some test_hello_world diffs
ochafik Feb 23, 2025
ead02c6
diff
ochafik Feb 23, 2025
d7acf2c
Update test_tool_call.py
ochafik Feb 23, 2025
0db4073
add requirements for tool_bench
ochafik Feb 23, 2025
0ce606b
fix test_thoughts deepseek test expectation
ochafik Feb 23, 2025
a3cde16
Update README.md
ochafik Feb 23, 2025
79ad623
update relaxed newline space rule in grammar tests
ochafik Feb 23, 2025
3fe208a
support add_generation_prompt query parameter (useful for /apply_temp…
ochafik Feb 25, 2025
fe8c79b
Merge remote-tracking branch 'origin/master' into tool-bench-prod
ochafik Feb 25, 2025
99d2d80
token cast tweak for gcc
ochafik Feb 25, 2025
c7fa19a
fix warning on gcc13 w/ uninitialized variant
ochafik Feb 25, 2025
6e5a830
fix python lints
ochafik Feb 25, 2025
0b5d105
fix gcc13 warning
ochafik Feb 25, 2025
7bcc5af
fix pyright lints in tool_bench.py
ochafik Feb 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
447 changes: 303 additions & 144 deletions common/chat.cpp

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}

std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
}

std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
Expand Down
17 changes: 13 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include <vector>
#include <sstream>
#include <variant>

#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
Expand Down Expand Up @@ -110,9 +111,16 @@ enum common_conversation_mode {
COMMON_CONVERSATION_MODE_AUTO = 2,
};

enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
};

struct common_grammar_trigger {
std::string word;
bool at_start;
common_grammar_trigger_type type;
std::variant<llama_token, std::string> value;
};

// sampling parameters
Expand Down Expand Up @@ -163,8 +171,7 @@ struct common_params_sampling {

std::string grammar; // optional BNF-like grammar to constrain sampling
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::vector<common_grammar_trigger> grammar_triggers; // optional trigger words to trigger lazy grammar
std::set<llama_token> preserved_tokens;

std::vector<llama_logit_bias> logit_bias; // logit biases to apply
Expand Down Expand Up @@ -453,6 +460,8 @@ std::string string_repeat(const std::string & str, size_t n);

void string_replace_all(std::string & s, const std::string & search, const std::string & replace);

std::string regex_escape(const std::string & s);

template<class T>
static std::vector<T> string_split(const std::string & str, char delim) {
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
Expand Down
9 changes: 4 additions & 5 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
throw std::runtime_error("At least one of min_value or max_value must be set");
}

const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";

struct BuiltinRule {
std::string content;
Expand Down Expand Up @@ -764,11 +764,10 @@ class SchemaConverter {
public:
SchemaConverter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall,
bool compact_spaces)
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
{
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
_rules["space"] = SPACE_RULE;
}

void resolve_refs(json & schema, const std::string & url) {
Expand Down Expand Up @@ -1007,7 +1006,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
}

std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
Expand Down
1 change: 0 additions & 1 deletion common/json-schema-to-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ struct common_grammar_builder {

struct common_grammar_options {
bool dotall = false;
bool compact_spaces = false;
};

std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
51 changes: 44 additions & 7 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,53 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<const char *> trigger_words;
trigger_words.reserve(params.grammar_trigger_words.size());
for (const auto & str : params.grammar_trigger_words) {
trigger_words.push_back(str.word.c_str());
std::vector<std::string> patterns_at_start;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = std::get<std::string>(trigger.value);
patterns_anywhere.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{
const auto & pattern = std::get<std::string>(trigger.value);
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
{
const auto & token = std::get<llama_token>(trigger.value);
trigger_tokens.push_back(token);
break;
}
default:
GGML_ASSERT(false && "unknown trigger type");
}
}

std::vector<std::string> trigger_patterns;
if (!patterns_at_start.empty()) {
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}

std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {
trigger_patterns_c.push_back(regex.c_str());
}

grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
trigger_words.data(), trigger_words.size(),
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}

Expand Down
2 changes: 1 addition & 1 deletion examples/json_schema_to_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(self, content: str, deps: list | None = None):
self.deps = deps or []

# Constraining spaces to prevent model "running away".
SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'

PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
Expand Down
2 changes: 1 addition & 1 deletion examples/server/public_legacy/json-schema-to-grammar.mjs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
const SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}';
const SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}';

function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
if (minItems === 0 && maxItems === 1) {
Expand Down
86 changes: 60 additions & 26 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,22 @@ struct slot_params {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}

std::vector<std::string> grammar_trigger_words;
for (const auto & trigger : sampling.grammar_trigger_words) {
grammar_trigger_words.push_back(trigger.word);
auto grammar_triggers = json::array();
for (const auto & trigger : sampling.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
grammar_triggers.push_back({{"word", std::get<std::string>(trigger.value)}});
break;
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
grammar_triggers.push_back({{"pattern", std::get<std::string>(trigger.value)}});
break;
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
grammar_triggers.push_back({{"pattern_start", std::get<std::string>(trigger.value)}});
break;
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
grammar_triggers.push_back({{"token", std::get<llama_token>(trigger.value)}});
break;
}
}

return json {
Expand Down Expand Up @@ -170,8 +183,8 @@ struct slot_params {
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
{"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"grammar_lazy", sampling.grammar_lazy},
{"grammar_triggers", grammar_triggers},
{"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
{"samplers", samplers},
Expand Down Expand Up @@ -356,24 +369,6 @@ struct server_task {
}

{
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) {
common_grammar_trigger trigger;
trigger.word = t.at("word");
trigger.at_start = t.at("at_start");

auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
Expand All @@ -383,12 +378,51 @@ struct server_task {
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
}
}
}
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) {
auto type = static_cast<common_grammar_trigger_type>(t.at("type"));
switch (type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const std::string & word = t.at("value");
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
}
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
trigger.value = token;
params.sampling.grammar_triggers.push_back(trigger);
} else {
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
}
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{
const std::string & pattern = t.at("value");
params.sampling.grammar_triggers.push_back({type, pattern});
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
throw std::runtime_error("Unespected token trigger");
default:
throw std::runtime_error("Unknown trigger type");
}
}
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
GGML_ASSERT(params.sampling.grammar_triggers.size() > 0);
}
}

Expand Down Expand Up @@ -2045,7 +2079,7 @@ struct server_context {

if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
slot.params.n_predict = slot.n_predict;
}

Expand Down
Loading
Loading