update docstring

This commit is contained in:
siyunzhao
2024-01-19 09:33:46 +00:00
committed by Huiqiang Jiang
parent afaaef2b05
commit 6d7bbc3c21

View File

@@ -17,6 +17,32 @@ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
class PromptCompressor:
"""
PromptCompressor is designed for compressing prompts based on a given language model.
This class initializes with the language model and its configuration, preparing it for prompt compression tasks.
The PromptCompressor class is versatile and can be adapted for various models and specific requirements in prompt processing.
Users can specify different model names and configurations as needed for their particular use case.The architecture is
based on the paper "LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models". Jiang, Huiqiang, Qianhui Wu,
Chin-Yew Lin, Yuqing Yang, and Lili Qiu. "Llmlingua: Compressing prompts for accelerated inference of large language models."
arXiv preprint arXiv:2310.05736 (2023).
Args:
model_name (str, optional): The name of the language model to be loaded. Default is "NousResearch/Llama-2-7b-hf".
device_map (str, optional): The device to load the model onto, e.g., "cuda" for GPU. Default is "cuda".
model_config (dict, optional): A dictionary containing the configuration parameters for the model. Default is an empty dictionary.
open_api_config (dict, optional): A dictionary containing configuration for openai APIs that may be used in conjunction with the model. Default is an empty dictionary.
Example:
>>> compress_method = PromptCompressor(model_name="gpt2", device_map="cuda")
>>> context = ["This is the first context sentence.", "Here is another context sentence."]
>>> result = compress_method.compress_prompt(context)
>>> print(result["compressed_prompt"])
# This will print the compressed version of the context.
Note:
The `PromptCompressor` class requires the Hugging Face Transformers library and an appropriate environment to load and run the models.
"""
def __init__(
self,
model_name: str = "NousResearch/Llama-2-7b-hf",
@@ -156,6 +182,50 @@ class PromptCompressor:
rank_method: str = "llmlingua",
concate_question: bool = True,
):
"""
Compresses the given context, instruction and question.
Args:
context (List[str]): List of context strings that form the basis of the prompt.
instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
ratio (float, optional): The minimum compression ratio target to be achieved. Default is 0.5. The actual compression ratio
generally exceeds the specified target, but there can be fluctuations due to differences in tokenizers. If specified,
it should be a float greater than or equal to 1.0, representing the target compression ratio.
target_token (float, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
The actual number of tokens after compression should generally be less than the specified target_token, but there can
be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
the sole criterion, overriding the ratio.
iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
Returns:
dict: A dictionary containing:
- "compressed_prompt" (str): The resulting compressed prompt.
- "origin_tokens" (int): The original number of tokens in the input.
- "compressed_tokens" (int): The number of tokens in the compressed output.
- "ratio" (str): The compression ratio achieved, in a human-readable format.
- "saving" (str): Estimated savings in GPT-4 token usage.
"""
if not context:
context = [" "]
if isinstance(context, str):