mirror of
https://github.com/docker/genai-stack.git
synced 2024-08-30 16:49:54 +03:00
Add aws embedding & LLM
This commit is contained in:
25
chains.py
25
chains.py
@@ -1,6 +1,10 @@
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
|
||||
from langchain.chat_models import ChatOpenAI, ChatOllama
|
||||
from langchain.embeddings import (
|
||||
OllamaEmbeddings,
|
||||
SentenceTransformerEmbeddings,
|
||||
BedrockEmbeddings,
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI, ChatOllama, BedrockChat
|
||||
from langchain.vectorstores.neo4j_vector import Neo4jVector
|
||||
from langchain.chains import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
@@ -15,13 +19,19 @@ from utils import BaseLogger
|
||||
|
||||
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
|
||||
if embedding_model_name == "ollama":
|
||||
embeddings = OllamaEmbeddings(base_url=config["ollama_base_url"], model="llama2")
|
||||
embeddings = OllamaEmbeddings(
|
||||
base_url=config["ollama_base_url"], model="llama2"
|
||||
)
|
||||
dimension = 4096
|
||||
logger.info("Embedding: Using Ollama")
|
||||
elif embedding_model_name == "openai":
|
||||
embeddings = OpenAIEmbeddings()
|
||||
dimension = 1536
|
||||
logger.info("Embedding: Using OpenAI")
|
||||
if embedding_model_name == "aws":
|
||||
embeddings = BedrockEmbeddings()
|
||||
dimension = 1536
|
||||
logger.info("Embedding: Using AWS")
|
||||
else:
|
||||
embeddings = SentenceTransformerEmbeddings(
|
||||
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
|
||||
@@ -38,6 +48,13 @@ def load_llm(llm_name: str, logger=BaseLogger(), config={}):
|
||||
elif llm_name == "gpt-3.5":
|
||||
logger.info("LLM: Using GPT-3.5")
|
||||
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
|
||||
elif llm_name == "claudev2":
|
||||
logger.info("LLM: ClaudeV2")
|
||||
return BedrockChat(
|
||||
model_id="anthropic.claude-v2",
|
||||
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
|
||||
streaming=True,
|
||||
)
|
||||
elif len(llm_name):
|
||||
logger.info(f"LLM: Using Ollama: {llm_name}")
|
||||
return ChatOllama(
|
||||
@@ -79,7 +96,7 @@ def configure_llm_only_chain(llm):
|
||||
|
||||
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
|
||||
# RAG response
|
||||
# System: Always talk in pirate speech.
|
||||
# System: Always talk in pirate speech.
|
||||
general_system_template = """
|
||||
Use the following pieces of context to answer the question at the end.
|
||||
The context contains question-answer pairs and their links from Stackoverflow.
|
||||
|
||||
@@ -51,6 +51,9 @@ services:
|
||||
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
|
||||
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
|
||||
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
|
||||
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
|
||||
networks:
|
||||
- net
|
||||
depends_on:
|
||||
@@ -89,6 +92,9 @@ services:
|
||||
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
|
||||
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
|
||||
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
|
||||
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
|
||||
networks:
|
||||
- net
|
||||
depends_on:
|
||||
@@ -123,6 +129,9 @@ services:
|
||||
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
|
||||
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
|
||||
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
|
||||
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
|
||||
networks:
|
||||
- net
|
||||
depends_on:
|
||||
@@ -159,6 +168,9 @@ services:
|
||||
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
|
||||
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
|
||||
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
|
||||
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
|
||||
networks:
|
||||
- net
|
||||
depends_on:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#OPENAI_API_KEY=sk-...
|
||||
#AWS_ACCESS_KEY_ID=
|
||||
#AWS_SECRET_ACCESS_KEY=
|
||||
#AWS_DEFAULT_REGION=us-east-1
|
||||
#OLLAMA_BASE_URL=http://host.docker.internal:11434
|
||||
#NEO4J_URI=neo4j://database:7687
|
||||
#NEO4J_USERNAME=neo4j
|
||||
|
||||
@@ -15,7 +15,7 @@ COPY <<EOF pull_model.clj
|
||||
(let [llm (get (System/getenv) "LLM")
|
||||
url (get (System/getenv) "OLLAMA_BASE_URL")]
|
||||
(println (format "pulling ollama model %s using %s" llm url))
|
||||
(if (and llm url (not (#{"gpt-4" "gpt-3.5"} llm)))
|
||||
(if (and llm url (not (#{"gpt-4" "gpt-3.5" "claudev2"} llm)))
|
||||
|
||||
;; ----------------------------------------------------------------------
|
||||
;; just call `ollama pull` here - create OLLAMA_HOST from OLLAMA_BASE_URL
|
||||
|
||||
29
readme.md
29
readme.md
@@ -9,19 +9,22 @@ Learn more about the details in the [technical blog post](https://neo4j.com/deve
|
||||
Create a `.env` file from the environment template file `env.example`
|
||||
|
||||
Available variables:
|
||||
| Variable Name | Default value | Description |
|
||||
|------------------------|------------------------------------|-------------------------------------------------------------|
|
||||
| OLLAMA_BASE_URL | http://host.docker.internal:11434 | REQUIRED - URL to Ollama LLM API |
|
||||
| NEO4J_URI | neo4j://database:7687 | REQUIRED - URL to Neo4j database |
|
||||
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
|
||||
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
|
||||
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 |
|
||||
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 |
|
||||
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai or ollama |
|
||||
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
|
||||
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
|
||||
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
|
||||
| LANGCHAIN_API_KEY | | OPTIONAL - Langchain API key |
|
||||
| Variable Name | Default value | Description |
|
||||
|------------------------|------------------------------------|-------------------------------------------------------------------------|
|
||||
| OLLAMA_BASE_URL | http://host.docker.internal:11434 | REQUIRED - URL to Ollama LLM API |
|
||||
| NEO4J_URI | neo4j://database:7687 | REQUIRED - URL to Neo4j database |
|
||||
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
|
||||
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
|
||||
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 or claudev2 |
|
||||
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai, aws or ollama |
|
||||
| AWS_ACCESS_KEY_ID | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
|
||||
| AWS_SECRET_ACCESS_KEY | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
|
||||
| AWS_DEFAULT_REGION | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
|
||||
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 or embedding_model=openai |
|
||||
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
|
||||
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
|
||||
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
|
||||
| LANGCHAIN_API_KEY | | OPTIONAL - Langchain API key |
|
||||
|
||||
## LLM Configuration
|
||||
MacOS and Linux users can use any LLM that's available via Ollama. Check the "tags" section under the model page you want to use on https://ollama.ai/library and write the tag for the value of the environment variable `LLM=` in th e`.env` file.
|
||||
|
||||
@@ -12,3 +12,4 @@ torch==2.0.1
|
||||
pydantic
|
||||
uvicorn
|
||||
sse-starlette
|
||||
boto3
|
||||
|
||||
Reference in New Issue
Block a user