made the dartboard more understandable

This commit is contained in:
nird
2025-02-19 21:36:18 +02:00
parent 9b19b48637
commit 6e1698f962
2 changed files with 148 additions and 62 deletions

View File

@@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -85,9 +85,10 @@
"import os\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"import numpy as np\n",
"from scipy.special import logsumexp\n",
"from typing import Tuple, List\n",
"from typing import Tuple, List, Any\n",
"import numpy as np\n",
"\n",
"# Load environment variables from a .env file\n",
"load_dotenv()\n",
"# Set the OpenAI API key environment variable (comment out if not using OpenAI)\n",
@@ -305,6 +306,13 @@
"show_context(texts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Now for the real part :) "
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -332,88 +340,173 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Greedy Dartboard Search\n",
"\n",
"### Definitions of parameters, and the actual function that optimizes both relevance and diversity \n",
"This is the core function that chooses the top k documents based on relevance and diversity. It uses distances between each candidate document and the query and between candidate documents."
"This is the core algorithm: A search algorithm that selects a diverse set of relevant documents from a collection by balancing two factors: relevance to the query and diversity among selected documents.\n",
"\n",
"Given distances between a query and documents, plus distances between all documents, the algorithm:\n",
"\n",
"1. Selects the most relevant document first\n",
"2. Iteratively selects additional documents by combining:\n",
" - Relevance to the original query\n",
" - Diversity from previously selected documents\n",
"\n",
"The balance between relevance and diversity is controlled by weights:\n",
"- `DIVERSITY_WEIGHT`: Importance of difference from existing selections\n",
"- `RELEVANCE_WEIGHT`: Importance of relevance to query\n",
"- `SIGMA`: Smoothing parameter for probability conversion\n",
"\n",
"The algorithm returns both the selected documents and their selection scores, making it useful for applications like search results where you want relevant but varied results.\n",
"\n",
"For example, when searching news articles, it would first return the most relevant article, then find articles that are both on-topic and provide new information, avoiding redundant selections."
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configuration parameters\n",
"DIVERSITY_WEIGHT = 1.0 # Weight for diversity in document selection\n",
"RELEVANCE_WEIGHT = 1.0 # Weight for relevance to query\n",
"SIGMA = 0.1 # Smoothing parameter for probability distribution\n",
"\n",
"# Adjust these according to your needs, knowledge base density, etc. \n",
"DIVERSITY_WEIGHT=1.0\n",
"RELEVANCE_WEIGHT=1.0\n",
"SIGMA=0.1\n",
"\n",
"\n",
"def greedy_dartsearch(q_dists:np.ndarray, dists_mat:np.ndarray, texts:List[str], k:int) -> Tuple[List[str], List[float]]:\n",
"def greedy_dartsearch(\n",
" query_distances: np.ndarray,\n",
" document_distances: np.ndarray,\n",
" documents: List[str],\n",
" num_results: int\n",
") -> Tuple[List[str], List[float]]:\n",
" \"\"\"\n",
" Perform greedy dartboard search to select top k documents.\n",
" Perform greedy dartboard search to select top k documents balancing relevance and diversity.\n",
" \n",
" Args:\n",
" query_distances: Distance between query and each document\n",
" document_distances: Pairwise distances between documents\n",
" documents: List of document texts\n",
" num_results: Number of documents to return\n",
" \n",
" Returns:\n",
" Tuple containing:\n",
" - List of selected document texts\n",
" - List of selection scores for each document\n",
" \"\"\"\n",
" sigma=np.max([SIGMA,1e-5]) # avoid division by zero\n",
" qprobs = lognorm(q_dists, sigma)\n",
" ccprobmat = lognorm(dists_mat, sigma)\n",
" out_scores=[]\n",
" top_idx = np.argmax(qprobs) # start with the most relevant document\n",
" chosen_inds = np.array([top_idx]) # initialize the array of selected documents\n",
" maxes = ccprobmat[top_idx] # Vector of distances to the most relevant document\n",
" while len(chosen_inds) < k:\n",
" newmaxes = np.maximum(maxes, ccprobmat) # update the maximum distances, note the broadcasting (matrix and vector)\n",
"\n",
" logscores = newmaxes*DIVERSITY_WEIGHT + qprobs*RELEVANCE_WEIGHT # score all the items\n",
" scores = logsumexp(logscores, axis=1) # normalize the scores\n",
" scores[chosen_inds] = -np.inf # avoid selecting the same document twice\n",
" best_idx = np.argmax(scores) # select the best item\n",
" best_score=np.max(scores) # avoid division by zero\n",
" maxes = newmaxes[best_idx] # update the maximum distances\n",
" chosen_inds = np.append(chosen_inds, best_idx) # add the best item to the set\n",
" out_scores.append(best_score) # add the best score to the list\n",
" return [texts[i] for i in chosen_inds],out_scores\n"
" # Avoid division by zero in probability calculations\n",
" sigma = max(SIGMA, 1e-5)\n",
" \n",
" # Convert distances to probability distributions\n",
" query_probabilities = lognorm(query_distances, sigma)\n",
" document_probabilities = lognorm(document_distances, sigma)\n",
" \n",
" # Initialize with most relevant document\n",
" selection_scores = []\n",
" most_relevant_idx = np.argmax(query_probabilities)\n",
" selected_indices = np.array([most_relevant_idx])\n",
" \n",
" # Get initial distances from the first selected document\n",
" max_distances = document_probabilities[most_relevant_idx]\n",
" \n",
" # Select remaining documents\n",
" while len(selected_indices) < num_results:\n",
" # Update maximum distances considering new document\n",
" updated_distances = np.maximum(max_distances, document_probabilities)\n",
" \n",
" # Calculate combined diversity and relevance scores\n",
" combined_scores = (\n",
" updated_distances * DIVERSITY_WEIGHT +\n",
" query_probabilities * RELEVANCE_WEIGHT\n",
" )\n",
" \n",
" # Normalize scores and mask already selected documents\n",
" normalized_scores = logsumexp(combined_scores, axis=1)\n",
" normalized_scores[selected_indices] = -np.inf\n",
" \n",
" # Select best remaining document\n",
" best_idx = np.argmax(normalized_scores)\n",
" best_score = np.max(normalized_scores)\n",
" \n",
" # Update tracking variables\n",
" max_distances = updated_distances[best_idx]\n",
" selected_indices = np.append(selected_indices, best_idx)\n",
" selection_scores.append(best_score)\n",
" \n",
" # Return selected documents and their scores\n",
" selected_documents = [documents[i] for i in selected_indices]\n",
" return selected_documents, selection_scores"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dartboard Context Retrieval\n",
"\n",
"### Main function for using the dartboard retrieval. This serves instead of get_context (which is simple RAG) it:\n",
"### Main function for using the dartboard retrieval. This serves instead of get_context (which is simple RAG). It:\n",
"\n",
"1. Takes a text query, vectorzes it, gets the top k documents (and their vectors) via simple RAG. \n",
"2. Uses these vectors to calculate the similarities to query and between candidate matches.\n",
"3. Runs the dartboard algorithm to refine the candidate matches to a final list of k documents.\n",
"4. Returns the final list of documents and their scores."
"1. Takes a text query, vectorizes it, gets the top k documents (and their vectors) via simple RAG\n",
"2. Uses these vectors to calculate the similarities to query and between candidate matches\n",
"3. Runs the dartboard algorithm to refine the candidate matches to a final list of k documents\n",
"4. Returns the final list of documents and their scores"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def get_context_with_dartboard(query:str,k:int=5) -> Tuple[List[str], List[float]]:\n",
"def get_context_with_dartboard(\n",
" query: str,\n",
" num_results: int = 5,\n",
" oversampling_factor: int = 3\n",
") -> Tuple[List[str], List[float]]:\n",
" \"\"\"\n",
" Retrieve top k context items for a query using the dartboard algorithm.\n",
" This function only handles the vectors and indices, the rest is handled by the get_dartboard function and below.\n",
" Retrieve most relevant and diverse context items for a query using the dartboard algorithm.\n",
" \n",
" Args:\n",
" query: The search query string\n",
" num_results: Number of context items to return (default: 5)\n",
" oversampling_factor: Factor to oversample initial results for better diversity (default: 3)\n",
" \n",
" Returns:\n",
" Tuple containing:\n",
" - List of selected context texts\n",
" - List of selection scores\n",
" \n",
" Note:\n",
" The function uses cosine similarity converted to distance. Initial retrieval \n",
" fetches oversampling_factor * num_results items to ensure sufficient diversity \n",
" in the final selection.\n",
" \"\"\"\n",
" q_vec=chunks_vector_store.embedding_function.embed_documents([query]) # embed the query\n",
" _,indices=chunks_vector_store.index.search(np.array(q_vec),k=k*3) # fetch more than k to ensure we overcome density and use diversity\n",
"\n",
" vecs = np.array(chunks_vector_store.index.reconstruct_batch(indices[0])) # reconstruct the vectors of the retrieved documents\n",
" texts = [idx_to_text(i) for i in indices[0]] # convert the indices to texts\n",
"\n",
" # Embed query and retrieve initial candidates\n",
" query_embedding = chunks_vector_store.embedding_function.embed_documents([query])\n",
" _, candidate_indices = chunks_vector_store.index.search(\n",
" np.array(query_embedding),\n",
" k=num_results * oversampling_factor\n",
" )\n",
" \n",
" # calculate similarities and convert them to distances:\n",
" dists_mat = 1-np.dot(vecs,vecs.T) # 1-cosine distance, you may think of better distance functions. This can also be applied to cross-encoder scores. \n",
" q_dists = 1-np.dot(q_vec,vecs.T) # calculate the distances to the query\n",
" # Get document vectors and texts for candidates\n",
" candidate_vectors = np.array(\n",
" chunks_vector_store.index.reconstruct_batch(candidate_indices[0])\n",
" )\n",
" candidate_texts = [idx_to_text(idx) for idx in candidate_indices[0]]\n",
" \n",
" # run the dartboard algorithm\n",
" texts, scores=greedy_dartsearch(q_dists,dists_mat,texts,k)\n",
" return texts,scores\n"
" # Calculate distance matrices\n",
" # Using 1 - cosine_similarity as distance metric\n",
" document_distances = 1 - np.dot(candidate_vectors, candidate_vectors.T)\n",
" query_distances = 1 - np.dot(query_embedding, candidate_vectors.T)\n",
" \n",
" # Apply dartboard selection algorithm\n",
" selected_texts, selection_scores = greedy_dartsearch(\n",
" query_distances,\n",
" document_distances,\n",
" candidate_texts,\n",
" num_results\n",
" )\n",
" \n",
" return selected_texts, selection_scores"
]
},
{
@@ -490,13 +583,6 @@
"texts,scores=get_context_with_dartboard(test_query,k=3)\n",
"show_context(texts)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -289,7 +289,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.0"
}
},
"nbformat": 4,