The project now runs the example script on ollama, but sure to have the

latest version as it's needed by one of the models
This commit is contained in:
Gerald Hewes
2025-02-13 08:38:44 -05:00
parent 350585663f
commit 1fc1c9d7ae
21 changed files with 16 additions and 11 deletions

View File

@@ -34,7 +34,7 @@ def get_azure_openai_async_client_instance():
def get_ollama_async_client_instance():
global global_ollama_client
if global_ollama_client is None:
#global_ollama_client = AsyncClient(host="http://localhost:11434") # Adjust base URL if necessary
# set OLLAMA_HOST
global_ollama_client = AsyncClient(host="http://10.0.1.12:11434") # Adjust base URL if necessary
return global_ollama_client
@@ -208,35 +208,37 @@ async def ollama_complete_if_cache(
# Send the request to Ollama
response = await ollama_client.chat(
model=model,
messages=messages,
**kwargs
messages=messages
)
# print(messages)
# print(response['message']['content'])
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.response, "model": model}}
{args_hash: {"return": response['message']['content'], "model": model}}
)
await hashing_kv.index_done_callback()
return response.response
return response['message']['content']
async def ollama_complete(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
return await ollama_complete_if_cache(
"deepseek-r1:32b", # For now select your model
#"deepseek-r1:32b", # For now select your model
"gemma2:latest",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
history_messages=history_messages
)
async def ollama_mini_complete(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
return await ollama_complete_if_cache(
"deepseek-r1:latest", # For now select your model
# "deepseek-r1:latest", # For now select your model
"olmo2",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
history_messages=history_messages
)
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)

View File

@@ -411,9 +411,12 @@ async def extract_entities(
maybe_edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
print('+++++++')
print(record)
if record is None:
continue
record = record.group(1)
print(record)
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)