mirror of
https://github.com/cpacker/MemGPT.git
synced 2023-10-17 01:28:22 +03:00
fix paging bug, implement llamaindex api search on top of memgpt
This commit is contained in:
@@ -71,9 +71,13 @@ async def function_message(msg):
|
||||
print(f'{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:')
|
||||
try:
|
||||
msg_dict = eval(function_args)
|
||||
print(f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t→ {msg_dict["new_content"]}')
|
||||
if function_name == 'archival_memory_search':
|
||||
print(f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}')
|
||||
else:
|
||||
print(f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
printd(e)
|
||||
printd(msg_dict)
|
||||
pass
|
||||
else:
|
||||
printd(f"Warning: did not recognize function message")
|
||||
|
||||
10
main.py
10
main.py
@@ -16,7 +16,7 @@ import memgpt.presets as presets
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
from memgpt.persistence_manager import InMemoryStateManager as persistence_manager
|
||||
from memgpt.persistence_manager import InMemoryStateManager, InMemoryStateManagerWithFaiss
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona")
|
||||
@@ -24,6 +24,7 @@ flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Speci
|
||||
flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model")
|
||||
flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence")
|
||||
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
|
||||
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage to load (a folder with a .index and .json describing documents to be loaded)")
|
||||
|
||||
|
||||
def clear_line():
|
||||
@@ -43,7 +44,12 @@ async def main():
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
print("Running... [exit by typing '/exit']")
|
||||
|
||||
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager())
|
||||
if FLAGS.archival_storage_faiss_path:
|
||||
index, archival_database = utils.prepare_archival_index(FLAGS.archival_storage_faiss_path)
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
|
||||
else:
|
||||
persistence_manager = InMemoryStateManager()
|
||||
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager)
|
||||
print_messages = interface.print_messages
|
||||
await print_messages(memgpt_agent.messages)
|
||||
|
||||
|
||||
@@ -624,7 +624,7 @@ class AgentAsync(object):
|
||||
return None
|
||||
|
||||
async def recall_memory_search(self, query, count=5, page=0):
|
||||
results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page)
|
||||
results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page*count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
@@ -635,7 +635,7 @@ class AgentAsync(object):
|
||||
return results_str
|
||||
|
||||
async def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
|
||||
results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page)
|
||||
results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page*count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
@@ -650,7 +650,7 @@ class AgentAsync(object):
|
||||
return None
|
||||
|
||||
async def archival_memory_search(self, query, count=5, page=0):
|
||||
results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page)
|
||||
results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page*count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import datetime
|
||||
import re
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
from .utils import cosine_similarity, get_local_time, printd
|
||||
from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
@@ -239,6 +241,85 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
"""Dummy in-memory version of an archival memory database, using a FAISS
|
||||
index for fast nearest-neighbors embedding search.
|
||||
|
||||
Archival memory is effectively "infinite" overflow for core memory,
|
||||
and is read-only via string queries.
|
||||
|
||||
Archival Memory: A more structured and deep storage space for the AI's reflections,
|
||||
insights, or any other data that doesn't fit into the active memory but
|
||||
is essential enough not to be left only to the recall memory.
|
||||
"""
|
||||
|
||||
def __init__(self, index=None, archival_memory_database=None, embedding_model='text-embedding-ada-002', k=100):
|
||||
if index is None:
|
||||
self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
|
||||
else:
|
||||
self.index = index
|
||||
self.k = k
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
self.embedding_model = embedding_model
|
||||
self.embeddings_dict = {}
|
||||
self.search_results = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self._archive)
|
||||
|
||||
async def insert(self, memory_string, embedding=None):
|
||||
if embedding is None:
|
||||
# Get the embedding
|
||||
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
|
||||
|
||||
self._archive.append({
|
||||
# can eventually upgrade to adding semantic tags, etc
|
||||
'timestamp': get_local_time(),
|
||||
'content': memory_string,
|
||||
})
|
||||
embedding = np.array([embedding]).astype('float32')
|
||||
self.index.add(embedding)
|
||||
|
||||
async def search(self, query_string, count=None, start=None):
|
||||
"""Simple embedding-based search (inefficient, no caching)"""
|
||||
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
|
||||
|
||||
# query_embedding = get_embedding(query_string, model=self.embedding_model)
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
if query_string in self.embeddings_dict:
|
||||
query_embedding = self.embeddings_dict[query_string]
|
||||
search_result = self.search_results[query_string]
|
||||
else:
|
||||
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
_, indices = self.index.search(np.array([np.array(query_embedding, dtype=np.float32)]), self.k)
|
||||
search_result = [self._archive[idx] if idx < len(self._archive) else "" for idx in indices[0]]
|
||||
self.embeddings_dict[query_string] = query_embedding
|
||||
self.search_results[query_string] = search_result
|
||||
|
||||
if start is not None and count is not None:
|
||||
toprint = search_result[start:start+count]
|
||||
else:
|
||||
if len(search_result) >= 5:
|
||||
toprint = search_result[:5]
|
||||
else:
|
||||
toprint = search_result
|
||||
printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}")
|
||||
|
||||
# Extract the sorted archive without the scores
|
||||
matches = search_result
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start:start+count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
return matches[start:], len(matches)
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class RecallMemory(ABC):
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings
|
||||
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss
|
||||
from .utils import get_local_time, printd
|
||||
|
||||
|
||||
@@ -88,4 +88,26 @@ class InMemoryStateManager(PersistenceManager):
|
||||
class InMemoryStateManagerWithEmbeddings(InMemoryStateManager):
|
||||
|
||||
archival_memory_cls = DummyArchivalMemoryWithEmbeddings
|
||||
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
||||
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
||||
|
||||
class InMemoryStateManagerWithFaiss(InMemoryStateManager):
|
||||
archival_memory_cls = DummyArchivalMemoryWithFaiss
|
||||
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
||||
|
||||
def __init__(self, archival_index, archival_memory_db, a_k=100):
|
||||
super().__init__()
|
||||
self.archival_index = archival_index
|
||||
self.archival_memory_db = archival_memory_db
|
||||
self.a_k = a_k
|
||||
|
||||
def init(self, agent):
|
||||
print(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
||||
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
||||
|
||||
# Persistence manager also handles DB-related state
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
self.archival_memory = self.archival_memory_cls(index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k)
|
||||
|
||||
13
memgpt/personas/examples/docqa/README.md
Normal file
13
memgpt/personas/examples/docqa/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# MemGPT Search over LlamaIndex API Docs
|
||||
|
||||
1.
|
||||
a. Download embeddings and docs index from XYZ.
|
||||
-- OR --
|
||||
b. Build the index:
|
||||
1. Build llama_index API docs with `make text`. Instructions [here](https://github.com/run-llama/llama_index/blob/main/docs/DOCS_README.md). Copy over the generated `_build/text` folder to this directory.
|
||||
2. Generate embeddings and FAISS index.
|
||||
```bash
|
||||
python3 scrape_docs.py
|
||||
python3 generate_embeddings_for_docs.py all_docs.jsonl
|
||||
python3 build_index.py --embedding_files all_docs.embeddings.jsonl --output_index_file all_docs.index
|
||||
```
|
||||
@@ -1,3 +1,6 @@
|
||||
My name is MemGPT.
|
||||
I am an AI assistant designed to help human users with document analysis.
|
||||
I can use this space in my core memory to keep track of my current tasks and goals.
|
||||
I can use this space in my core memory to keep track of my current tasks and goals.
|
||||
|
||||
The answer to the human's question will usually be located somewhere in your archival memory, so keep paging through results until you find enough information to construct an answer.
|
||||
Do not respond to the human until you have arrived at an answer.
|
||||
@@ -4,6 +4,8 @@ import demjson3 as demjson
|
||||
import numpy as np
|
||||
import json
|
||||
import pytz
|
||||
import os
|
||||
import faiss
|
||||
|
||||
|
||||
# DEBUG = True
|
||||
@@ -61,3 +63,20 @@ def parse_json(string):
|
||||
except demjson.JSONDecodeError as e:
|
||||
print(f"Error parsing json with demjson package: {e}")
|
||||
raise e
|
||||
|
||||
def prepare_archival_index(folder):
|
||||
index_file = os.path.join(folder, "all_docs.index")
|
||||
index = faiss.read_index(index_file)
|
||||
|
||||
archival_database_file = os.path.join(folder, "all_docs.jsonl")
|
||||
archival_database = []
|
||||
with open(archival_database_file, 'rt') as f:
|
||||
all_data = [json.loads(line) for line in f]
|
||||
for doc in all_data:
|
||||
total = len(doc)
|
||||
for i, passage in enumerate(doc):
|
||||
archival_database.append({
|
||||
'content': f"[Title: {passage['title']}, {i}/{total}] {passage['text']}",
|
||||
'timestamp': get_local_time(),
|
||||
})
|
||||
return index, archival_database
|
||||
Reference in New Issue
Block a user