Skip to content

Commit 6137c32

Browse files
chat : add Granite 4.0 chat template with correct tool_call role mapping (ggml-org#20804)
* chat : add Granite 4.0 chat template with correct tool_call role mapping Introduce `LLM_CHAT_TEMPLATE_GRANITE_4_0` alongside the existing Granite 3.x template (renamed `LLM_CHAT_TEMPLATE_GRANITE_3_X`). The Granite 4.0 Jinja template uses `<tool_call>` XML tags and maps the `assistant_tool_call` role to `<|start_of_role|>assistant<|end_of_role|><|tool_call|>`. Without a matching C++ handler, the fallback path emits the literal role `assistant_tool_call` which the model does not recognize, breaking tool calling when `--jinja` is not used. Changes: - Rename `LLM_CHAT_TEMPLATE_GRANITE` to `LLM_CHAT_TEMPLATE_GRANITE_3_X` (preserves existing 3.x behavior unchanged) - Add `LLM_CHAT_TEMPLATE_GRANITE_4_0` enum, map entry, and handler - Detection: `<|start_of_role|>` + (`<tool_call>` or `<tools>`) → 4.0, otherwise → 3.x - Add production Granite 4.0 Jinja template - Add tests for both 3.x and 4.0 template paths (C++ and Jinja) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Code review: follow standard format and use common logic in test-chat-template.cpp * Rename custom_conversation variable for extra_conversation to give it a more meaningful name --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 17193cc commit 6137c32

5 files changed

Lines changed: 189 additions & 9 deletions

File tree

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
{%- set tools_system_message_prefix = 'You are a helpful assistant with access to the following tools. You may call one or more tools to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>' %}
2+
{%- set tools_system_message_suffix = '\n</tools>\n\nFor each tool call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.' %}
3+
{%- set documents_system_message_prefix = 'You are a helpful assistant with access to the following documents. You may use one or more documents to assist with the user query.\n\nYou are given a list of documents within <documents></documents> XML tags:\n<documents>' %}
4+
{%- set documents_system_message_suffix = '\n</documents>\n\nWrite the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.' %}
5+
{%- set g4_default_system_message = 'You are a helpful assistant. Please ensure responses are professional, accurate, and safe.' %}
6+
{%- if available_tools is defined and available_tools %}
7+
{%- set tools = available_tools %}
8+
{%- endif %}
9+
{%- set ns = namespace(tools_system_message=tools_system_message_prefix,
10+
documents_system_message=documents_system_message_prefix,
11+
default_system_message=g4_default_system_message,
12+
system_message=''
13+
) %}
14+
{%- if tools %}
15+
{%- for tool in tools %}
16+
{%- set ns.tools_system_message = ns.tools_system_message + '\n' + (tool | tojson) %}
17+
{%- endfor %}
18+
{%- set ns.tools_system_message = ns.tools_system_message + tools_system_message_suffix %}
19+
{%- else %}
20+
{%- set ns.tools_system_message = '' %}
21+
{%- endif %}
22+
{%- if documents %}
23+
{%- for document in documents %}
24+
{%- set ns.documents_system_message = ns.documents_system_message + '\n' + (document | tojson) %}
25+
{%- endfor %}
26+
{%- set ns.documents_system_message = ns.documents_system_message + documents_system_message_suffix %}
27+
{%- else %}
28+
{%- set ns.documents_system_message = '' %}
29+
{%- endif %}
30+
{%- if messages[0].role == 'system' %}
31+
{%- if messages[0].content is string %}
32+
{%- set ns.system_message = messages[0].content %}
33+
{%- elif messages[0].content is iterable %}
34+
{%- for entry in messages[0].content %}
35+
{%- if entry.type== 'text' %}
36+
{%- if ns.system_message != '' %}
37+
{%- set ns.system_message = ns.system_message + '\n' %}
38+
{%- endif %}
39+
{%- set ns.system_message = ns.system_message + entry.text %}
40+
{%- endif %}
41+
{%- endfor %}
42+
{%- endif %}
43+
{%- if tools and documents %}
44+
{%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message + '\n\n' + ns.documents_system_message %}
45+
{%- elif tools %}
46+
{%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message %}
47+
{%- elif documents %}
48+
{%- set ns.system_message = ns.system_message + '\n\n' + ns.documents_system_message %}
49+
{%- endif %}
50+
{%- else %}
51+
{%- if tools and documents %}
52+
{%- set ns.system_message = ns.tools_system_message + '\n\n' + ns.documents_system_message %}
53+
{%- elif tools %}
54+
{%- set ns.system_message = ns.tools_system_message %}
55+
{%- elif documents %}
56+
{%- set ns.system_message = ns.documents_system_message %}
57+
{%- endif %}
58+
{%- endif %}
59+
{%- if ns.system_message %}
60+
{{- '<|start_of_role|>system<|end_of_role|>' + ns.system_message + '<|end_of_text|>\n' }}
61+
{%- else %}
62+
{{- '<|start_of_role|>system<|end_of_role|>' + ns.default_system_message + '<|end_of_text|>\n' }}
63+
{%- endif %}
64+
{%- for message in messages %}
65+
{%- set content = namespace(val='') %}
66+
{%- if message.content is string %}
67+
{%- set content.val = message.content %}
68+
{%- else %}
69+
{%- if message.content is iterable %}
70+
{%- for entry in message.content %}
71+
{%- if entry.type== 'text' %}
72+
{%- if content.val != '' %}
73+
{%- set content.val = content.val + '\n' %}
74+
{%- endif %}
75+
{%- set content.val = content.val + entry.text %}
76+
{%- endif %}
77+
{%- endfor %}
78+
{%- endif %}
79+
{%- endif %}
80+
{%- if (message.role == 'user') or (message.role == 'system' and not loop.first) %}
81+
{{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val + '<|end_of_text|>\n' }}
82+
{%- elif message.role == 'assistant' %}
83+
{{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }}
84+
{%- if message.tool_calls %}
85+
{%- for tool_call in message.tool_calls %}
86+
{%- if (loop.first and content.val) or (not loop.first) %}
87+
{{- '\n' }}
88+
{%- endif %}
89+
{%- if tool_call.function %}
90+
{%- set tool_call = tool_call.function %}
91+
{%- endif %}
92+
{{- '<tool_call>\n{"name": "' }}
93+
{{- tool_call.name }}
94+
{{- '", "arguments": ' }}
95+
{%- if tool_call.arguments is string %}
96+
{{- tool_call.arguments }}
97+
{%- else %}
98+
{{- tool_call.arguments | tojson }}
99+
{%- endif %}
100+
{{- '}\n</tool_call>' }}
101+
{%- endfor %}
102+
{%- endif %}
103+
{{- '<|end_of_text|>\n' }}
104+
{%- elif message.role == 'tool' %}
105+
{%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %}
106+
{{- '<|start_of_role|>user<|end_of_role|>' }}
107+
{%- endif %}
108+
{{- '\n<tool_response>\n' }}
109+
{{- content.val }}
110+
{{- '\n</tool_response>' }}
111+
{%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %}
112+
{{- '<|end_of_text|>\n' }}
113+
{%- endif %}
114+
{%- endif %}
115+
{%- endfor %}
116+
{%- if add_generation_prompt %}
117+
{{- '<|start_of_role|>assistant<|end_of_role|>' }}
118+
{%- endif %}

src/llama-chat.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
6060
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
6161
{ "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE },
6262
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
63-
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
63+
{ "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X },
64+
{ "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 },
6465
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
6566
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
6667
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
@@ -191,7 +192,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
191192
} else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
192193
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
193194
} else if (tmpl_contains("<|start_of_role|>")) {
194-
return LLM_CHAT_TEMPLATE_GRANITE;
195+
if (tmpl_contains("<tool_call>") || tmpl_contains("<tools>")) {
196+
return LLM_CHAT_TEMPLATE_GRANITE_4_0;
197+
}
198+
return LLM_CHAT_TEMPLATE_GRANITE_3_X;
195199
} else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
196200
return LLM_CHAT_TEMPLATE_GIGACHAT;
197201
} else if (tmpl_contains("<|role_start|>")) {
@@ -617,8 +621,8 @@ int32_t llm_chat_apply_template(
617621
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
618622
}
619623
}
620-
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
621-
// IBM Granite template
624+
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_3_X) {
625+
// IBM Granite 3.x template
622626
for (const auto & message : chat) {
623627
std::string role(message->role);
624628
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
@@ -630,6 +634,20 @@ int32_t llm_chat_apply_template(
630634
if (add_ass) {
631635
ss << "<|start_of_role|>assistant<|end_of_role|>";
632636
}
637+
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_0) {
638+
// IBM Granite 4.0 template
639+
for (const auto & message : chat) {
640+
std::string role(message->role);
641+
if (role == "assistant_tool_call") {
642+
ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>";
643+
} else {
644+
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
645+
}
646+
ss << message->content << "<|end_of_text|>\n";
647+
}
648+
if (add_ass) {
649+
ss << "<|start_of_role|>assistant<|end_of_role|>";
650+
}
633651
} else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
634652
// GigaChat template
635653
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";

src/llama-chat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ enum llm_chat_template {
3939
LLM_CHAT_TEMPLATE_EXAONE_4,
4040
LLM_CHAT_TEMPLATE_EXAONE_MOE,
4141
LLM_CHAT_TEMPLATE_RWKV_WORLD,
42-
LLM_CHAT_TEMPLATE_GRANITE,
42+
LLM_CHAT_TEMPLATE_GRANITE_3_X,
43+
LLM_CHAT_TEMPLATE_GRANITE_4_0,
4344
LLM_CHAT_TEMPLATE_GIGACHAT,
4445
LLM_CHAT_TEMPLATE_MEGREZ,
4546
LLM_CHAT_TEMPLATE_YANDEX,

tests/test-chat-template.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ int main_automated_tests(void) {
354354
std::string bos_token = "";
355355
std::string eos_token = "";
356356
bool supported_with_jinja = true;
357+
std::vector<llama_chat_message> extra_conversation = {};
357358
};
358359
std::vector<TestCase> test_cases {
359360
{
@@ -604,6 +605,26 @@ int main_automated_tests(void) {
604605
/* .expected_output_jinja= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
605606
/* .bos_token= */ "<seed:bos>",
606607
/* .eos_token= */ "<seed:eos>",
608+
},
609+
{
610+
/* .name= */ "ibm-granite/granite-3.x (tool call)",
611+
/* .template_str= */ "{%- for message in messages %}\n {%- if message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- else %}\n {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
612+
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is the weather?<|end_of_text|>\n<|start_of_role|>assistant_tool_call<|end_of_role|><|tool_call|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}]<|end_of_text|>\n<|start_of_role|>tool_response<|end_of_role|>{\"temperature\": 72}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
613+
/* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is the weather?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|><|tool_call|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}]<|end_of_text|>\n<|start_of_role|>tool_response<|end_of_role|>{\"temperature\": 72}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
614+
/* .bos_token= */ "",
615+
/* .eos_token= */ "",
616+
/* .supported_with_jinja= */ true,
617+
/* .extra_conversation= */ {{"user", "What is the weather?"}, {"assistant_tool_call", "[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}]"}, {"tool_response", "{\"temperature\": 72}"}},
618+
},
619+
{
620+
/* .name= */ "ibm-granite/granite-4.0 (tool call)",
621+
/* .template_str= */ "{%- for message in messages %}\n {%- if message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- else %}\n {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}\n{# <tool_call> <tools> #}",
622+
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is the weather?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|><|tool_call|><tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}\n</tool_call><|end_of_text|>\n<|start_of_role|>tool_response<|end_of_role|>{\"temperature\": 72}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
623+
/* .expected_output_jinja= */ "",
624+
/* .bos_token= */ "",
625+
/* .eos_token= */ "",
626+
/* .supported_with_jinja= */ true,
627+
/* .extra_conversation= */ {{"user", "What is the weather?"}, {"assistant_tool_call", "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}\n</tool_call>"}, {"tool_response", "{\"temperature\": 72}"}},
607628
}
608629
};
609630
std::vector<char> formatted_chat(1024);
@@ -627,11 +648,13 @@ int main_automated_tests(void) {
627648

628649
for (const auto & test_case : test_cases) {
629650
std::cout << "\n\n=== " << test_case.name << " ===\n\n";
630-
formatted_chat.resize(1024);
651+
auto conv = conversation;
652+
conv.insert(conv.end(), test_case.extra_conversation.begin(), test_case.extra_conversation.end());
653+
formatted_chat.resize(2048);
631654
res = llama_chat_apply_template(
632655
test_case.template_str.c_str(),
633-
conversation.data(),
634-
conversation.size(),
656+
conv.data(),
657+
conv.size(),
635658
add_generation_prompt,
636659
formatted_chat.data(),
637660
formatted_chat.size()
@@ -658,11 +681,15 @@ int main_automated_tests(void) {
658681
}
659682
std::cout << "\n\n=== " << test_case.name << " (jinja) ===\n\n";
660683
try {
684+
auto msgs = messages;
685+
for (const auto & msg : test_case.extra_conversation) {
686+
msgs.push_back(simple_msg(msg.role, msg.content));
687+
}
661688
auto output = format_using_common(
662689
test_case.template_str,
663690
test_case.bos_token,
664691
test_case.eos_token,
665-
messages);
692+
msgs);
666693
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
667694
if (output != expected_output) {
668695
std::cout << "Template:```\n" << test_case.template_str << "\n```";

tests/test-chat.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,6 +1929,22 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
19291929
// .run();
19301930
}
19311931

1932+
{
1933+
// IBM Granite 4.0 (production template shared by h-tiny, h-small, micro)
1934+
// Uses <tool_call> XML tags for tool calls, tools in system message
1935+
auto tst = peg_tester("models/templates/ibm-granite-granite-4.0.jinja", detailed_debug);
1936+
1937+
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
1938+
1939+
tst.test(
1940+
"<tool_call>\n"
1941+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
1942+
"</tool_call>")
1943+
.tools({ special_function_tool })
1944+
.expect(message_assist_call)
1945+
.run();
1946+
}
1947+
19321948
{
19331949
// ByteDance-Seed-OSS (reasoning and tool calling model)
19341950
auto tst = peg_tester("models/templates/ByteDance-Seed-OSS.jinja", detailed_debug);

0 commit comments

Comments
 (0)