feat: Introduce Memoize (#7)

* update openaiwrapper to use memoize and move parsing logic to get_embedding and chat_completion functions
* add memoize functions
* remove commented out code
* remove missed commented out code
* add comment explaining changes to prompts
* remove print statements, use environment vars, etc
This commit is contained in:
bearney74
2023-12-24 14:16:01 -06:00
committed by GitHub
parent 4436a1a0bd
commit 1767cbe48e
10 changed files with 94 additions and 46 deletions

View File

@@ -11,7 +11,6 @@ from prompt_management.prompts import (
DEFAULT_MAX_AGENTS = 20
PRIME_AGENT_WEIGHT = 25
MODEL_NAME = "gpt-4-1106-preview"
class AgentCreation:
def __init__(self, openai_wrapper: OpenAIAPIWrapper, max_agents: int = DEFAULT_MAX_AGENTS):
@@ -72,13 +71,9 @@ class AgentCreation:
{"role": "system", "content": PROMPT_ENGINEERING_SYSTEM_PROMPT},
{"role": "user", "content": PROMPT_ENGINEERING_TEMPLATE.format(goal=goal, sample_input=sample_input, examples=EXAMPLES)}
]
try:
response = self.openai_wrapper.chat_completion(
model=MODEL_NAME,
messages=messages
)
return response.choices[0].message['content'].strip()
return self.openai_wrapper.chat_completion(messages=messages)
except Exception as e:
print(f"Error generating LLM prompt: {e}")
return ""

View File

@@ -4,9 +4,6 @@ from integrations.openaiwrapper import OpenAIAPIWrapper
# Basic logging setup
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
# Constants
MODEL_NAME = "gpt-4-1106-preview"
class AgentEvaluator:
"""
Evaluates AI agent's responses using OpenAI's GPT model.
@@ -24,12 +21,7 @@ class AgentEvaluator:
"for quality/relevance. Possible Answers: Poor, Good, Perfect. "
"LLM output: '{output}'").format(input=input_text, prompt=prompt, output=output)
result = self.openai_api.chat_completion(
model=MODEL_NAME,
messages=[{"role": "system", "content": query}]
)
return result.choices[0].essage['content']
return self.openai_api.chat_completion(messages=[{"role": "system", "content": query}])
except Exception as error:
logging.info(f"Agent evaluation error: {error}")
raise
raise

View File

@@ -27,7 +27,6 @@ class AgentResponse:
react_prompt = self._build_react_prompt(input_text, conversation_accumulator, thought_number, action_number)
self.agent.update_status('Thinking .. (Iteration #' + str(thought_number) + ')')
response = self._generate_chat_response(system_prompt, react_prompt)
conversation_accumulator, thought_number, action_number = self._process_response(
response, conversation_accumulator, thought_number, action_number, input_text
)
@@ -59,12 +58,11 @@ class AgentResponse:
def _generate_chat_response(self, system_prompt, react_prompt):
return self.openai_wrapper.chat_completion(
model="gpt-4-1106-preview",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": react_prompt}
]
).choices[0].message['content']
)
def _process_response(self, response, conversation_accumulator, thought_number, action_number, input_text):
conversation_accumulator += f"\n{response}"
@@ -94,12 +92,11 @@ class AgentResponse:
def _conclude_output(self, conversation):
react_prompt = conversation
self.agent.update_status('Reviewing output')
return self.openai_wrapper.chat_completion(
model="gpt-4-1106-preview",
messages=[
{"role": "system", "content": REACT_SYSTEM_PROMPT},
{"role": "user", "content": react_prompt}
]
).choices[0].message['content']
)

View File

@@ -1,7 +1,7 @@
import numpy as np
from typing import List, Tuple, Optional
from sklearn.metrics.pairwise import cosine_similarity
from integrations.openaiwrapper import OpenAIAPIWrapper
from integrations.openaiwrapper import OpenAIAPIWrapper
class Agent:
def __init__(self, purpose: str):
@@ -73,4 +73,4 @@ class AgentSimilarity:
return closest_agent, highest_similarity
except Exception as e:
raise ValueError(f"Error finding closest agent: {e}")
raise ValueError(f"Error finding closest agent: {e}")

View File

@@ -15,7 +15,7 @@ class MicroAgent:
The MicroAgent class encapsulates the behavior of a small, purpose-driven agent
that interacts with the OpenAI API.
"""
def __init__(self, initial_prompt, purpose, depth, agent_creator, openai_wrapper, max_depth=3, bootstrap_agent=False, is_prime=False):
self.dynamic_prompt = initial_prompt
self.purpose = purpose

View File

@@ -46,10 +46,7 @@ class ResponseExtraction:
{"role": "user", "content": formatted_prompt}
]
extraction = self.openai_wrapper.chat_completion(
model="gpt-4",
return self.openai_wrapper.chat_completion(
messages=messages,
max_tokens=100,
)
return extraction.choices[0].message['content'].strip()

49
integrations/memoize.py Normal file
View File

@@ -0,0 +1,49 @@
import hashlib
import json
import sqlite3
## retrieved from https://www.kevinkatz.io/posts/memoize-to-sqlite
def memoize_to_sqlite(func_name:str, filename: str = "cache.db"):
"""
Memoization decorator that caches the output of a method in a SQLite
database.
"""
db_conn = sqlite3.connect(filename)
print("opening database")
db_conn.execute(
"CREATE TABLE IF NOT EXISTS cache (hash TEXT PRIMARY KEY, result TEXT)"
)
db_conn.execute(
"CREATE INDEX IF NOT EXISTS cache_ndx on cache(hash)"
)
def memoize(func):
def wrapped(*args, **kwargs):
# Compute the hash of the <function name>:<argument>
xs = f"{func_name}:{repr(tuple(args[1:]))}:{repr(kwargs)}".encode("utf-8")
arg_hash = hashlib.sha256(xs).hexdigest()
# Check if the result is already cached
cursor = db_conn.cursor()
cursor.execute(
"SELECT result FROM cache WHERE hash = ?", (arg_hash,)
)
row = cursor.fetchone()
if row is not None:
return json.loads(row[0])
# Compute the result and cache it
result = func(*args, **kwargs)
if func_name == "chat_completion":
print(result)
cursor.execute(
"INSERT INTO cache (hash, result) VALUES (?, ?)",
(arg_hash, json.dumps(result))
)
db_conn.commit()
return result
return wrapped
return memoize

View File

@@ -2,9 +2,15 @@ import openai
import time
import logging
from utils.utility import get_env_variable
from .memoize import memoize_to_sqlite
RETRY_SLEEP_DURATION = 1 # seconds
MAX_RETRIES = 5
ENGINE=get_env_variable("OPENAI_EMBEDDING", "text-embedding-ada-002", False)
MODEL=get_env_variable("OPENAI_MODEL", "gpt-4-1106-preview", False)
class OpenAIAPIWrapper:
"""
A wrapper class for OpenAI's API.
@@ -20,8 +26,8 @@ class OpenAIAPIWrapper:
self.api_key = api_key
openai.api_key = api_key
self.timeout = timeout
self.cache = {}
@memoize_to_sqlite(func_name="get_embedding", filename="openai_embedding_cache.db")
def get_embedding(self, text):
"""
Retrieves the embedding for the given text.
@@ -29,24 +35,23 @@ class OpenAIAPIWrapper:
:param text: The text for which embedding is required.
:return: The embedding for the given text.
"""
if text in self.cache:
return self.cache[text]
start_time = time.time()
retries = 0
while time.time() - start_time < self.timeout:
try:
embedding = openai.Embedding.create(input=text, engine="text-embedding-ada-002")
self.cache[text] = embedding
return embedding
return openai.Embedding.create(input=text, engine=ENGINE)
except openai.error.OpenAIError as e:
logging.error(f"OpenAI API error: {e}")
retries += 1
if retries >= MAX_RETRIES:
raise
time.sleep(RETRY_SLEEP_DURATION)
if f"{e}".startswith("Rate limit"):
print("Rate limit reached... sleeping for 20 seconds")
start_time+=20
time.sleep(20)
raise TimeoutError("API call timed out")
def chat_completion(self, **kwargs):
@@ -56,17 +61,30 @@ class OpenAIAPIWrapper:
:param kwargs: Keyword arguments for the chat completion API call.
:return: The result of the chat completion API call.
"""
if 'model' not in kwargs:
kwargs['model']=MODEL
start_time = time.time()
retries = 0
while time.time() - start_time < self.timeout:
try:
return openai.ChatCompletion.create(**kwargs)
res=openai.ChatCompletion.create(**kwargs)
if isinstance(res, dict):
if isinstance(res['choices'][0], dict):
return res['choices'][0]['message']['content'].strip()
return res['choices'][0].message['content'].strip()
return res.choices[0].message['content'].strip()
except openai.error.OpenAIError as e:
logging.error(f"OpenAI API error: {e}")
retries += 1
if retries >= MAX_RETRIES:
raise
time.sleep(RETRY_SLEEP_DURATION)
raise TimeoutError("API call timed out")
if f"{e}".startswith("Rate limit"):
print("Rate limit reached... sleeping for 20 seconds")
start_time+=20
time.sleep(20)
raise TimeoutError("API call timed out")

View File

@@ -115,4 +115,6 @@ def microagent_factory(initial_prompt, purpose, api_key, depth, max_depth, boots
return MicroAgent(initial_prompt, purpose, api_key, depth, max_depth, bootstrap_agent)
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
main()

View File

@@ -12,7 +12,7 @@ class PromptEvolution:
def evolve_prompt(self, input_text: str, dynamic_prompt: str, output: str, full_conversation: str, new_solution: bool, depth: int) -> str:
"""
Evolves the prompt based on feedback from the output and full conversation.
Args:
input_text: The input text for the prompt.
dynamic_prompt: The dynamic part of the prompt.
@@ -27,7 +27,7 @@ class PromptEvolution:
full_conversation = self._truncate_conversation(full_conversation)
runtime_context = self._generate_runtime_context(depth)
evolve_prompt_query = self._build_evolve_prompt_query(dynamic_prompt, output, full_conversation, new_solution)
try:
new_prompt = self._get_new_prompt(evolve_prompt_query, runtime_context)
except Exception as e:
@@ -57,8 +57,6 @@ class PromptEvolution:
def _get_new_prompt(self, evolve_prompt_query: str, runtime_context: str) -> str:
"""Fetches a new prompt from the OpenAI API."""
response = self.openai_wrapper.chat_completion(
model="gpt-4-1106-preview",
return self.openai_wrapper.chat_completion(
messages=[{"role": "system", "content": evolve_prompt_query + runtime_context}]
)
return response.choices[0].message['content'].strip()