Skip to content

Multi Document RAG

Multi Document RAG is not as simple as single document QA. zyx has decided to let the professionals handle this one, and provides a simple interface for working with Document objects using the chromadb library, with a simple wrapper interface.


Example

Lets begin by loading the documents we're going to use for this example.

import zyx

links = [
    "https://openreview.net/pdf?id=zAdUB0aCTQ", # AgentBench: Evaluating LLMs as Agents
    "https://openreview.net/pdf?id=z8TW0ttBPp", # MathCoder: Seamless Code Integration in LLMs for Enhanced Mathematical Reasoning
    "https://openreview.net/pdf?id=yoVq2BGQdP", # Achieving Fairness in Multi-Agent MDP Using Reinforcement Learning
    "https://openreview.net/pdf?id=yRrPfKyJQ2", # Conversational Drug Editing Using Retrieval and Domain Feedback
]

documents = zyx.read(links)                     # Handles lists of links/paths as well

Lets now investigate the documents we read

for doc in documents:
    print("---")
    print(doc.content[:200])
Output
---
Published as a conference paper at ICLR 2024
AGENT BENCH : EVALUATING LLM S AS AGENTS
Xiao Liu1,*, Hao Yu1,*,†, Hanchen Zhang1,*, Yifan Xu1, Xuanyu Lei1, Hanyu Lai1, Yu Gu2,†,
Hangliang Ding1, Kaiwen
---
Published as a conference paper at ICLR 2024
MATHCODER : S EAMLESS CODE INTEGRATION IN
LLM S FOR ENHANCED MATHEMATICAL REASONING
Ke Wang1,4∗Houxing Ren1∗Aojun Zhou1∗Zimu Lu1∗Sichun Luo3∗
Weikang Shi1∗
---
Published as a conference paper at ICLR 2024
ACHIEVING FAIRNESS IN MULTI -AGENT MDP U SING
REINFORCEMENT LEARNING
Peizhong Ju
Department of ECE
The Ohio State University
Columbus, OH 43210, USA
ju.171
---
Published as a conference paper at ICLR 2024
CONVERSATIONAL DRUG EDITING USING RETRIEVAL
AND DOMAIN FEEDBACK
Shengchao Liu1 *, Jiongxiao Wang2 *, Yijin Yang3, Chengpeng Wang4, Ling Liu5,
Hongyu Guo6,7

Creating a Memory Store

# Initialize an on memory store
store = zyx.Memory()

Now lets add our documents to the store

store.add(documents)

# Now we can use the store to search for documents
# One of our papers is about LLM's in the domain of Drug Editing
results = store.search("Drug Editing")

LLM Completions in the Store

# We can also wuery our store with an LLM
response = store.completion("How have LLM's been used in the domain of Drug Editing?")

print(response)
Output
ChatCompletion(
    id='chatcmpl-ACvGG7JCm2pCwIZgxNCQa5Iew9HEZ',
    choices=[
        Choice(
            finish_reason='stop',
            index=0,
            logprobs=None,
            message=ChatCompletionMessage(
                content='Large Language Models (LLMs) have been utilized in the domain of drug editing primarily for their capabilities in data analysis,
predictive modeling, and natural language processing. They assist in the identification of potential drug candidates by analyzing vast databases of chemical
compounds and biological data. LLMs can predict the interactions between drugs and biological targets, facilitate the design of novel drug molecules, and
streamline the drug discovery process by automating literature reviews and synthesizing relevant information. Moreover, their ability to generate hypotheses and
simulate molecular interactions aids researchers in optimizing drug formulations and improving efficacy. Overall, LLMs enhance efficiency and innovation in drug
editing and development.',
                refusal=None,
                role='assistant',
                function_call=None,
                tool_calls=None
            )
        )
    ],
    created=1727643412,
    model='gpt-4o-mini-2024-07-18',
    object='chat.completion',
    service_tier=None,
    system_fingerprint='fp_f85bea6784',
    usage=CompletionUsage(completion_tokens=126, prompt_tokens=46, total_tokens=172, completion_tokens_details=CompletionTokensDetails(reasoning_tokens=0))
)

API Reference

Class for storing and retrieving data using Chroma.

Source code in zyx/resources/stores/memory.py
class Memory:
    """
    Class for storing and retrieving data using Chroma.
    """

    def __init__(
        self,
        collection_name: str = "my_collection",
        model_class: Optional[Type[BaseModel]] = None,
        embedding_api_key: Optional[str] = None,
        location: Union[Literal[":memory:"], str] = ":memory:",
        persist_directory: str = "chroma_db",
        chunk_size: int = 512,
        model: str = "gpt-4o-mini",
    ):
        """
        Class for storing and retrieving data using Chroma.

        Args:
            collection_name (str): The name of the collection.
            model_class (Type[BaseModel], optional): Model class for storing data.
            embedding_api_key (str, optional): API key for embedding model.
            location (str): ":memory:" for in-memory database or a string path for persistent storage.
            persist_directory (str): Directory for persisting Chroma database (if not using in-memory storage).
            chunk_size (int): Size of chunks for text splitting.
            model (str): Model name for text summarization.
        """

        self.collection_name = collection_name
        self.embedding_api_key = embedding_api_key
        self.model_class = model_class
        self.location = location
        self.persist_directory = persist_directory
        self.chunk_size = chunk_size
        self.model = model

        self.client = self._initialize_client()
        self.collection = self._create_or_get_collection()

    def _initialize_client(self):
        """
        Initialize Chroma client. Use in-memory database if location is ":memory:",
        otherwise, use persistent storage at the specified directory.
        """
        if self.location == ":memory:":
            logger.info("Using in-memory Chroma storage.")
            return Client()  # In-memory by default
        else:
            logger.info(f"Using persistent Chroma storage at {self.persist_directory}.")
            settings = Settings(persist_directory=self.persist_directory)
            return Client(settings)

    def _create_or_get_collection(self):
        """Retrieve or create a Chroma collection with a custom embedding function."""
        embedding_fn = CustomEmbeddingFunction(api_key=self.embedding_api_key)
        if self.collection_name in self.client.list_collections():
            logger.info(f"Collection '{self.collection_name}' already exists.")
            return self.client.get_collection(
                self.collection_name, embedding_function=embedding_fn
            )
        else:
            logger.info(f"Creating collection '{self.collection_name}'.")
            return self.client.create_collection(
                name=self.collection_name, embedding_function=embedding_fn
            )

    def _get_embedding(self, text: str) -> List[float]:
        """Generate embeddings for a given text using the custom embedding function.

        Args:
            text (str): The text to generate an embedding for.

        Returns:
            List[float]: The embedding for the text.
        """
        embedding_fn = CustomEmbeddingFunction(api_key=self.embedding_api_key)
        return embedding_fn([text])[0]  # Return the first (and only) embedding

    def add(
        self,
        data: Union[str, List[str], Document, List[Document]],
        metadata: Optional[dict] = None,
    ):
        """Add documents or data to Chroma.

        Args:
            data (Union[str, List[str], Document, List[Document]]): The data to add to Chroma.
            metadata (Optional[dict]): The metadata to add to the data.
        """
        if isinstance(data, str):
            data = [data]
        elif isinstance(data, Document):
            data = [data]

        ids, embeddings, texts, metadatas = [], [], [], []

        for item in data:
            try:
                if isinstance(item, Document):
                    text = item.content
                    metadata = item.metadata
                else:
                    text = item

                # Chunk the content
                chunks = chunk(text, chunk_size=self.chunk_size, model=self.model)

                for chunk_text in chunks:
                    embedding_vector = self._get_embedding(chunk_text)
                    ids.append(str(uuid.uuid4()))
                    embeddings.append(embedding_vector)
                    texts.append(chunk_text)
                    chunk_metadata = metadata.copy() if metadata else {}
                    chunk_metadata["chunk"] = True
                    metadatas.append(chunk_metadata)
            except Exception as e:
                logger.error(f"Error processing item: {item}. Error: {e}")

        if embeddings:
            try:
                # Ensure metadatas is not empty
                metadatas = [m if m else {"default": "empty"} for m in metadatas]
                self.collection.add(
                    ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
                )
                logger.info(
                    f"Successfully added {len(embeddings)} chunks to the collection."
                )
            except Exception as e:
                logger.error(f"Error adding points to collection: {e}")
        else:
            logger.warning("No valid embeddings to add to the collection.")

    def search(self, query: str, top_k: int = 5) -> SearchResponse:
        """Search in Chroma collection.

        Args:
            query (str): The query to search for.
            top_k (int): The number of results to return.

        Returns:
            SearchResponse: The search results.
        """
        try:
            query_embedding = self._get_embedding(query)
            search_results = self.collection.query(
                query_embeddings=[query_embedding], n_results=top_k
            )

            nodes = []
            for i in range(len(search_results["ids"][0])):  # Note the [0] here
                node = ChromaNode(
                    id=search_results["ids"][0][i],
                    text=search_results["documents"][0][i],
                    embedding=query_embedding,
                    metadata=search_results["metadatas"][0][i]
                    if search_results["metadatas"]
                    else {},
                )
                nodes.append(node)
            return SearchResponse(query=query, results=nodes)
        except Exception as e:
            logger.error(f"Error during search: {e}")
            return SearchResponse(query=query)  # Return empty results on error

    def _summarize_results(self, results: List[ChromaNode]) -> str:
        """Summarize the search results.

        Args:
            results (List[ChromaNode]): The search results.

        Returns:
            str: The summary of the search results.
        """

        class SummaryModel(BaseModel):
            summary: str

        texts = [node.text for node in results]
        combined_text = "\n\n".join(texts)

        summary = generate(
            SummaryModel,
            instructions="Provide a concise summary of the following text, focusing on the most important information:",
            model=self.model,
            n=1,
        )

        return summary.summary

    def completion(
        self,
        messages: Union[str, List[dict]] = None,
        model: Optional[str] = None,
        top_k: Optional[int] = 5,
        tools: Optional[List[Union[Callable, dict, BaseModel]]] = None,
        run_tools: Optional[bool] = True,
        response_model: Optional[BaseModel] = None,
        mode: Optional[InstructorMode] = "tool_call",
        base_url: Optional[str] = None,
        api_key: Optional[str] = None,
        organization: Optional[str] = None,
        top_p: Optional[float] = None,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        max_retries: Optional[int] = 3,
        verbose: Optional[bool] = False,
    ):
        """Perform completion with context from Chroma.

        Args:
            messages (Union[str, List[dict]]): The messages to use for the completion.
            model (Optional[str]): The model to use for the completion.
            top_k (Optional[int]): The number of results to return from the search.
            tools (Optional[List[Union[Callable, dict, BaseModel]]]): The tools to use for the completion.
            run_tools (Optional[bool]): Whether to run the tools for the completion.
            response_model (Optional[BaseModel]): The response model to use for the completion.
            mode (Optional[InstructorMode]): The mode to use for the completion.
            base_url (Optional[str]): The base URL to use for the completion.
            api_key (Optional[str]): The API key to use for the completion.
            organization (Optional[str]): The organization to use for the completion.
            top_p (Optional[float]): The top p to use for the completion.
            temperature (Optional[float]): The temperature to use for the completion.
            max_tokens (Optional[int]): The maximum number of tokens to generate.
            max_retries (Optional[int]): The maximum number of retries to use for the completion.
            verbose (Optional[bool]): Whether to print the messages to the console.
        """
        logger.info(f"Initial messages: {messages}")

        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        elif isinstance(messages, list):
            messages = [
                {"role": "user", "content": m} if isinstance(m, str) else m
                for m in messages
            ]

        query = messages[-1].get("content", "") if messages else ""

        try:
            results = self.search(query, top_k=top_k)
            summarized_results = self._summarize_results(results.results)
        except Exception as e:
            logger.error(f"Error during search or summarization: {e}")
            summarized_results = ""

        if messages:
            if not any(message.get("role", "") == "system" for message in messages):
                system_message = {
                    "role": "system",
                    "content": f"Relevant information retrieved: \n {summarized_results}",
                }
                messages.insert(0, system_message)
            else:
                for message in messages:
                    if message.get("role", "") == "system":
                        message["content"] += (
                            f"\nAdditional context: {summarized_results}"
                        )

        try:
            result = completion(
                messages=messages,
                model=model or self.model,
                tools=tools,
                run_tools=run_tools,
                response_model=response_model,
                mode=mode,
                base_url=base_url,
                api_key=api_key,
                organization=organization,
                top_p=top_p,
                temperature=temperature,
                max_tokens=max_tokens,
                max_retries=max_retries,
            )

            if verbose:
                logger.info(f"Completion result: {result}")

            return result
        except Exception as e:
            logger.error(f"Error during completion: {e}")
            raise

__init__(collection_name='my_collection', model_class=None, embedding_api_key=None, location=':memory:', persist_directory='chroma_db', chunk_size=512, model='gpt-4o-mini')

Class for storing and retrieving data using Chroma.

Parameters:

Name Type Description Default
collection_name str

The name of the collection.

'my_collection'
model_class Type[BaseModel]

Model class for storing data.

None
embedding_api_key str

API key for embedding model.

None
location str

":memory:" for in-memory database or a string path for persistent storage.

':memory:'
persist_directory str

Directory for persisting Chroma database (if not using in-memory storage).

'chroma_db'
chunk_size int

Size of chunks for text splitting.

512
model str

Model name for text summarization.

'gpt-4o-mini'
Source code in zyx/resources/stores/memory.py
def __init__(
    self,
    collection_name: str = "my_collection",
    model_class: Optional[Type[BaseModel]] = None,
    embedding_api_key: Optional[str] = None,
    location: Union[Literal[":memory:"], str] = ":memory:",
    persist_directory: str = "chroma_db",
    chunk_size: int = 512,
    model: str = "gpt-4o-mini",
):
    """
    Class for storing and retrieving data using Chroma.

    Args:
        collection_name (str): The name of the collection.
        model_class (Type[BaseModel], optional): Model class for storing data.
        embedding_api_key (str, optional): API key for embedding model.
        location (str): ":memory:" for in-memory database or a string path for persistent storage.
        persist_directory (str): Directory for persisting Chroma database (if not using in-memory storage).
        chunk_size (int): Size of chunks for text splitting.
        model (str): Model name for text summarization.
    """

    self.collection_name = collection_name
    self.embedding_api_key = embedding_api_key
    self.model_class = model_class
    self.location = location
    self.persist_directory = persist_directory
    self.chunk_size = chunk_size
    self.model = model

    self.client = self._initialize_client()
    self.collection = self._create_or_get_collection()

add(data, metadata=None)

Add documents or data to Chroma.

Parameters:

Name Type Description Default
data Union[str, List[str], Document, List[Document]]

The data to add to Chroma.

required
metadata Optional[dict]

The metadata to add to the data.

None
Source code in zyx/resources/stores/memory.py
def add(
    self,
    data: Union[str, List[str], Document, List[Document]],
    metadata: Optional[dict] = None,
):
    """Add documents or data to Chroma.

    Args:
        data (Union[str, List[str], Document, List[Document]]): The data to add to Chroma.
        metadata (Optional[dict]): The metadata to add to the data.
    """
    if isinstance(data, str):
        data = [data]
    elif isinstance(data, Document):
        data = [data]

    ids, embeddings, texts, metadatas = [], [], [], []

    for item in data:
        try:
            if isinstance(item, Document):
                text = item.content
                metadata = item.metadata
            else:
                text = item

            # Chunk the content
            chunks = chunk(text, chunk_size=self.chunk_size, model=self.model)

            for chunk_text in chunks:
                embedding_vector = self._get_embedding(chunk_text)
                ids.append(str(uuid.uuid4()))
                embeddings.append(embedding_vector)
                texts.append(chunk_text)
                chunk_metadata = metadata.copy() if metadata else {}
                chunk_metadata["chunk"] = True
                metadatas.append(chunk_metadata)
        except Exception as e:
            logger.error(f"Error processing item: {item}. Error: {e}")

    if embeddings:
        try:
            # Ensure metadatas is not empty
            metadatas = [m if m else {"default": "empty"} for m in metadatas]
            self.collection.add(
                ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
            )
            logger.info(
                f"Successfully added {len(embeddings)} chunks to the collection."
            )
        except Exception as e:
            logger.error(f"Error adding points to collection: {e}")
    else:
        logger.warning("No valid embeddings to add to the collection.")

completion(messages=None, model=None, top_k=5, tools=None, run_tools=True, response_model=None, mode='tool_call', base_url=None, api_key=None, organization=None, top_p=None, temperature=None, max_tokens=None, max_retries=3, verbose=False)

Perform completion with context from Chroma.

Parameters:

Name Type Description Default
messages Union[str, List[dict]]

The messages to use for the completion.

None
model Optional[str]

The model to use for the completion.

None
top_k Optional[int]

The number of results to return from the search.

5
tools Optional[List[Union[Callable, dict, BaseModel]]]

The tools to use for the completion.

None
run_tools Optional[bool]

Whether to run the tools for the completion.

True
response_model Optional[BaseModel]

The response model to use for the completion.

None
mode Optional[InstructorMode]

The mode to use for the completion.

'tool_call'
base_url Optional[str]

The base URL to use for the completion.

None
api_key Optional[str]

The API key to use for the completion.

None
organization Optional[str]

The organization to use for the completion.

None
top_p Optional[float]

The top p to use for the completion.

None
temperature Optional[float]

The temperature to use for the completion.

None
max_tokens Optional[int]

The maximum number of tokens to generate.

None
max_retries Optional[int]

The maximum number of retries to use for the completion.

3
verbose Optional[bool]

Whether to print the messages to the console.

False
Source code in zyx/resources/stores/memory.py
def completion(
    self,
    messages: Union[str, List[dict]] = None,
    model: Optional[str] = None,
    top_k: Optional[int] = 5,
    tools: Optional[List[Union[Callable, dict, BaseModel]]] = None,
    run_tools: Optional[bool] = True,
    response_model: Optional[BaseModel] = None,
    mode: Optional[InstructorMode] = "tool_call",
    base_url: Optional[str] = None,
    api_key: Optional[str] = None,
    organization: Optional[str] = None,
    top_p: Optional[float] = None,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    max_retries: Optional[int] = 3,
    verbose: Optional[bool] = False,
):
    """Perform completion with context from Chroma.

    Args:
        messages (Union[str, List[dict]]): The messages to use for the completion.
        model (Optional[str]): The model to use for the completion.
        top_k (Optional[int]): The number of results to return from the search.
        tools (Optional[List[Union[Callable, dict, BaseModel]]]): The tools to use for the completion.
        run_tools (Optional[bool]): Whether to run the tools for the completion.
        response_model (Optional[BaseModel]): The response model to use for the completion.
        mode (Optional[InstructorMode]): The mode to use for the completion.
        base_url (Optional[str]): The base URL to use for the completion.
        api_key (Optional[str]): The API key to use for the completion.
        organization (Optional[str]): The organization to use for the completion.
        top_p (Optional[float]): The top p to use for the completion.
        temperature (Optional[float]): The temperature to use for the completion.
        max_tokens (Optional[int]): The maximum number of tokens to generate.
        max_retries (Optional[int]): The maximum number of retries to use for the completion.
        verbose (Optional[bool]): Whether to print the messages to the console.
    """
    logger.info(f"Initial messages: {messages}")

    if isinstance(messages, str):
        messages = [{"role": "user", "content": messages}]
    elif isinstance(messages, list):
        messages = [
            {"role": "user", "content": m} if isinstance(m, str) else m
            for m in messages
        ]

    query = messages[-1].get("content", "") if messages else ""

    try:
        results = self.search(query, top_k=top_k)
        summarized_results = self._summarize_results(results.results)
    except Exception as e:
        logger.error(f"Error during search or summarization: {e}")
        summarized_results = ""

    if messages:
        if not any(message.get("role", "") == "system" for message in messages):
            system_message = {
                "role": "system",
                "content": f"Relevant information retrieved: \n {summarized_results}",
            }
            messages.insert(0, system_message)
        else:
            for message in messages:
                if message.get("role", "") == "system":
                    message["content"] += (
                        f"\nAdditional context: {summarized_results}"
                    )

    try:
        result = completion(
            messages=messages,
            model=model or self.model,
            tools=tools,
            run_tools=run_tools,
            response_model=response_model,
            mode=mode,
            base_url=base_url,
            api_key=api_key,
            organization=organization,
            top_p=top_p,
            temperature=temperature,
            max_tokens=max_tokens,
            max_retries=max_retries,
        )

        if verbose:
            logger.info(f"Completion result: {result}")

        return result
    except Exception as e:
        logger.error(f"Error during completion: {e}")
        raise

search(query, top_k=5)

Search in Chroma collection.

Parameters:

Name Type Description Default
query str

The query to search for.

required
top_k int

The number of results to return.

5

Returns:

Name Type Description
SearchResponse SearchResponse

The search results.

Source code in zyx/resources/stores/memory.py
def search(self, query: str, top_k: int = 5) -> SearchResponse:
    """Search in Chroma collection.

    Args:
        query (str): The query to search for.
        top_k (int): The number of results to return.

    Returns:
        SearchResponse: The search results.
    """
    try:
        query_embedding = self._get_embedding(query)
        search_results = self.collection.query(
            query_embeddings=[query_embedding], n_results=top_k
        )

        nodes = []
        for i in range(len(search_results["ids"][0])):  # Note the [0] here
            node = ChromaNode(
                id=search_results["ids"][0][i],
                text=search_results["documents"][0][i],
                embedding=query_embedding,
                metadata=search_results["metadatas"][0][i]
                if search_results["metadatas"]
                else {},
            )
            nodes.append(node)
        return SearchResponse(query=query, results=nodes)
    except Exception as e:
        logger.error(f"Error during search: {e}")
        return SearchResponse(query=query)  # Return empty results on error