Overview
Llamaindex is a very good library to do RAG with LLM however it is very hard to learn since the function is very high-level and as a beginner it is hard to understand what each function is doing. In this post, I will discuss the key components with source code and example code.
To do a simple RAG, the following components are the most basic one:
- Readers
- Indices
- QueryRunner
The doc will use the 0.2.17 since this is the oldest tag version and it will be easy to understand the key concept for the whole llamaindex.
Reader
The basic class is BaseReader and it has one main abstract function is
def load_data(self, *args: Any, **load_kwargs: Any) -> List[Document]
So what is Document object? It is a generic interface for a data document. And it inherits from BaseDocument.
@dataclass
class Document(BaseDocument):
"""Generic interface for a data document.
This document connects to data sources.
"""
...
@dataclass
class BaseDocument(DataClassJsonMixin):
"""Base document.
Generic abstract interfaces that captures both index structs
as well as documents.
"""
text: Optional[str] = None
doc_id: Optional[str] = None
embedding: Optional[List[float]] = None
# extra fields
extra_info: Optional[Dict[str, Any]] = None
We can see each BaseDocument saved text, doc_id, embedding and extra_info field. DataClassJsonMixin is a class that can dump staff to json easily.
So we can see that Reader can read document and save to Document object. Then how it exactly happened? Let’s take a look at SimpleDirectoryReader.
class SimpleDirectoryReader(BaseReader):
"""Simple directory reader.
Can read files into separate documents, or concatenates
files into one document text.
Args:
input_dir (str): Path to the directory.
input_files (List): List of file paths to read (Optional; overrides input_dir)
exclude_hidden (bool): Whether to exclude hidden files (dotfiles).
errors (str): how encoding and decoding errors are to be handled,
see https://docs.python.org/3/library/functions.html#open
recursive (bool): Whether to recursively search in subdirectories.
False by default.
required_exts (Optional[List[str]]): List of required extensions.
Default is None.
file_extractor (Optional[Dict[str, BaseParser]]): A mapping of file
extension to a BaseParser class that specifies how to convert that file
to text. See DEFAULT_FILE_EXTRACTOR.
num_files_limit (Optional[int]): Maximum number of files to read.
Default is None.
file_metadata (Optional[Callable[str, Dict]]): A function that takes
in a filename and returns a Dict of metadata for the Document.
Default is None.
"""
Here we can define the place to read file, the extensions to read and how to extract content and also file_metadata as a callable.
In the load_data, it basically go through the file list and use specific parser and load metadata using the file metadata. We can also concatenate all to one document if needed.
def load_data(self, concatenate: bool = False) -> List[Document]:
"""Load data from the input directory.
Args:
concatenate (bool): whether to concatenate all files into one document.
If set to True, file metadata is ignored.
False by default.
Returns:
List[Document]: A list of documents.
"""
data: Union[str, List[str]] = ""
data_list: List[str] = []
metadata_list = []
for input_file in self.input_files:
if input_file.suffix in self.file_extractor:
parser = self.file_extractor[input_file.suffix]
if not parser.parser_config_set:
parser.init_parser()
data = parser.parse_file(input_file, errors=self.errors)
else:
# do standard read
with open(input_file, "r", errors=self.errors) as f:
data = f.read()
if isinstance(data, List):
data_list.extend(data)
else:
data_list.append(str(data))
if self.file_metadata is not None:
metadata_list.append(self.file_metadata(str(input_file)))
if concatenate:
return [Document("\n".join(data_list))]
elif self.file_metadata is not None:
return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
else:
return [Document(d) for d in data_list]
Indices
This is the key part of the whole project and since it is called llamaindex.
Lots of things can be learned in this part. Here I can only go through the basic examples.
Short Example
Here I will show how people can use it and then go deep.
documents = SimpleDirectoryReader('data').load_data()
index = GPTTreeIndex(documents)
new_index.query("What is the name of the professional women's basketball team in New York City?")
Looks very simple, right? Let’s go deep.
GPTTreeIndex class
The class for GPTTreeIndex is defined as GPTTreeIndex(BaseGPTIndex[IndexGraph])
.
For BaseGPTIndex
is defined as class BaseGPTIndex(Generic[IS])
and IS is IS = TypeVar("IS", bound=IndexStruct)
For IndexStruct
is defined as class IndexStruct(BaseDocument, DataClassJsonMixin)
Here we can see it is also a BaseDocument for IndexStruct, which means have the following field:
- text: Optional[str] = None
- doc_id: Optional[str] = None
- embedding: Optional[List[float]] = None
- extra_info: Optional[Dict[str, Any]] = None
IndexGraph
is defined as class IndexGraph(IndexStruct)
and Node is also defined as Node(IndexStruct)
Here is a graph generated by chatgpt to show the relationship.
A lot of abstraction, what each layer defined?
- IndexStruct is empty
- Node is Base struct used in most indices.
- index: int = 0
- child_indices: Set[int] = field(default_factory=set)
- embedding: Optional[List[float]] = None
- ref_doc_id: Optional[str] = None
- node_info: Optional[Dict[str, Any]] = None
- looks like embedding got redefine since it is defined in BaseDocument
- IndexGraph is a graph representing the tree-structured index.
- defined the following fields
- all_nodes: Dict[int, Node]
- root_nodes: Dict[int, Node]
- here key is a index for each node and it is global node index for the whole index object
- defined the following fields
- BaseGPTIndex is the Base GPT Index
- documents: Optional[Sequence[DOCUMENTS_INPUT]]
- index_struct: Optional[IS]
- llm_predictor: Optional[LLMPredictor]
- embed_model: Optional[BaseEmbedding]
- docstore: Optional[DocumentStore]
- prompt_helper: Optional[PromptHelper]
- chunk_size_limit: Optional[int]
- it can load documents to build index
- add/update/delete doc from docid
- use QueryRunner to query index based on user question
- GPTTreeIndex
- use GPTTreeIndexBuilder to build tree
There are also other index, will discuss in other posts.
GPTTreeIndexBuilder
Let’s look deep in GPTTreeIndexBuilder.
The key is in the build_index_from_nodes.
def build_index_from_nodes(
self,
cur_nodes: Dict[int, Node],
all_nodes: Dict[int, Node],
verbose: bool = False,
) -> Dict[int, Node]:
"""Consolidates chunks recursively, in a bottoms-up fashion."""
cur_node_list = get_sorted_node_list(cur_nodes) # sort based on the node index
cur_index = len(all_nodes)
new_node_dict = {}
print(
f"> Building index from nodes: {len(cur_nodes) // self.num_children} chunks"
)
for i in range(0, len(cur_node_list), self.num_children):
cur_nodes_chunk = cur_node_list[i : i + self.num_children]
## get num_children text chunk from cur_node_list
text_chunk = self._prompt_helper.get_text_from_nodes(
cur_nodes_chunk, prompt=self.summary_prompt
)
## use LLM to summary the text
new_summary, _ = self._llm_predictor.predict(
self.summary_prompt, context_str=text_chunk
)
## get new node and save child node
new_node = Node(
text=new_summary,
index=cur_index,
child_indices={n.index for n in cur_nodes_chunk},
)
new_node_dict[cur_index] = new_node
cur_index += 1
## update all nodes
all_nodes.update(new_node_dict)
if len(new_node_dict) <= self.num_children:
return new_node_dict
else:
# recursive build node
return self.build_index_from_nodes(new_node_dict, all_nodes)
root_nodes is the return of this function.
QueryRunner
QueryRunner can take in an index and prompt and get answer based on the index.
It is also very complicated, let’s following the high level class strcture.
QueryRunner
is defined as class QueryRunner(BaseQueryRunner)
BaseQueryRunner only defined one abstruct function which is def query(self, query: str, index_struct: IndexStruct) -> Response
Response
is defined as, so we can know how we got this response.
class Response:
"""Response.
Attributes:
response: The response text.
"""
response: Optional[str]
source_nodes: List[SourceNode] = field(default_factory=list)
extra_info: Optional[Dict[str, Any]] = None
In QueryRunner
, we can see how query function got implemented:
def query(self, query_str: str, index_struct: IndexStruct) -> Response:
"""Run query."""
index_struct_type = IndexStructType.from_index_struct(index_struct)
if index_struct_type not in self._config_dict:
raise ValueError(f"IndexStructType {index_struct_type} not in config_dict")
config = self._config_dict[index_struct_type]
mode = config.query_mode
query_cls = get_query_cls(index_struct_type, mode)
# if recursive, pass self as query_runner to each individual query
query_runner = self if self._recursive else None
query_kwargs = self._get_query_kwargs(config)
query_obj = query_cls(
index_struct,
**query_kwargs,
query_runner=query_runner,
docstore=self._docstore,
)
return query_obj.query(query_str, verbose=self._verbose)
So the key for its actually defined in get_query_cls
and keep this post simple, here it could return one GPTTreeIndexLeafQuery
.
GPTTreeIndexLeafQuery
is defined as class GPTTreeIndexLeafQuery(BaseGPTIndexQuery[IndexGraph])
BaseGPTIndexQuery is defined as class BaseGPTIndexQuery(Generic[IS])
.
- BaseGPTIndexQuery defines the following properties:
- index_struct: IS
- llm_predictor: Optional[LLMPredictor]
- prompt_helper: Optional[PromptHelper]
- embed_model: Optional[BaseEmbedding]
- docstore: Optional[DocumentStore]
- query_runner: Optional[BaseQueryRunner]
- required_keywords: Optional[List[str]]
- exclude_keywords: Optional[List[str]]
- response_mode: ResponseMode = ResponseMode.DEFAULT,
- text_qa_template: Optional[QuestionAnswerPrompt]
- refine_template: Optional[RefinePrompt]
- include_summary: bool = False,
- response_kwargs: Optional[Dict]
- similarity_cutoff: Optional[float]
- _query function that can be overloaded by sub classes
- query function take _query result and build source nodes for Response class
- default _query is defined as following
-
- get_nodes_and_similarities_for_response
-
- _should_use_node check cut off
-
For GPTTreeIndexLeafQuery, the _query function is defined as following:
def _query(self, query_str: str, verbose: bool = False) -> Response:
"""Answer a query."""
# NOTE: this overrides the _query method in the base class
print(f"> Starting query: {query_str}")
source_builder = ResponseSourceBuilder()
response_str = self._query_level(
self.index_struct.root_nodes,
query_str,
source_builder,
level=0,
verbose=verbose,
).strip()
return Response(response_str, source_nodes=source_builder.get_sources())
So we query the root nodes for first.
In _query_level :
def _query_level(
self,
cur_nodes: Dict[int, Node],
query_str: str,
source_builder: ResponseSourceBuilder,
level: int = 0,
verbose: bool = False,
) -> str:
"""Answer a query recursively."""
cur_node_list = get_sorted_node_list(cur_nodes)
if self.child_branch_factor == 1:
query_template = self.query_template.partial_format(
num_chunks=len(cur_node_list), query_str=query_str
)
# numbered_node_text is a str and contains all nodes info
numbered_node_text = self._prompt_helper.get_numbered_text_from_nodes(
cur_node_list, prompt=query_template
)
# get LLM response and result is node index
response, formatted_query_prompt = self._llm_predictor.predict(
query_template,
context_list=numbered_node_text,
)
else:
# ignore here in the post
pass
# regex to find number
numbers = extract_numbers_given_response(response, n=self.child_branch_factor)
result_response = None
for number_str in numbers:
number = int(number_str)
if number > len(cur_node_list):
if verbose:
print(
f">[Level {level}] Invalid response: {response} - "
f"number {number} out of range"
)
return response
# number is 1-indexed, so subtract 1
selected_node = cur_node_list[number - 1]
print(
f">[Level {level}] Selected node: "
f"[{number}]/[{','.join([str(int(n)) for n in numbers])}]"
)
if verbose:
summary_text = " ".join(selected_node.get_text().splitlines())
fmt_summary_text = truncate_text(summary_text, 100)
print(
f">[Level {level}] Node "
f"[{number}] Summary text: {fmt_summary_text}"
)
# query the select node child
result_response = self._query_with_selected_node(
selected_node,
query_str,
source_builder,
prev_response=result_response,
level=level,
verbose=verbose,
)
# result_response should not be None
return cast(str, result_response)
def _query_with_selected_node(
self,
selected_node: Node,
query_str: str,
source_builder: ResponseSourceBuilder,
prev_response: Optional[str] = None,
level: int = 0,
verbose: bool = False,
) -> str:
"""Get response for selected node.
If not leaf node, it will recursively call _query on the child nodes.
If prev_response is provided, we will update prev_response with the answer.
"""
# if no child, add source node to source_builder
if len(selected_node.child_indices) == 0:
response_builder = ResponseBuilder(
self._prompt_helper,
self._llm_predictor,
self.text_qa_template,
self.refine_template,
)
source_builder.add_node(selected_node)
# use response builder to get answer from node
node_text, _ = self._get_text_from_node(
query_str, selected_node, verbose=verbose, level=level
)
cur_response = response_builder.get_response_over_chunks(
query_str, [node_text], prev_response=prev_response, verbose=verbose
)
if verbose:
print(f">[Level {level}] Current answer response: {cur_response} ")
else:
# query next level with all child node
cur_response = self._query_level(
{
i: self.index_struct.all_nodes[i]
for i in selected_node.child_indices
},
query_str,
source_builder,
level=level + 1,
verbose=verbose,
)
if prev_response is None:
return cur_response
else:
# concat the prev response with current response with LLM
context_msg = "\n".join([selected_node.get_text(), cur_response])
cur_response, formatted_refine_prompt = self._llm_predictor.predict(
self.refine_template,
query_str=query_str,
existing_answer=prev_response,
context_msg=context_msg,
)
if verbose:
print(f">[Level {level}] Refine prompt: {formatted_refine_prompt}")
print(f">[Level {level}] Current refined response: {cur_response} ")
return cur_response
Summary
this is a deep dive for basic component in llamaindex and I believe it is a good start to learn the whole project and help me to use it.
There is a lot of things we can set to tune to improve its performance on our own project. The only way to learn is to read the code, or use llamaindex and llm to understand the code repo quickly.