Building a Custom Agent¶
In this cookbook we show you how to build a custom agent using LlamaIndex.
- The easiest way to build a custom agent is to simply define a stateful function and plug it into
FnAgentWorker
. - [Optional] Another approach that allows you to peek into our agent abstractions a bit more is to subclass
CustomSimpleAgentWorker
and implement a few required functions. You have complete flexibility in defining the agent step-wise logic.
This lets you add arbitrarily complex reasoning logic on top of your RAG pipeline.
We show you how to build a simple agent that adds a retry layer on top of a RouterQueryEngine, allowing it to retry queries until the task is complete. We build this on top of both a SQL tool and a vector index query tool. Even if the tool makes an error or only answers part of the question, the agent can continue retrying the question until the task is complete.
NOTE: Any Text-to-SQL application should be aware that executing arbitrary SQL queries can be a security risk. It is recommended to take precautions as needed, such as using restricted roles, read-only databases, sandboxing, etc.
%pip install llama-index-readers-wikipedia
%pip install llama-index-llms-openai
Setup Data and Tools¶
We setup both a SQL Tool as well as vector index tools for each city.
from llama_index.llms.openai import OpenAI
llm = OpenAI(model="gpt-4o")
from llama_index.core.tools import QueryEngineTool
Setup SQL DB + Tool¶
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
select,
column,
)
from llama_index.core import SQLDatabase
engine = create_engine("sqlite:///:memory:", future=True)
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
from sqlalchemy import insert
rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{"city_name": "Berlin", "population": 3645000, "country": "Germany"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
from llama_index.core.query_engine import NLSQLTableQueryEngine
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_query_engine = NLSQLTableQueryEngine(
sql_database=sql_database, tables=["city_stats"], verbose=True, llm=llm
)
sql_tool = QueryEngineTool.from_defaults(
query_engine=sql_query_engine,
description=(
"Useful for translating a natural language query into a SQL query over"
" a table containing: city_stats, containing the population/country of"
" each city"
),
)
Setup Vector Tools¶
from llama_index.readers.wikipedia import WikipediaReader
from llama_index.core import VectorStoreIndex
cities = ["Toronto", "Berlin", "Tokyo"]
wiki_docs = WikipediaReader().load_data(pages=cities)
# build a separate vector index per city
# You could also choose to define a single vector index across all docs, and annotate each chunk by metadata
vector_tools = []
for city, wiki_doc in zip(cities, wiki_docs):
vector_index = VectorStoreIndex.from_documents([wiki_doc])
vector_query_engine = vector_index.as_query_engine()
vector_tool = QueryEngineTool.from_defaults(
query_engine=vector_query_engine,
description=f"Useful for answering semantic questions about {city}",
)
vector_tools.append(vector_tool)
Setup the Custom Agent¶
Here we setup the custom agent. There are two ways to setup a custom agent.
In the first approach, you just define a custom function, whereas in the second approach, you learn a bit more about using some of the low-level agent components that LlamaIndex has to offer, giving you a more structured approach to handle validation, run things step-wise, and modify the output.
Basic Setup¶
Here we define some common functions used for both implementations.
from typing import Dict, Any, List, Tuple, Optional
from llama_index.core.tools import QueryEngineTool
from llama_index.core.program import FunctionCallingProgram
from llama_index.core.query_engine import RouterQueryEngine
from llama_index.core import ChatPromptTemplate
from llama_index.core.selectors import PydanticSingleSelector
from llama_index.core.bridge.pydantic import Field, BaseModel
Here we define some helper variables and methods. E.g. the prompt template to use to detect errors as well as the response format in Pydantic.
from llama_index.core.llms import ChatMessage, MessageRole
DEFAULT_PROMPT_STR = """
Given previous question/response pairs, please determine if an error has occurred in the response, and suggest \
a modified question that will not trigger the error.
Examples of modified questions:
- The question itself is modified to elicit a non-erroneous response
- The question is augmented with context that will help the downstream system better answer the question.
- The question is augmented with examples of negative responses, or other negative questions.
An error means that either an exception has triggered, or the response is completely irrelevant to the question.
Please return the evaluation of the response in the following JSON format.
"""
def get_chat_prompt_template(
system_prompt: str, current_reasoning: Tuple[str, str]
) -> ChatPromptTemplate:
system_msg = ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
messages = [system_msg]
for raw_msg in current_reasoning:
if raw_msg[0] == "user":
messages.append(
ChatMessage(role=MessageRole.USER, content=raw_msg[1])
)
else:
messages.append(
ChatMessage(role=MessageRole.ASSISTANT, content=raw_msg[1])
)
return ChatPromptTemplate(message_templates=messages)
class ResponseEval(BaseModel):
"""Evaluation of whether the response has an error."""
has_error: bool = Field(
..., description="Whether the response has an error."
)
new_question: str = Field(..., description="The suggested new question.")
explanation: str = Field(
...,
description=(
"The explanation for the error as well as for the new question."
"Can include the direct stack trace as well."
),
)
Define Agent State Function¶
Here we define a simple Python function that modifies the state
variable and executes a single step. It returns a Tuple of the state dictionary and whether or not the agent has completed execution.
We wrap it with a FnAgentWorker
that can give us an agent that can run this function multiple steps.
Notes:
- The state dictionary passed to the Python function can access a special
__task__
variable that theFnAgentWorker
injects during execution, representing the task object maintained by the agent throughout execution. - The output of the agent is defined by the
__output__
variable in the state dictionary. Whenis_done
is True, make sure__output__
is defined as well. - You can customize the key names of both the input and output variables through customizing
task_input_key
andoutput_key
in theFnAgentWorker
. - You can also inject any variables you want during initialization through the
initial_state
parameter in theFnAgentWorker
initialization.
from llama_index.core.bridge.pydantic import PrivateAttr
def retry_agent_fn(state: Dict[str, Any]) -> Tuple[Dict[str, Any], bool]:
"""Retry agent.
Runs a single step.
Returns:
Tuple of (agent_response, is_done)
"""
task, router_query_engine = state["__task__"], state["router_query_engine"]
llm, prompt_str = state["llm"], state["prompt_str"]
verbose = state.get("verbose", False)
if "new_input" not in state:
new_input = task.input
else:
new_input = state["new_input"]
# first run router query engine
response = router_query_engine.query(new_input)
# append to current reasoning
state["current_reasoning"].extend(
[("user", new_input), ("assistant", str(response))]
)
# Then, check for errors
# dynamically create pydantic program for structured output extraction based on template
chat_prompt_tmpl = get_chat_prompt_template(
prompt_str, state["current_reasoning"]
)
llm_program = FunctionCallingProgram.from_defaults(
output_cls=ResponseEval,
prompt=chat_prompt_tmpl,
llm=llm,
)
# run program, look at the result
response_eval = llm_program(
query_str=new_input, response_str=str(response)
)
if not response_eval.has_error:
is_done = True
else:
is_done = False
state["new_input"] = response_eval.new_question
if verbose:
print(f"> Question: {new_input}")
print(f"> Response: {response}")
print(f"> Response eval: {response_eval.dict()}")
# set output
state["__output__"] = str(response)
# return response
return state, is_done
from llama_index.llms.openai import OpenAI
from llama_index.core.agent import FnAgentWorker
llm = OpenAI(model="gpt-4o")
router_query_engine = RouterQueryEngine(
selector=PydanticSingleSelector.from_defaults(llm=llm),
query_engine_tools=[sql_tool] + vector_tools,
verbose=True,
)
agent = FnAgentWorker(
fn=retry_agent_fn,
initial_state={
"prompt_str": DEFAULT_PROMPT_STR,
"llm": llm,
"router_query_engine": router_query_engine,
"current_reasoning": [],
"verbose": True,
},
).as_agent()
Try Out Some Queries¶
Now that we've defined the agent, you can try out some queries.
response = agent.chat("Which countries are each city from?")
print(str(response))
Selecting query engine 0: The question asks for the countries of each city, which requires translating a natural language query into a SQL query over a table containing city statistics, including population and country information.. > Question: Which countries are each city from? > Response: Here are the countries for each city: - Toronto is in Canada. - Tokyo is in Japan. - Berlin is in Germany. > Response eval: {'has_error': True, 'new_question': 'Can you provide the countries for the following cities: Toronto, Tokyo, and Berlin?', 'explanation': 'The original question is too vague and does not specify which cities need to be identified. The response assumes a set of cities without confirmation. By specifying the cities in the question, the response can be more accurate and relevant.'} Selecting query engine 0: The question requires translating a natural language query into a SQL query over a table containing city statistics, including the population and country of each city.. > Question: Can you provide the countries for the following cities: Toronto, Tokyo, and Berlin? > Response: Sure! Here are the countries for the given cities: - Toronto is in Canada. - Tokyo is in Japan. - Berlin is in Germany. > Response eval: {'has_error': False, 'new_question': 'Can you provide the countries for the following cities: Toronto, Tokyo, and Berlin?', 'explanation': 'The response correctly identifies the countries for the given cities: Toronto (Canada), Tokyo (Japan), and Berlin (Germany). No error is present in the response.'} Sure! Here are the countries for the given cities: - Toronto is in Canada. - Tokyo is in Japan. - Berlin is in Germany.
response = agent.chat(
"What is the city in Canada, and what are the top modes of transport for that city?"
)
print(str(response))
Selecting query engine 1: The question asks about a city in Canada, and Toronto is a city in Canada. Therefore, the choice that is useful for answering semantic questions about Toronto is the most relevant.. > Question: What is the city in Canada, and what are the top modes of transport for that city? > Response: The city in Canada is Toronto. The top modes of transport for Toronto are the Toronto subway system, buses, streetcars, and an extensive network of bicycle lanes and multi-use trails and paths. > Response eval: {'has_error': True, 'new_question': 'What are the top modes of transport in Toronto, Canada?', 'explanation': 'The original question is ambiguous and could refer to any city in Canada. The response incorrectly assumes the city is Toronto without any context. The modified question specifies Toronto directly to avoid ambiguity.'} Selecting query engine 1: The question is about semantic information specific to Toronto, so the choice that is useful for answering semantic questions about Toronto is the most relevant.. > Question: What are the top modes of transport in Toronto, Canada? > Response: The top modes of transport in Toronto, Canada are the public transportation system operated by the Toronto Transit Commission (TTC), which includes the subway system, buses, and streetcars, as well as the regional rail and bus transit system operated by GO Transit. Additionally, Toronto is served by major highways, an extensive network of bicycle lanes, and two airports - Toronto Pearson International Airport and Billy Bishop Toronto City Airport. > Response eval: {'has_error': False, 'new_question': 'What are the top modes of transport in Toronto, Canada?', 'explanation': 'The response correctly identifies the top modes of transport in Toronto, Canada, including the public transportation system operated by the TTC, GO Transit, major highways, bicycle lanes, and airports.'} The top modes of transport in Toronto, Canada are the public transportation system operated by the Toronto Transit Commission (TTC), which includes the subway system, buses, and streetcars, as well as the regional rail and bus transit system operated by GO Transit. Additionally, Toronto is served by major highways, an extensive network of bicycle lanes, and two airports - Toronto Pearson International Airport and Billy Bishop Toronto City Airport.
response = sql_query_engine.query(
"What are the top modes of transporation fo the city with the lowest population?"
)
print(str(response.metadata["sql_query"]))
print(str(response))
SELECT mode_of_transportation, COUNT(*) as num_trips FROM trip_data WHERE city_name = (SELECT city_name FROM city_stats ORDER BY population ASC LIMIT 1) GROUP BY mode_of_transportation ORDER BY num_trips DESC LIMIT 3; It seems there was an error in the SQL query provided. To find the top modes of transportation for the city with the lowest population, you would need to first identify the city with the lowest population from the city_stats table, and then query the trip_data table for the mode of transportation used in that city. Once you have the city name, you can then count the number of trips for each mode of transportation in that city to determine the top modes of transportation.
response = agent.chat("What are the sports teams of each city in Asia?")
print(str(response))
Selecting query engine 3: Tokyo is a city in Asia and is likely to have information about sports teams in that region.. > Question: What are the sports teams of each city in Asia? > Response: Tokyo is home to two professional baseball clubs, the Yomiuri Giants and Tokyo Yakult Swallows, as well as soccer clubs F.C. Tokyo, Tokyo Verdy 1969, and FC Machida Zelvia. Rugby Union teams in Tokyo include Black Rams Tokyo, Tokyo Sungoliath, and Toshiba Brave Lupus Tokyo. Additionally, basketball clubs in Tokyo include the Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence. > Response eval: {'has_error': True, 'new_question': 'What are some sports teams in Tokyo, Japan?', 'explanation': 'The original question was too broad, as there are many cities in Asia with multiple sports teams. The response only provided information about sports teams in Tokyo, Japan. The new question narrows the scope to a specific city in Asia.'} Selecting query engine 3: The choice (4) is the most relevant as it is useful for answering semantic questions about Tokyo, which includes providing information about sports teams in Tokyo, Japan.. > Question: What are some sports teams in Tokyo, Japan? > Response: Some sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, FC Machida Zelvia in soccer, Black Rams Tokyo, Tokyo Sungoliath, and Toshiba Brave Lupus Tokyo in Rugby Union, Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. > Response eval: {'has_error': False, 'new_question': '', 'explanation': ''} Some sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, FC Machida Zelvia in soccer, Black Rams Tokyo, Tokyo Sungoliath, and Toshiba Brave Lupus Tokyo in Rugby Union, Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball.
[Optional] Build a Custom Agent through Subclassing¶
If you'd like, you can also choose to build a custom agent through subclassing the CustomSimpleAgentWorker
. This is if you want to more heavily customize the mechanisms of our agent interfaces, such as the Task and AgentChatResponse objects and step-wise execution.
NOTE: You probably don't need to read this section for most custom agent flows.
Refresher¶
An agent in LlamaIndex consists of both an agent runner + agent worker. An agent runner is an orchestrator that stores state like memory, whereas an agent worker controls the step-wise execution of a Task. Agent runners include sequential, parallel execution. More details can be found in our lower level API guide.
Most core agent logic (e.g. ReAct, function calling loops), can be executed in the agent worker. Therefore we've made it easy to subclass an agent worker, letting you plug it into any agent runner.
Creating a Custom Agent Worker Subclass¶
As mentioned above we subclass CustomSimpleAgentWorker
. This is a class that already sets up some scaffolding for you. This includes being able to take in tools, callbacks, LLM, and also ensures that the state/steps are properly formatted. In the meantime you mostly have to implement the following functions:
_initialize_state
_run_step
_finalize_task
Some additional notes:
- You can implement
_arun_step
as well if you want to support async chat in the agent. - You can choose to override
__init__
as long as you pass all remaining args, kwargs tosuper()
CustomSimpleAgentWorker
is implemented as a PydanticBaseModel
meaning that you can define your own custom properties as well.
Here are the full set of base properties on each CustomSimpleAgentWorker
(that you need to/can pass in when constructing your custom agent):
tools: Sequence[BaseTool]
tool_retriever: Optional[ObjectRetriever[BaseTool]]
llm: LLM
callback_manager: CallbackManager
verbose: bool
Note that tools
and tool_retriever
are mutually exclusive, you can only pass in one or the either (e.g. define a static list of tools or define a callable function that returns relevant tools given a user message). You can call get_tools(message: str)
to return relevant tools given a message.
All of these properties are accessible via self
when defining your custom agent.
from llama_index.core.agent import (
CustomSimpleAgentWorker,
Task,
AgentChatResponse,
)
from typing import Dict, Any, List, Tuple, Optional
from llama_index.core.tools import BaseTool, QueryEngineTool
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.query_engine import RouterQueryEngine
from llama_index.core import ChatPromptTemplate, PromptTemplate
from llama_index.core.selectors import PydanticSingleSelector
from llama_index.core.bridge.pydantic import Field, BaseModel
from llama_index.core.bridge.pydantic import PrivateAttr
class RetryAgentWorker(CustomSimpleAgentWorker):
"""Agent worker that adds a retry layer on top of a router.
Continues iterating until there's no errors / task is done.
"""
prompt_str: str = Field(default=DEFAULT_PROMPT_STR)
max_iterations: int = Field(default=10)
_router_query_engine: RouterQueryEngine = PrivateAttr()
def __init__(self, tools: List[BaseTool], **kwargs: Any) -> None:
"""Init params."""
# validate that all tools are query engine tools
for tool in tools:
if not isinstance(tool, QueryEngineTool):
raise ValueError(
f"Tool {tool.metadata.name} is not a query engine tool."
)
self._router_query_engine = RouterQueryEngine(
selector=PydanticSingleSelector.from_defaults(),
query_engine_tools=tools,
verbose=kwargs.get("verbose", False),
)
super().__init__(
tools=tools,
**kwargs,
)
def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]:
"""Initialize state."""
return {"count": 0, "current_reasoning": []}
def _run_step(
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
) -> Tuple[AgentChatResponse, bool]:
"""Run step.
Returns:
Tuple of (agent_response, is_done)
"""
if "new_input" not in state:
new_input = task.input
else:
new_input = state["new_input"]
# first run router query engine
response = self._router_query_engine.query(new_input)
# append to current reasoning
state["current_reasoning"].extend(
[("user", new_input), ("assistant", str(response))]
)
# Then, check for errors
# dynamically create pydantic program for structured output extraction based on template
chat_prompt_tmpl = get_chat_prompt_template(
self.prompt_str, state["current_reasoning"]
)
llm_program = LLMTextCompletionProgram.from_defaults(
output_parser=PydanticOutputParser(output_cls=ResponseEval),
prompt=chat_prompt_tmpl,
llm=self.llm,
)
# run program, look at the result
response_eval = llm_program(
query_str=new_input, response_str=str(response)
)
if not response_eval.has_error:
is_done = True
else:
is_done = False
state["new_input"] = response_eval.new_question
if self.verbose:
print(f"> Question: {new_input}")
print(f"> Response: {response}")
print(f"> Response eval: {response_eval.dict()}")
# return response
return AgentChatResponse(response=str(response)), is_done
def _finalize_task(self, state: Dict[str, Any], **kwargs) -> None:
"""Finalize task."""
# nothing to finalize here
# this is usually if you want to modify any sort of
# internal state beyond what is set in `_initialize_state`
pass
Define Custom Agent¶
from llama_index.llms.openai import OpenAI
llm = OpenAI(model="gpt-4")
callback_manager = llm.callback_manager
query_engine_tools = [sql_tool] + vector_tools
agent_worker = RetryAgentWorker.from_tools(
query_engine_tools,
llm=llm,
verbose=True,
callback_manager=callback_manager,
)
agent = agent_worker.as_agent(callback_manager=callback_manager)