How to force tool calling behavior
Prerequisites
This guide assumes familiarity with the following concepts: - Chat models - LangChain Tools - How to use a model to call tools :::
In order to force our LLM to select a specific tool, we can use the
tool_choice
parameter to ensure certain behavior. First, letโs define
our model and tools:
import { tool } from "@langchain/core/tools";
import { z } from "zod";
const add = tool(
(input) => {
return `${input.a + input.b}`;
},
{
name: "add",
description: "Adds a and b.",
schema: z.object({
a: z.number(),
b: z.number(),
}),
}
);
const multiply = tool(
(input) => {
return `${input.a * input.b}`;
},
{
name: "admultiplyd",
description: "Multiplies a and b.",
schema: z.object({
a: z.number(),
b: z.number(),
}),
}
);
const tools = [add, multiply];
import { ChatOpenAI } from "@langchain/openai";
const llm = new ChatOpenAI({
model: "gpt-3.5-turbo",
});
For example, we can force our tool to call the multiply tool by using the following code:
const llmForcedToMultiply = llm.bindTools(tools, {
tool_choice: "Multiply",
});
await llmForcedToMultiply.invoke("what is 2 + 4");
2:3 - Type '"Multiply"' is not assignable to type 'ChatCompletionToolChoiceOption'.
Even if we pass it something that doesnโt require multiplcation - it will still call the tool!
We can also just force our tool to select at least one of our tools by
passing in the โanyโ (or โrequiredโ which is OpenAI specific) keyword to
the tool_choice
parameter.
llm_forced_to_use_tool = llm.bind_tools(tools, (tool_choice = "any"));
llm_forced_to_use_tool.invoke("What day is today?");
AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_mCSiJntCwHJUBfaHZVUB2D8W', 'function': {'arguments': '{"a":1,"b":2}', 'name': 'Add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 94, 'total_tokens': 109}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-28f75260-9900-4bed-8cd3-f1579abb65e5-0', tool_calls=[{'name': 'Add', 'args': {'a': 1, 'b': 2}, 'id': 'call_mCSiJntCwHJUBfaHZVUB2D8W'}], usage_metadata={'input_tokens': 94, 'output_tokens': 15, 'total_tokens': 109})