Add aws embedding & LLM

This commit is contained in:
Tomaz Bratanic
2023-10-15 14:29:06 +02:00
parent 91410af873
commit d3041302e6
6 changed files with 54 additions and 18 deletions

View File

@@ -1,6 +1,10 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
from langchain.chat_models import ChatOpenAI, ChatOllama
from langchain.embeddings import (
OllamaEmbeddings,
SentenceTransformerEmbeddings,
BedrockEmbeddings,
)
from langchain.chat_models import ChatOpenAI, ChatOllama, BedrockChat
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
@@ -15,13 +19,19 @@ from utils import BaseLogger
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(base_url=config["ollama_base_url"], model="llama2")
embeddings = OllamaEmbeddings(
base_url=config["ollama_base_url"], model="llama2"
)
dimension = 4096
logger.info("Embedding: Using Ollama")
elif embedding_model_name == "openai":
embeddings = OpenAIEmbeddings()
dimension = 1536
logger.info("Embedding: Using OpenAI")
if embedding_model_name == "aws":
embeddings = BedrockEmbeddings()
dimension = 1536
logger.info("Embedding: Using AWS")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
@@ -38,6 +48,13 @@ def load_llm(llm_name: str, logger=BaseLogger(), config={}):
elif llm_name == "gpt-3.5":
logger.info("LLM: Using GPT-3.5")
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
elif llm_name == "claudev2":
logger.info("LLM: ClaudeV2")
return BedrockChat(
model_id="anthropic.claude-v2",
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
streaming=True,
)
elif len(llm_name):
logger.info(f"LLM: Using Ollama: {llm_name}")
return ChatOllama(
@@ -79,7 +96,7 @@ def configure_llm_only_chain(llm):
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
# RAG response
# System: Always talk in pirate speech.
# System: Always talk in pirate speech.
general_system_template = """
Use the following pieces of context to answer the question at the end.
The context contains question-answer pairs and their links from Stackoverflow.

View File

@@ -51,6 +51,9 @@ services:
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
networks:
- net
depends_on:
@@ -89,6 +92,9 @@ services:
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
networks:
- net
depends_on:
@@ -123,6 +129,9 @@ services:
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
networks:
- net
depends_on:
@@ -159,6 +168,9 @@ services:
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
networks:
- net
depends_on:

View File

@@ -1,4 +1,7 @@
#OPENAI_API_KEY=sk-...
#AWS_ACCESS_KEY_ID=
#AWS_SECRET_ACCESS_KEY=
#AWS_DEFAULT_REGION=us-east-1
#OLLAMA_BASE_URL=http://host.docker.internal:11434
#NEO4J_URI=neo4j://database:7687
#NEO4J_USERNAME=neo4j

View File

@@ -15,7 +15,7 @@ COPY <<EOF pull_model.clj
(let [llm (get (System/getenv) "LLM")
url (get (System/getenv) "OLLAMA_BASE_URL")]
(println (format "pulling ollama model %s using %s" llm url))
(if (and llm url (not (#{"gpt-4" "gpt-3.5"} llm)))
(if (and llm url (not (#{"gpt-4" "gpt-3.5" "claudev2"} llm)))
;; ----------------------------------------------------------------------
;; just call `ollama pull` here - create OLLAMA_HOST from OLLAMA_BASE_URL

View File

@@ -9,19 +9,22 @@ Learn more about the details in the [technical blog post](https://neo4j.com/deve
Create a `.env` file from the environment template file `env.example`
Available variables:
| Variable Name | Default value | Description |
|------------------------|------------------------------------|-------------------------------------------------------------|
| OLLAMA_BASE_URL | http://host.docker.internal:11434 | REQUIRED - URL to Ollama LLM API |
| NEO4J_URI | neo4j://database:7687 | REQUIRED - URL to Neo4j database |
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 |
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 |
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai or ollama |
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
| LANGCHAIN_API_KEY | | OPTIONAL - Langchain API key |
| Variable Name | Default value | Description |
|------------------------|------------------------------------|-------------------------------------------------------------------------|
| OLLAMA_BASE_URL | http://host.docker.internal:11434 | REQUIRED - URL to Ollama LLM API |
| NEO4J_URI | neo4j://database:7687 | REQUIRED - URL to Neo4j database |
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 or claudev2 |
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai, aws or ollama |
| AWS_ACCESS_KEY_ID | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| AWS_SECRET_ACCESS_KEY | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| AWS_DEFAULT_REGION | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 or embedding_model=openai |
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
| LANGCHAIN_API_KEY | | OPTIONAL - Langchain API key |
## LLM Configuration
MacOS and Linux users can use any LLM that's available via Ollama. Check the "tags" section under the model page you want to use on https://ollama.ai/library and write the tag for the value of the environment variable `LLM=` in th e`.env` file.

View File

@@ -12,3 +12,4 @@ torch==2.0.1
pydantic
uvicorn
sse-starlette
boto3