Skip to content

Commit 5883958

Browse files
authored
Merge pull request #50 from docker/aws
Add aws embedding & LLM
2 parents 9ef6de8 + d304130 commit 5883958

6 files changed

Lines changed: 54 additions & 18 deletions

File tree

chains.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from langchain.embeddings.openai import OpenAIEmbeddings
2-
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
3-
from langchain.chat_models import ChatOpenAI, ChatOllama
2+
from langchain.embeddings import (
3+
OllamaEmbeddings,
4+
SentenceTransformerEmbeddings,
5+
BedrockEmbeddings,
6+
)
7+
from langchain.chat_models import ChatOpenAI, ChatOllama, BedrockChat
48
from langchain.vectorstores.neo4j_vector import Neo4jVector
59
from langchain.chains import RetrievalQAWithSourcesChain
610
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
@@ -15,13 +19,19 @@
1519

1620
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
1721
if embedding_model_name == "ollama":
18-
embeddings = OllamaEmbeddings(base_url=config["ollama_base_url"], model="llama2")
22+
embeddings = OllamaEmbeddings(
23+
base_url=config["ollama_base_url"], model="llama2"
24+
)
1925
dimension = 4096
2026
logger.info("Embedding: Using Ollama")
2127
elif embedding_model_name == "openai":
2228
embeddings = OpenAIEmbeddings()
2329
dimension = 1536
2430
logger.info("Embedding: Using OpenAI")
31+
if embedding_model_name == "aws":
32+
embeddings = BedrockEmbeddings()
33+
dimension = 1536
34+
logger.info("Embedding: Using AWS")
2535
else:
2636
embeddings = SentenceTransformerEmbeddings(
2737
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
@@ -38,6 +48,13 @@ def load_llm(llm_name: str, logger=BaseLogger(), config={}):
3848
elif llm_name == "gpt-3.5":
3949
logger.info("LLM: Using GPT-3.5")
4050
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
51+
elif llm_name == "claudev2":
52+
logger.info("LLM: ClaudeV2")
53+
return BedrockChat(
54+
model_id="anthropic.claude-v2",
55+
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
56+
streaming=True,
57+
)
4158
elif len(llm_name):
4259
logger.info(f"LLM: Using Ollama: {llm_name}")
4360
return ChatOllama(
@@ -79,7 +96,7 @@ def generate_llm_output(
7996

8097
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
8198
# RAG response
82-
# System: Always talk in pirate speech.
99+
# System: Always talk in pirate speech.
83100
general_system_template = """
84101
Use the following pieces of context to answer the question at the end.
85102
The context contains question-answer pairs and their links from Stackoverflow.

docker-compose.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ services:
5151
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
5252
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
5353
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
54+
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
55+
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
56+
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
5457
networks:
5558
- net
5659
depends_on:
@@ -89,6 +92,9 @@ services:
8992
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
9093
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
9194
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
95+
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
96+
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
97+
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
9298
networks:
9399
- net
94100
depends_on:
@@ -123,6 +129,9 @@ services:
123129
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
124130
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
125131
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
132+
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
133+
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
134+
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
126135
networks:
127136
- net
128137
depends_on:
@@ -159,6 +168,9 @@ services:
159168
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
160169
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
161170
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
171+
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
172+
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
173+
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
162174
networks:
163175
- net
164176
depends_on:

env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#OPENAI_API_KEY=sk-...
2+
#AWS_ACCESS_KEY_ID=
3+
#AWS_SECRET_ACCESS_KEY=
4+
#AWS_DEFAULT_REGION=us-east-1
25
#OLLAMA_BASE_URL=http://host.docker.internal:11434
36
#NEO4J_URI=neo4j://database:7687
47
#NEO4J_USERNAME=neo4j

pull_model.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ COPY <<EOF pull_model.clj
1515
(let [llm (get (System/getenv) "LLM")
1616
url (get (System/getenv) "OLLAMA_BASE_URL")]
1717
(println (format "pulling ollama model %s using %s" llm url))
18-
(if (and llm url (not (#{"gpt-4" "gpt-3.5"} llm)))
18+
(if (and llm url (not (#{"gpt-4" "gpt-3.5" "claudev2"} llm)))
1919

2020
;; ----------------------------------------------------------------------
2121
;; just call `ollama pull` here - create OLLAMA_HOST from OLLAMA_BASE_URL

readme.md

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,22 @@ Learn more about the details in the [technical blog post](https://neo4j.com/deve
99
Create a `.env` file from the environment template file `env.example`
1010

1111
Available variables:
12-
| Variable Name | Default value | Description |
13-
|------------------------|------------------------------------|-------------------------------------------------------------|
14-
| OLLAMA_BASE_URL | http://host.docker.internal:11434 | REQUIRED - URL to Ollama LLM API |
15-
| NEO4J_URI | neo4j://database:7687 | REQUIRED - URL to Neo4j database |
16-
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
17-
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
18-
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 |
19-
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 |
20-
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai or ollama |
21-
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
22-
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
23-
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
24-
| LANGCHAIN_API_KEY | | OPTIONAL - Langchain API key |
12+
| Variable Name | Default value | Description |
13+
|------------------------|------------------------------------|-------------------------------------------------------------------------|
14+
| OLLAMA_BASE_URL | http://host.docker.internal:11434 | REQUIRED - URL to Ollama LLM API |
15+
| NEO4J_URI | neo4j://database:7687 | REQUIRED - URL to Neo4j database |
16+
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
17+
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
18+
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 or claudev2 |
19+
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai, aws or ollama |
20+
| AWS_ACCESS_KEY_ID | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
21+
| AWS_SECRET_ACCESS_KEY | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
22+
| AWS_DEFAULT_REGION | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
23+
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 or embedding_model=openai |
24+
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
25+
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
26+
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
27+
| LANGCHAIN_API_KEY | | OPTIONAL - Langchain API key |
2528

2629
## LLM Configuration
2730
MacOS and Linux users can use any LLM that's available via Ollama. Check the "tags" section under the model page you want to use on https://ollama.ai/library and write the tag for the value of the environment variable `LLM=` in th e`.env` file.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ torch==2.0.1
1212
pydantic
1313
uvicorn
1414
sse-starlette
15+
boto3

0 commit comments

Comments
 (0)