mirror of
https://github.com/NirDiamant/RAG_Techniques.git
synced 2025-04-07 00:48:52 +03:00
made the dartboard more understandable
This commit is contained in:
@@ -70,7 +70,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -85,9 +85,10 @@
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import numpy as np\n",
|
||||
"from scipy.special import logsumexp\n",
|
||||
"from typing import Tuple, List\n",
|
||||
"from typing import Tuple, List, Any\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# Load environment variables from a .env file\n",
|
||||
"load_dotenv()\n",
|
||||
"# Set the OpenAI API key environment variable (comment out if not using OpenAI)\n",
|
||||
@@ -305,6 +306,13 @@
|
||||
"show_context(texts)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Now for the real part :) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -332,88 +340,173 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Greedy Dartboard Search\n",
|
||||
"\n",
|
||||
"### Definitions of parameters, and the actual function that optimizes both relevance and diversity \n",
|
||||
"This is the core function that chooses the top k documents based on relevance and diversity. It uses distances between each candidate document and the query and between candidate documents."
|
||||
"This is the core algorithm: A search algorithm that selects a diverse set of relevant documents from a collection by balancing two factors: relevance to the query and diversity among selected documents.\n",
|
||||
"\n",
|
||||
"Given distances between a query and documents, plus distances between all documents, the algorithm:\n",
|
||||
"\n",
|
||||
"1. Selects the most relevant document first\n",
|
||||
"2. Iteratively selects additional documents by combining:\n",
|
||||
" - Relevance to the original query\n",
|
||||
" - Diversity from previously selected documents\n",
|
||||
"\n",
|
||||
"The balance between relevance and diversity is controlled by weights:\n",
|
||||
"- `DIVERSITY_WEIGHT`: Importance of difference from existing selections\n",
|
||||
"- `RELEVANCE_WEIGHT`: Importance of relevance to query\n",
|
||||
"- `SIGMA`: Smoothing parameter for probability conversion\n",
|
||||
"\n",
|
||||
"The algorithm returns both the selected documents and their selection scores, making it useful for applications like search results where you want relevant but varied results.\n",
|
||||
"\n",
|
||||
"For example, when searching news articles, it would first return the most relevant article, then find articles that are both on-topic and provide new information, avoiding redundant selections."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configuration parameters\n",
|
||||
"DIVERSITY_WEIGHT = 1.0 # Weight for diversity in document selection\n",
|
||||
"RELEVANCE_WEIGHT = 1.0 # Weight for relevance to query\n",
|
||||
"SIGMA = 0.1 # Smoothing parameter for probability distribution\n",
|
||||
"\n",
|
||||
"# Adjust these according to your needs, knowledge base density, etc. \n",
|
||||
"DIVERSITY_WEIGHT=1.0\n",
|
||||
"RELEVANCE_WEIGHT=1.0\n",
|
||||
"SIGMA=0.1\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def greedy_dartsearch(q_dists:np.ndarray, dists_mat:np.ndarray, texts:List[str], k:int) -> Tuple[List[str], List[float]]:\n",
|
||||
"def greedy_dartsearch(\n",
|
||||
" query_distances: np.ndarray,\n",
|
||||
" document_distances: np.ndarray,\n",
|
||||
" documents: List[str],\n",
|
||||
" num_results: int\n",
|
||||
") -> Tuple[List[str], List[float]]:\n",
|
||||
" \"\"\"\n",
|
||||
" Perform greedy dartboard search to select top k documents.\n",
|
||||
" Perform greedy dartboard search to select top k documents balancing relevance and diversity.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" query_distances: Distance between query and each document\n",
|
||||
" document_distances: Pairwise distances between documents\n",
|
||||
" documents: List of document texts\n",
|
||||
" num_results: Number of documents to return\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" Tuple containing:\n",
|
||||
" - List of selected document texts\n",
|
||||
" - List of selection scores for each document\n",
|
||||
" \"\"\"\n",
|
||||
" sigma=np.max([SIGMA,1e-5]) # avoid division by zero\n",
|
||||
" qprobs = lognorm(q_dists, sigma)\n",
|
||||
" ccprobmat = lognorm(dists_mat, sigma)\n",
|
||||
" out_scores=[]\n",
|
||||
" top_idx = np.argmax(qprobs) # start with the most relevant document\n",
|
||||
" chosen_inds = np.array([top_idx]) # initialize the array of selected documents\n",
|
||||
" maxes = ccprobmat[top_idx] # Vector of distances to the most relevant document\n",
|
||||
" while len(chosen_inds) < k:\n",
|
||||
" newmaxes = np.maximum(maxes, ccprobmat) # update the maximum distances, note the broadcasting (matrix and vector)\n",
|
||||
"\n",
|
||||
" logscores = newmaxes*DIVERSITY_WEIGHT + qprobs*RELEVANCE_WEIGHT # score all the items\n",
|
||||
" scores = logsumexp(logscores, axis=1) # normalize the scores\n",
|
||||
" scores[chosen_inds] = -np.inf # avoid selecting the same document twice\n",
|
||||
" best_idx = np.argmax(scores) # select the best item\n",
|
||||
" best_score=np.max(scores) # avoid division by zero\n",
|
||||
" maxes = newmaxes[best_idx] # update the maximum distances\n",
|
||||
" chosen_inds = np.append(chosen_inds, best_idx) # add the best item to the set\n",
|
||||
" out_scores.append(best_score) # add the best score to the list\n",
|
||||
" return [texts[i] for i in chosen_inds],out_scores\n"
|
||||
" # Avoid division by zero in probability calculations\n",
|
||||
" sigma = max(SIGMA, 1e-5)\n",
|
||||
" \n",
|
||||
" # Convert distances to probability distributions\n",
|
||||
" query_probabilities = lognorm(query_distances, sigma)\n",
|
||||
" document_probabilities = lognorm(document_distances, sigma)\n",
|
||||
" \n",
|
||||
" # Initialize with most relevant document\n",
|
||||
" selection_scores = []\n",
|
||||
" most_relevant_idx = np.argmax(query_probabilities)\n",
|
||||
" selected_indices = np.array([most_relevant_idx])\n",
|
||||
" \n",
|
||||
" # Get initial distances from the first selected document\n",
|
||||
" max_distances = document_probabilities[most_relevant_idx]\n",
|
||||
" \n",
|
||||
" # Select remaining documents\n",
|
||||
" while len(selected_indices) < num_results:\n",
|
||||
" # Update maximum distances considering new document\n",
|
||||
" updated_distances = np.maximum(max_distances, document_probabilities)\n",
|
||||
" \n",
|
||||
" # Calculate combined diversity and relevance scores\n",
|
||||
" combined_scores = (\n",
|
||||
" updated_distances * DIVERSITY_WEIGHT +\n",
|
||||
" query_probabilities * RELEVANCE_WEIGHT\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Normalize scores and mask already selected documents\n",
|
||||
" normalized_scores = logsumexp(combined_scores, axis=1)\n",
|
||||
" normalized_scores[selected_indices] = -np.inf\n",
|
||||
" \n",
|
||||
" # Select best remaining document\n",
|
||||
" best_idx = np.argmax(normalized_scores)\n",
|
||||
" best_score = np.max(normalized_scores)\n",
|
||||
" \n",
|
||||
" # Update tracking variables\n",
|
||||
" max_distances = updated_distances[best_idx]\n",
|
||||
" selected_indices = np.append(selected_indices, best_idx)\n",
|
||||
" selection_scores.append(best_score)\n",
|
||||
" \n",
|
||||
" # Return selected documents and their scores\n",
|
||||
" selected_documents = [documents[i] for i in selected_indices]\n",
|
||||
" return selected_documents, selection_scores"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dartboard Context Retrieval\n",
|
||||
"\n",
|
||||
"### Main function for using the dartboard retrieval. This serves instead of get_context (which is simple RAG) it:\n",
|
||||
"### Main function for using the dartboard retrieval. This serves instead of get_context (which is simple RAG). It:\n",
|
||||
"\n",
|
||||
"1. Takes a text query, vectorzes it, gets the top k documents (and their vectors) via simple RAG. \n",
|
||||
"2. Uses these vectors to calculate the similarities to query and between candidate matches.\n",
|
||||
"3. Runs the dartboard algorithm to refine the candidate matches to a final list of k documents.\n",
|
||||
"4. Returns the final list of documents and their scores."
|
||||
"1. Takes a text query, vectorizes it, gets the top k documents (and their vectors) via simple RAG\n",
|
||||
"2. Uses these vectors to calculate the similarities to query and between candidate matches\n",
|
||||
"3. Runs the dartboard algorithm to refine the candidate matches to a final list of k documents\n",
|
||||
"4. Returns the final list of documents and their scores"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def get_context_with_dartboard(query:str,k:int=5) -> Tuple[List[str], List[float]]:\n",
|
||||
"def get_context_with_dartboard(\n",
|
||||
" query: str,\n",
|
||||
" num_results: int = 5,\n",
|
||||
" oversampling_factor: int = 3\n",
|
||||
") -> Tuple[List[str], List[float]]:\n",
|
||||
" \"\"\"\n",
|
||||
" Retrieve top k context items for a query using the dartboard algorithm.\n",
|
||||
" This function only handles the vectors and indices, the rest is handled by the get_dartboard function and below.\n",
|
||||
" Retrieve most relevant and diverse context items for a query using the dartboard algorithm.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" query: The search query string\n",
|
||||
" num_results: Number of context items to return (default: 5)\n",
|
||||
" oversampling_factor: Factor to oversample initial results for better diversity (default: 3)\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" Tuple containing:\n",
|
||||
" - List of selected context texts\n",
|
||||
" - List of selection scores\n",
|
||||
" \n",
|
||||
" Note:\n",
|
||||
" The function uses cosine similarity converted to distance. Initial retrieval \n",
|
||||
" fetches oversampling_factor * num_results items to ensure sufficient diversity \n",
|
||||
" in the final selection.\n",
|
||||
" \"\"\"\n",
|
||||
" q_vec=chunks_vector_store.embedding_function.embed_documents([query]) # embed the query\n",
|
||||
" _,indices=chunks_vector_store.index.search(np.array(q_vec),k=k*3) # fetch more than k to ensure we overcome density and use diversity\n",
|
||||
"\n",
|
||||
" vecs = np.array(chunks_vector_store.index.reconstruct_batch(indices[0])) # reconstruct the vectors of the retrieved documents\n",
|
||||
" texts = [idx_to_text(i) for i in indices[0]] # convert the indices to texts\n",
|
||||
"\n",
|
||||
" # Embed query and retrieve initial candidates\n",
|
||||
" query_embedding = chunks_vector_store.embedding_function.embed_documents([query])\n",
|
||||
" _, candidate_indices = chunks_vector_store.index.search(\n",
|
||||
" np.array(query_embedding),\n",
|
||||
" k=num_results * oversampling_factor\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # calculate similarities and convert them to distances:\n",
|
||||
" dists_mat = 1-np.dot(vecs,vecs.T) # 1-cosine distance, you may think of better distance functions. This can also be applied to cross-encoder scores. \n",
|
||||
" q_dists = 1-np.dot(q_vec,vecs.T) # calculate the distances to the query\n",
|
||||
" # Get document vectors and texts for candidates\n",
|
||||
" candidate_vectors = np.array(\n",
|
||||
" chunks_vector_store.index.reconstruct_batch(candidate_indices[0])\n",
|
||||
" )\n",
|
||||
" candidate_texts = [idx_to_text(idx) for idx in candidate_indices[0]]\n",
|
||||
" \n",
|
||||
" # run the dartboard algorithm\n",
|
||||
" texts, scores=greedy_dartsearch(q_dists,dists_mat,texts,k)\n",
|
||||
" return texts,scores\n"
|
||||
" # Calculate distance matrices\n",
|
||||
" # Using 1 - cosine_similarity as distance metric\n",
|
||||
" document_distances = 1 - np.dot(candidate_vectors, candidate_vectors.T)\n",
|
||||
" query_distances = 1 - np.dot(query_embedding, candidate_vectors.T)\n",
|
||||
" \n",
|
||||
" # Apply dartboard selection algorithm\n",
|
||||
" selected_texts, selection_scores = greedy_dartsearch(\n",
|
||||
" query_distances,\n",
|
||||
" document_distances,\n",
|
||||
" candidate_texts,\n",
|
||||
" num_results\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" return selected_texts, selection_scores"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -490,13 +583,6 @@
|
||||
"texts,scores=get_context_with_dartboard(test_query,k=3)\n",
|
||||
"show_context(texts)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -289,7 +289,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
"version": "3.12.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user