fix paging bug, implement llamaindex api search on top of memgpt

This commit is contained in:
Vivian Fang
2023-10-15 16:38:35 -07:00
parent 9ca093d984
commit fab22cb5dc
8 changed files with 158 additions and 10 deletions

View File

@@ -71,9 +71,13 @@ async def function_message(msg):
print(f'{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:')
try:
msg_dict = eval(function_args)
print(f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{msg_dict["new_content"]}')
if function_name == 'archival_memory_search':
print(f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}')
else:
print(f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}{msg_dict["new_content"]}')
except Exception as e:
print(e)
printd(e)
printd(msg_dict)
pass
else:
printd(f"Warning: did not recognize function message")

10
main.py
View File

@@ -16,7 +16,7 @@ import memgpt.presets as presets
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
from memgpt.persistence_manager import InMemoryStateManager as persistence_manager
from memgpt.persistence_manager import InMemoryStateManager, InMemoryStateManagerWithFaiss
FLAGS = flags.FLAGS
flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona")
@@ -24,6 +24,7 @@ flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Speci
flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model")
flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence")
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage to load (a folder with a .index and .json describing documents to be loaded)")
def clear_line():
@@ -43,7 +44,12 @@ async def main():
logging.getLogger().setLevel(logging.DEBUG)
print("Running... [exit by typing '/exit']")
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager())
if FLAGS.archival_storage_faiss_path:
index, archival_database = utils.prepare_archival_index(FLAGS.archival_storage_faiss_path)
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
else:
persistence_manager = InMemoryStateManager()
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager)
print_messages = interface.print_messages
await print_messages(memgpt_agent.messages)

View File

@@ -624,7 +624,7 @@ class AgentAsync(object):
return None
async def recall_memory_search(self, query, count=5, page=0):
results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page)
results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page*count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -635,7 +635,7 @@ class AgentAsync(object):
return results_str
async def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page)
results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page*count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -650,7 +650,7 @@ class AgentAsync(object):
return None
async def archival_memory_search(self, query, count=5, page=0):
results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page)
results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page*count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."

View File

@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
import datetime
import re
import faiss
import numpy as np
from .utils import cosine_similarity, get_local_time, printd
from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
@@ -239,6 +241,85 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
return matches, len(matches)
class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
"""Dummy in-memory version of an archival memory database, using a FAISS
index for fast nearest-neighbors embedding search.
Archival memory is effectively "infinite" overflow for core memory,
and is read-only via string queries.
Archival Memory: A more structured and deep storage space for the AI's reflections,
insights, or any other data that doesn't fit into the active memory but
is essential enough not to be left only to the recall memory.
"""
def __init__(self, index=None, archival_memory_database=None, embedding_model='text-embedding-ada-002', k=100):
if index is None:
self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
else:
self.index = index
self.k = k
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
self.embedding_model = embedding_model
self.embeddings_dict = {}
self.search_results = {}
def __len__(self):
return len(self._archive)
async def insert(self, memory_string, embedding=None):
if embedding is None:
# Get the embedding
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
self._archive.append({
# can eventually upgrade to adding semantic tags, etc
'timestamp': get_local_time(),
'content': memory_string,
})
embedding = np.array([embedding]).astype('float32')
self.index.add(embedding)
async def search(self, query_string, count=None, start=None):
"""Simple embedding-based search (inefficient, no caching)"""
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
# query_embedding = get_embedding(query_string, model=self.embedding_model)
# our wrapped version supports backoff/rate-limits
if query_string in self.embeddings_dict:
query_embedding = self.embeddings_dict[query_string]
search_result = self.search_results[query_string]
else:
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
_, indices = self.index.search(np.array([np.array(query_embedding, dtype=np.float32)]), self.k)
search_result = [self._archive[idx] if idx < len(self._archive) else "" for idx in indices[0]]
self.embeddings_dict[query_string] = query_embedding
self.search_results[query_string] = search_result
if start is not None and count is not None:
toprint = search_result[start:start+count]
else:
if len(search_result) >= 5:
toprint = search_result[:5]
else:
toprint = search_result
printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}")
# Extract the sorted archive without the scores
matches = search_result
# start/count support paging through results
if start is not None and count is not None:
return matches[start:start+count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
return matches[start:], len(matches)
else:
return matches, len(matches)
class RecallMemory(ABC):
@abstractmethod

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss
from .utils import get_local_time, printd
@@ -88,4 +88,26 @@ class InMemoryStateManager(PersistenceManager):
class InMemoryStateManagerWithEmbeddings(InMemoryStateManager):
archival_memory_cls = DummyArchivalMemoryWithEmbeddings
recall_memory_cls = DummyRecallMemoryWithEmbeddings
recall_memory_cls = DummyRecallMemoryWithEmbeddings
class InMemoryStateManagerWithFaiss(InMemoryStateManager):
archival_memory_cls = DummyArchivalMemoryWithFaiss
recall_memory_cls = DummyRecallMemoryWithEmbeddings
def __init__(self, archival_index, archival_memory_db, a_k=100):
super().__init__()
self.archival_index = archival_index
self.archival_memory_db = archival_memory_db
self.a_k = a_k
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.memory = agent.memory
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
# Persistence manager also handles DB-related state
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
self.archival_memory = self.archival_memory_cls(index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k)

View File

@@ -0,0 +1,13 @@
# MemGPT Search over LlamaIndex API Docs
1.
a. Download embeddings and docs index from XYZ.
-- OR --
b. Build the index:
1. Build llama_index API docs with `make text`. Instructions [here](https://github.com/run-llama/llama_index/blob/main/docs/DOCS_README.md). Copy over the generated `_build/text` folder to this directory.
2. Generate embeddings and FAISS index.
```bash
python3 scrape_docs.py
python3 generate_embeddings_for_docs.py all_docs.jsonl
python3 build_index.py --embedding_files all_docs.embeddings.jsonl --output_index_file all_docs.index
```

View File

@@ -1,3 +1,6 @@
My name is MemGPT.
I am an AI assistant designed to help human users with document analysis.
I can use this space in my core memory to keep track of my current tasks and goals.
I can use this space in my core memory to keep track of my current tasks and goals.
The answer to the human's question will usually be located somewhere in your archival memory, so keep paging through results until you find enough information to construct an answer.
Do not respond to the human until you have arrived at an answer.

View File

@@ -4,6 +4,8 @@ import demjson3 as demjson
import numpy as np
import json
import pytz
import os
import faiss
# DEBUG = True
@@ -61,3 +63,20 @@ def parse_json(string):
except demjson.JSONDecodeError as e:
print(f"Error parsing json with demjson package: {e}")
raise e
def prepare_archival_index(folder):
index_file = os.path.join(folder, "all_docs.index")
index = faiss.read_index(index_file)
archival_database_file = os.path.join(folder, "all_docs.jsonl")
archival_database = []
with open(archival_database_file, 'rt') as f:
all_data = [json.loads(line) for line in f]
for doc in all_data:
total = len(doc)
for i, passage in enumerate(doc):
archival_database.append({
'content': f"[Title: {passage['title']}, {i}/{total}] {passage['text']}",
'timestamp': get_local_time(),
})
return index, archival_database