Feature(LLMLingua): add LongLLMLingua & documents

This commit is contained in:
Huiqiang Jiang
2023-10-08 11:56:35 +00:00
parent 44899b19fd
commit b917a05576
9 changed files with 892 additions and 167 deletions

1
.gitignore vendored
View File

@@ -396,3 +396,4 @@ FodyWeavers.xsd
# JetBrains Rider
*.sln.iml
*.egg-info

121
DOCUMENT.md Normal file
View File

@@ -0,0 +1,121 @@
# LLMLingua Documentation
## Principles
- The most important thing is **the sensitivity to compression varies among different components in a prompt**, such as instructions and questions being more sensitive, while context or documents are less sensitive. Therefore, it is advisable to separate the components within the prompt and input them into demonstrations, instructions, and questions.
- **Divide demonstrations and context into independent granularities**, such as documents in multi-document QA and examples in few-shot learning. This approach will be beneficial for the budget controller and document reordering.
- **Preserving essential characters in the scenario as required by the rule**, we will provide support soon.
- Try experimenting with different target compression ratios or other hyperparameters to optimize the performance.
## Initialization
```python
from llmlingua import PromptCompressor
llm_lingua = PromptCompressor(
model_name: str = "NousResearch/Llama-2-7b-hf",
device_map: str = "cuda",
use_auth_token: bool = False,
open_api_config: dict = {},
)
```
### Parameters
- model_name(str), the name of small language model from huggingface. Default set to "NousResearch/Llama-2-7b-hf";
- device_map(str), the device environment for using small models, like 'cuda', 'cpu', 'balanced', 'balanced_low_0', 'auto'. Default set to "cuda";
- use_auth_token(bool, optional), controls the usage of huggingface auto_token. Default set to False;
- open_api_config(dict, optional), the config of openai which use in OpenAI Embedding in coarse-level prompt compression. Default set to {};
## Function Call
```python
compressed_prompt = llm_lingua.compress_prompt(
context: List[str],
instruction: str = "",
question: str = "",
ratio: float = 0.5,
target_token: float = -1,
iterative_size: int = 200,
force_context_ids: List[int] = None,
force_context_number: int = None,
use_sentence_level_filter: bool = False,
use_context_level_filter: bool = True,
use_token_level_filter: bool = True,
keep_split: bool = False,
keep_first_sentence: int = 0,
keep_last_sentence: int = 0,
keep_sentence_number: int = 0,
high_priority_bonus: int = 100,
context_budget: str = "+100",
token_budget_ratio: float = 1.4,
condition_in_question: str = "none",
reorder_context: str = "original",
dynamic_context_compression_ratio: float = 0.0,
condition_compare: bool = False,
add_instruction: bool = False,
rank_method: str = "longllmlingua",
concate_question: bool = True,
)
# > {'compressed_prompt': 'Question: Sam bought a dozen boxes, each with 30 highlighter pens inside, for $10 each box. He reanged five of boxes into packages of sixlters each and sold them $3 per. He sold the rest theters separately at the of three pens $2. How much did make in total, dollars?\nLets think step step\nSam bought 1 boxes x00 oflters.\nHe bought 12 * 300ters in total\nSam then took 5 boxes 6ters0ters.\nHe sold these boxes for 5 *5\nAfterelling these boxes there were 3030 highlighters remaining.\nThese form 330 / 3 = 110 groups of three pens.\nHe sold each of these groups for $2 each, so made 110 * 2 = $220 from them.\nIn total, then, he earned $220 + $15 = $235.\nSince his original cost was $120, he earned $235 - $120 = $115 in profit.\nThe answer is 115',
# 'origin_tokens': 2365,
# 'compressed_tokens': 211,
# 'ratio': '11.2x',
# 'saving': ', Saving $0.1 in GPT-4.'}
```
### Parameters
- **context**(str or List[str]), the context, documents or demonstrations in the prompt, low sensitivity to compression;
- instruction(str), general instruction in the prompt before the context, high sensitivity to compression;
- **question**(str), general question in the prompt after the context, high sensitivity to compression;
- **ratio**(float, optional), target compression ratio, the larger the value, the fewer tokens will be retained, mutually exclusive with **target_token**, default set to 0.5;
- **target_token**(float), target compression token number, mutually exclusive with **ratio**, default set to -1;
- **iterative_size**(int), the segment size in Iterative Token-level Prompt Compression, default set to 200;
- **force_context_ids**(List[int], optional), the index list forcefully retains of **context**, default set to None,
- **force_context_number**(int, optional), the context number forcefully retains in Coarse-level Prompt Compression, default set to None,
- **use_sentence_level_filter**(bool, optional), controls the usage of the sentence-level prompt compression, default set to False;
- **use_context_level_filter**(bool, optional), controls the usage of the coarse-level prompt compression, default set to True;
- **use_token_level_filter**(bool, optional), controls the usage of the token-level prompt compression, default set to True;
- **keep_split**(bool, optional), control whether to retain all the newline separators "\n\n" in the prompt, default set to False;
- **keep_first_sentence**(bool, optional), control whether to retain the first k sentence in each context, default set to False;
- **keep_last_sentence**(bool, optional), control whether to retain the last k sentence in each context, default set to False;
- **keep_sentence_number**(int, optional), control the retain sentence number in each context, default set to 0;
- **high_priority_bonus**(int, optional), control the ppl bonus of the ratin sentence, only use when **keep_first_sentence** or **keep_last_sentence** is True, default set to 100;
- **context_budget**(str, optional), the budget in Coarse-level Prompt Compression, supported operators, like "*1.5" or "+100", default set to "+100";
- **token_budget_ratio**(float, optional), the budget ratio in sentence-level Prompt Compression, default set to 1.4;
- **condition_in_question**(str, optional), control whether use the question-aware coarse-level prompt compression, support "none", "after", "before". In the LongLLMLingua, it is necessary to set to "after" or "before", default set to "none";
- **reorder_context**(str, optional), control whether use the document reordering before compression in LongLLMLingua, support "original", "sort", "two_stage", default set to "original";
- **dynamic_context_compression_ratio**(float, optional), control the ratio of dynamic context compression in LongLLMLingua, default set to 0.0;
- **condition_compare**(bool, optional), control whether use the Iterative Token-level Question-aware Fine-Grained Compression in LongLLMLingua, default set to False,
- **add_instruction**(bool, optional), control whether add the instuct before prompt in Iterative Token-level Question-aware Fine-Grained Compression, default set to False;
- **rank_method**(bool, optional), control the rank method use in Coarse-level Prompt Compression, support "llmlingua", "longllmlingua", "bm25", "gzip", "sentbert", "openai", default set to "llmlingua";
- **concate_question**(bool, optional), control whether include the question in the compressed prompt, default set to True;
### Response
- **compressed_prompt**(str), the compressed prompt;
- **origin_tokens**(int), the token number of original prompt;
- **compressed_tokens**(int), the token number of compressed prompt;
- **ratio**(str), the actual compression ratio;
- **saving**(str), the saving cost in GPT-4.
## Post-precessing
```python
compressed_prompt = llm_lingua.recover(
original_prompt: str,
compressed_prompt: str,
response: str,
)
```
### Parameters
- **original_prompt**(str), the original prompt;
- **compressed_prompt**(str), the compressed prompt;
- **response**(str), the response of the compressed prompt from black-box LLMs;
### Response
- **recovered_response**(str), the recovered response;

View File

@@ -2,16 +2,21 @@
<img src="images/LLMLingua_logo.png" alt="LLMLingua" style="width: 20%; min-width: 100px; display: block; margin: auto;">
</p>
# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models
# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [[paper]()] & LongLLMLingua [[paper]()]
This repo contains the code for LLMLingua, a project that compresses prompts and speeds up inference for LLMs with minimal loss of performance.
https://github.com/microsoft/LLMLingua/assets/30883354/ef52995c-ef3c-4eac-a9fd-1acb491c325b
[LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models]() ().
## Tl;DR
LLMLingua, that uses a well-trained small language model after alignment, such as GPT2-small or LLaMA-7B, to detect the unimportant tokens in the prompt and enable inference with the compressed prompt in black-box LLMs, achieving up to 20x compression with minimal performance loss.
[LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models]() (EMNLP 2023).
_Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
LongLLMLingua is a method that enhances LLMs' ability to perceive key information in long-context scenarios using prompt compression, achieveing up to $28.5 in cost savings per 1,000 samples while also improving performance.
PS: We also release a hackathon demo to show our idea. Please check [here](https://hackbox.microsoft.com/hackathons/hackathon2023/project/26540).
[LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression]() (Under Review).
_Huiqiang Jiang, Qianhui Wu, Xufang Luo, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
## 🎥 Overview
@@ -25,12 +30,12 @@ Large language models, such as ChatGPT and GPT-4, impress us with their amazing
![image](./images/LLMLingua_framework.png)
Now you can use **LLMLingua**!
Now you can use **LLMLingua** & **LongLLMLingua**!
A simple and efficient method to compress prompt up to **20x**.
- 💰 **Saving cost**, not only prompt, but also the generation length;
- 📝 **Support longer contexts**;
- 📝 **Support longer contexts** while delivering enhanced performance;
- ⚖️ **Robustness**, no need any training for the LLMs;
- 🕵️ **Keeping** the original prompt knowledge like ICL, reasoning, etc.
- 📜 **KV-Cache compression**, speedup inference;
@@ -43,6 +48,16 @@ If you find this repo helpful, please cite the following paper:
@inproceedings{jiang-etal-2023-llmlingua,
title = "LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models",
author = "Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang and Lili Qiu",
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2023",
publisher = "Association for Computational Linguistics",
}
```
```bibtex
@inproceedings{jiang-etal-2023-longllmlingua,
title = "LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression",
author = "Huiqiang Jiang, Qianhui Wu, Xufang Luo, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu",
}
```
@@ -59,8 +74,8 @@ Then, you can use LLMLingua to compress your prompt,
```python
from llmlingua import PromptCompressor
llmlingua = PromptCompressor()
compressed_prompt = llmlingua.compress_prompt(prompt, instruction="", question="", target_token=200)
llm_lingua = PromptCompressor()
compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
# > {'compressed_prompt': 'Question: Sam bought a dozen boxes, each with 30 highlighter pens inside, for $10 each box. He reanged five of boxes into packages of sixlters each and sold them $3 per. He sold the rest theters separately at the of three pens $2. How much did make in total, dollars?\nLets think step step\nSam bought 1 boxes x00 oflters.\nHe bought 12 * 300ters in total\nSam then took 5 boxes 6ters0ters.\nHe sold these boxes for 5 *5\nAfterelling these boxes there were 3030 highlighters remaining.\nThese form 330 / 3 = 110 groups of three pens.\nHe sold each of these groups for $2 each, so made 110 * 2 = $220 from them.\nIn total, then, he earned $220 + $15 = $235.\nSince his original cost was $120, he earned $235 - $120 = $115 in profit.\nThe answer is 115',
# 'origin_tokens': 2365,
@@ -69,6 +84,11 @@ compressed_prompt = llmlingua.compress_prompt(prompt, instruction="", question="
# 'saving': ', Saving $0.1 in GPT-4.'}
```
You can refer to this [document](./DOCUMENT.md) for more recommendations on how to use LLMLingua effectively.
## Frequently Asked Questions
show in [Transparency_FAQ.md](./Transparency_FAQ.md)
## Contributing

View File

@@ -1,13 +1,3 @@
# TODO: The maintainer of this repo has not yet edited this file
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
- **No CSS support:** Fill out this template with information about how to file issues and get help.
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
# Support
## How to file issues and get help
@@ -16,9 +6,7 @@ This project uses GitHub Issues to track bugs and feature requests. Please searc
issues before filing new issues to avoid duplicates. For new issues, file your bug or
feature request as a new Issue.
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
For help and questions about using this project, please refer the [document](./DOCUMENT.md).
## Microsoft Support Policy

41
Transparency_FAQ.md Normal file
View File

@@ -0,0 +1,41 @@
# LLMLingua's Responsible AI FAQ
## What is LLMLingua?
- LLMLingua is a simple and efficient method to compress prompt up to 20x and keeping the original prompt knowledge like ICL, reasoning, etc.
- LLMLingua takes user-defined prompts and compression goals as input, and outputs a compressed prompt, which may often result in a form of expression that is difficult for humans to understand.
## What can LLMLingua do?
- LLMLingua can simultaneously reduce the length of prompts and the output of LLMs (20%-30%), thus saving API calls;
- Compressed prompts from LLMLingua can be directly used with black-box LLMs, such as ChatGPT, GPT-4, and Claude;
- By compressing prompts, LLMLingua allows for more information to be included within the original token length, thereby improving model performance;
- LLMLingua relies on a small language model, like GPT-2 or LLaMA-7b, for perplexity calculations, which is a relatively low-cost approach;
- Compressed prompts generated by LLMLingua can be understood by LLMs, preserving their original capabilities in downstream tasks and keeping the original prompt knowledge like ICL, reasoning, etc. LLMs can also recover the essential information from the compressed prompts;
- LLMLingua is a robustness method, no need any training for the LLMs;
- Additionally, LLMLingua can be used to compress KV-Cache, which speeds up inference.
## What is/are LLMLinguas intended use(s)?
- Users who call black-box LLM APIs similar to GPT-4, those who utilize ChatGPT to handle longer content, as well as model deployers and cloud service providers, can benefit from these techniques.
## How was LLMLingua evaluated? What metrics are used to measure performance?
- In our experiments, we conducted a detailed evaluation of the performance of compressed prompts across various tasks, particularly in those involving LLM-specific capabilities, such as In-Context Learning, reasoning tasks, summarization, and conversation tasks. We assessed our approach using compression ratio and performance loss as evaluation metrics.
## What are the limitations of LLMLingua? How can users minimize the impact of LLMLinguas limitations when using the system?
- The potential harmful, false or biased responses using the compressed prompts would likely be unchanged. Thus using LLMLingua has no inherent benefits or risks when it comes to those types of responsible AI issues.
- LLMLingua may struggle to perform well at particularly high compression ratios, especially when the original prompts are already quite short.
## What operational factors and settings allow for effective and responsible use of LLMLingua?
- Users can set parameters such as the boundaries between different components (instruction, context, question) in the prompt, compression goals, and the small model used for compression calculations. Afterward, they can input the compressed prompt into black-box LLMs for use.
## What is instruction, context, and question?
In our approach, we divide the prompts into three distinct modules: instruction, context, and question. Each prompt necessarily contains a question, but the presence of context and instruction is not always guaranteed.
- Question: This refers to the directives given by the user to the LLMs, such as inquiries, questions, or requests. Positioned after the instruction and context modules, the question module has a high sensitivity to compression.
- Context: This module provides the supplementary context needed to address the question, such as documents, demonstrations, web search results, or API call results. Located between the instruction and question modules, its sensitivity to compression is relatively low.
- Instruction: This module consists of directives given by the user to the LLMs, such as task descriptions. Placed before the instruction and context modules, the instruction module exhibits a high sensitivity to compression.

View File

@@ -1,9 +1,11 @@
import bisect
from collections import defaultdict
from typing import List
import nltk
import numpy as np
import torch
import nltk
import tiktoken
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -13,24 +15,37 @@ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
class PromptCompressor:
def __init__(
self, model_name: str = "NousResearch/Llama-2-7b-hf", device_map: str = "cuda"
self,
model_name: str = "NousResearch/Llama-2-7b-hf",
device_map: str = "cuda",
use_auth_token: bool = False,
open_api_config: dict = {},
):
self.load_model(model_name, device_map)
self.load_model(model_name, device_map, use_auth_token)
self.sbert = None
self.open_api_config = open_api_config
def load_model(self, model_name: str, device_map: str = "cuda"):
def load_model(
self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
):
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = (
config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
)
if device_map == "cuda":
self.device = (
device_map if any(key in device_map for key in ["cuda", "cpu"]) else "cuda"
)
if "cuda" in device_map or "cpu" in device_map:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
config=config,
ignore_mismatched_sizes=True,
).cuda()
).to(device_map)
if device_map == "cpu":
model = model.type(torch.float32)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
@@ -40,10 +55,11 @@ class PromptCompressor:
offload_folder="/tmp/offload",
offload_state_dict=True,
cache_dir="/tmp/cache",
use_auth_token=True,
use_auth_token=use_auth_token,
)
self.tokenizer = tokenizer
self.model = model
self.context_idxs = []
def get_ppl(
self,
@@ -51,63 +67,93 @@ class PromptCompressor:
granularity: str = "sentence",
input_ids=None,
attention_mask=None,
past_key_values=None,
return_kv=False,
end=None,
condition_mode: str = "none",
condition_pos_id: int = 0,
):
if input_ids is None:
tokenized_text = self.tokenizer(text, return_tensors="pt")
input_ids = tokenized_text["input_ids"].cuda()
attention_mask = tokenized_text["attention_mask"].cuda()
input_ids = tokenized_text["input_ids"].to(self.device)
attention_mask = tokenized_text["attention_mask"].to(self.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
else:
past_length = 0
if end is None:
end = input_ids.shape[1]
end = min(end, past_length + 4096)
with torch.no_grad():
response = self.model(input_ids, attention_mask=attention_mask)
response = self.model(
input_ids[:, past_length:end],
attention_mask=attention_mask[:, :end],
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = response.past_key_values
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
shift_logits = response.logits[
..., :-1, :
].contiguous() # batch_size x seq_len x vocab_size
shift_labels = input_ids[..., 1:].contiguous() # batch_size x seq_len
shift_logits = response.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., past_length + 1 : end].contiguous()
# Flatten the tokens
active = (attention_mask == 1)[..., :-1].view(-1)
active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
active_labels = shift_labels.view(-1)[active]
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(active_logits, active_labels)
if granularity == "token":
return loss
elif granularity == "sentence":
return loss.mean()
if condition_mode == "before":
loss = loss[:condition_pos_id]
elif condition_mode == "after":
loss = loss[condition_pos_id:]
res = loss.mean() if granularity == "sentence" else loss
return (res, past_key_values) if return_kv else res
def __call__(self, *args, **kwargs):
return self.compress_prompt(*args, **kwargs)
def compress_prompt(
self,
demonstrations: List[str],
context: List[str],
instruction: str = "",
question: str = "",
ratio: float = 0.5,
target_token: float = -1,
iterative_size: int = 200,
length_ratio: float = 0.0,
force_demonstrations_ids: List[int] = None,
force_context_ids: List[int] = None,
force_context_number: int = None,
use_sentence_level_filter: bool = False,
use_demonstrate_level_filter: bool = True,
use_context_level_filter: bool = True,
use_token_level_filter: bool = True,
keep_split: bool = False,
keep_first_sentence: int = 0,
keep_last_sentence: int = 0,
keep_sentence_number: int = 0,
high_priority_bonus: int = 100,
context_budget: str = "+100",
token_budget_ratio: float = 1.4,
condition_in_question: bool = False,
condition_in_question: str = "none",
reorder_context: str = "original",
dynamic_context_compression_ratio: float = 0.0,
condition_compare: bool = False,
add_instruction: bool = False,
rank_method: str = "llmlingua",
concate_question: bool = True,
):
if isinstance(demonstrations, str):
demonstrations = [demonstrations]
if isinstance(context, str):
context = [context]
assert not (
rank_method == "longllmlingua" and not question
), "In the LongLLMLingua, it is necessary to set a question."
if rank_method == "longllmlingua":
if condition_in_question == "none":
condition_in_question = "after"
elif rank_method == "llmlingua":
condition_in_question = "none"
origin_tokens = len(
encoding.encode(
"\n\n".join([instruction] + demonstrations + [question]).strip()
encoding.encode("\n\n".join([instruction] + context + [question]).strip())
)
)
demonstrations_tokens_length = [
self.get_token_length(demonstration) for demonstration in demonstrations
]
context_tokens_length = [self.get_token_length(c) for c in context]
instruction_tokens_length, question_tokens_length = self.get_token_length(
instruction
), self.get_token_length(question)
@@ -116,105 +162,256 @@ class PromptCompressor:
(
instruction_tokens_length
+ question_tokens_length
+ sum(demonstrations_tokens_length)
+ sum(context_tokens_length)
)
* (1 - ratio)
- instruction_tokens_length
- question_tokens_length
- (question_tokens_length if concate_question else 0)
)
if len(demonstrations) > 1 and use_demonstrate_level_filter:
demonstrations = self.control_demonstrations_budget(
demonstrations,
demonstrations_tokens_length,
condition_flag = "_condition" in condition_in_question
condition_in_question = condition_in_question.replace("_condition", "")
if len(context) > 1 and use_context_level_filter:
context, dynamic_ratio = self.control_context_budget(
context,
context_tokens_length,
target_token,
length_ratio,
force_demonstrations_ids,
force_context_ids,
force_context_number,
question,
condition_in_question,
reorder_context=reorder_context,
dynamic_context_compression_ratio=dynamic_context_compression_ratio,
rank_method=rank_method,
context_budget=context_budget,
)
else:
dynamic_ratio = [0.0] * len(context)
if use_sentence_level_filter:
demonstrations = self.control_sentence_budget(
demonstrations,
context = self.control_sentence_budget(
context,
target_token,
keep_first_sentence=keep_first_sentence,
keep_last_sentence=keep_last_sentence,
keep_sentence_number=keep_sentence_number,
high_priority_bonus=high_priority_bonus,
token_budget_ratio=token_budget_ratio,
question=question,
condition_in_question=condition_in_question,
rank_method=rank_method,
)
if condition_in_question:
demonstrations = [question] + demonstrations
if condition_flag:
if add_instruction:
context = [question + "\n\n" + instruction] + context
start = self.get_token_length(question + "\n\n" + instruction) + 2
else:
context = [question] + context
start = self.get_token_length(question) + 2
else:
start = 0
demonstrations = self.iterative_compress_prompt(
demonstrations,
if use_token_level_filter:
context = self.iterative_compress_prompt(
context,
target_token,
iterative_size=iterative_size,
keep_split=keep_split,
start=start,
dynamic_ratio=dynamic_ratio,
condition_compare=condition_compare,
)
compressed_prompt = (
self.tokenizer.batch_decode(context[0])[0]
.replace("<s> ", "")
.replace("<s>", "")
)
else:
compressed_prompt = "\n\n".join(context)
context = self.tokenizer.batch_decode(demonstrations[0])[0].replace("<s> ", "")
if instruction:
context = instruction + "\n\n" + context
if question:
context = context + "\n\n" + question
compressed_prompt = instruction + "\n\n" + compressed_prompt
if question and concate_question:
compressed_prompt = compressed_prompt + "\n\n" + question
compressed_tokens = len(encoding.encode(context))
compressed_tokens = len(encoding.encode(compressed_prompt))
saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
return {
"compressed_prompt": context,
"compressed_prompt": compressed_prompt,
"origin_tokens": origin_tokens,
"compressed_tokens": compressed_tokens,
"ratio": f"{origin_tokens/compressed_tokens:.1f}x",
"saving": f", Saving ${saving:.1f} in GPT-4.",
}
def get_token_length(self, text: str):
return len(self.tokenizer(text).input_ids)
def get_token_length(self, text: str, add_special_tokens: bool = True):
return len(
self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
)
def control_demonstrations_budget(
def get_condition_ppl(
self,
demonstrations: List[str],
demonstrations_tokens_length: List[int],
target_token: float,
length_ratio: float = 0.0,
force_demonstrations_ids: List[int] = None,
text: str,
question: str,
condition_in_question: str = "none",
granularity: str = "sentence",
):
if force_demonstrations_ids is not None:
return [demonstrations[ii] for ii in force_demonstrations_ids]
demonstrations_ppl = [
self.get_ppl(d) - dl * 2 / 250 * length_ratio
for d, dl in zip(demonstrations, demonstrations_tokens_length)
]
if condition_in_question == "none":
return self.get_ppl(text, granularity=granularity)
elif condition_in_question == "before":
return self.get_ppl(
question + text,
granularity=granularity,
condition_mode="after",
condition_pos_id=self.get_token_length(question) - 1,
)
elif condition_in_question == "after":
return self.get_ppl(
text + question,
granularity=granularity,
condition_mode="after",
condition_pos_id=self.get_token_length(text) - 1,
)
def get_dynamic_compression_ratio(
self,
context: list,
target_token: float,
iterative_size: int,
dynamic_ratio: list,
start: int,
):
def get_ratio(base: float, delta: float):
return max(min(1, base + delta), 0)
context_length = [self.get_token_length(ii, False) + 2 for ii in context]
if start:
context_length = context_length[1:]
tau = target_token / (sum(context_length) + 1)
res, idx, last, last_target = [], 0, 1, []
while idx < len(context_length):
if last + context_length[idx] >= iterative_size:
last_target.append(
(iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
)
res.append(last_target)
last = last + context_length[idx] - iterative_size
if last > iterative_size:
k = last // iterative_size
res.extend(
[[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
)
last -= k * iterative_size
last_target = (
[(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
)
else:
last += context_length[idx]
last_target.append(
(context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
)
idx += 1
if last_target:
res.append(last_target)
return res
def control_context_budget(
self,
context: List[str],
context_tokens_length: List[int],
target_token: float,
force_context_ids: List[int] = None,
force_context_number: int = None,
question: str = "",
condition_in_question: str = "none",
reorder_context: str = "original",
dynamic_context_compression_ratio: float = 0.0,
rank_method: str = "longllmlingua",
context_budget: str = "+100",
):
if force_context_ids is not None:
return [context[ii] for ii in force_context_ids]
demostrations_sort = self.get_rank_results(
context,
question,
rank_method,
condition_in_question,
context_tokens_length,
)
if target_token < 0:
target_token = 100
target_token += 100
target_token = eval("target_token" + context_budget)
res = []
for idx, ppl in sorted(enumerate(demonstrations_ppl), key=lambda x: -x[1]):
target_token -= demonstrations_tokens_length[idx]
res.append(demonstrations[idx])
if target_token < 0:
used = force_context_ids if force_context_ids is not None else []
self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
for idx, _ in demostrations_sort:
if idx >= len(context_tokens_length):
continue
target_token -= context_tokens_length[idx]
if idx not in used:
used.append(idx)
if target_token < 0 or (
force_context_number is not None and len(res) >= force_context_number
):
break
return res
original_used = used
if reorder_context == "original":
used = sorted(used)
elif reorder_context == "two_stage":
l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
_ for idx, _ in enumerate(used) if idx % 2 == 1
]
used = l + r[::-1]
if dynamic_context_compression_ratio > 0:
N = len(used)
if condition_in_question:
rank = [
i
for i, _ in self.get_rank_results(
context,
question,
"longllmlingua",
"after",
context_tokens_length,
)
]
used = sorted(used, key=lambda x: rank.index(x))
dynamic_ratio = [
i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
for i in range(-(N - 1), N, 2)
][::-1]
dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
dynamic_ratio = [dynamic_ratio_map[i] for i in used]
else:
dynamic_ratio = [0.0] * len(used)
res = [context[idx] for idx in used if idx < len(context)]
return res, dynamic_ratio
def control_sentence_budget(
self,
demonstrations: List[str],
context: List[str],
target_token: float,
keep_first_sentence: int = 0,
keep_last_sentence: int = 0,
keep_sentence_number: int = 0,
high_priority_bonus: int = 100,
token_budget_ratio: float = 1.4,
question: str = "",
condition_in_question: str = "none",
rank_method: str = "longllmlingua",
):
def keep_sentence(dem_idx: int, sent_keep: int):
idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
for idx in idxs:
sentence_ppl[idx] += high_priority_bonus
sentences = [
nltk.sent_tokenize(demonstration) for demonstration in demonstrations
]
sentences = [nltk.sent_tokenize(c) for c in context]
dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
for idx_d, s in enumerate(sentences):
for _ in s:
@@ -222,35 +419,55 @@ class PromptCompressor:
s2de[idx] = idx_d
idx += 1
demonstrations_sentences = [s for ii in sentences for s in ii]
context_sentences = [s for ii in sentences for s in ii]
sentence_tokens_length = [
self.get_token_length(sentence) for sentence in demonstrations_sentences
self.get_token_length(sentence) for sentence in context_sentences
]
N = len(context_sentences)
flags = list(range(len(context_sentences)))
if len(sentence_tokens_length) == 1:
return demonstrations
return context
if rank_method == "longllmlingua":
sentence_ppl = [
self.get_ppl(sentence).cpu().numpy().item()
for sentence in demonstrations_sentences
self.get_condition_ppl(sentence, question, condition_in_question)
.cpu()
.numpy()
.item()
for sentence in context_sentences
]
if keep_first_sentence:
sentence_ppl[:keep_first_sentence] = [
ii + high_priority_bonus for ii in sentence_ppl[:keep_first_sentence]
ii + high_priority_bonus
for ii in sentence_ppl[:keep_first_sentence]
]
if keep_last_sentence:
sentence_ppl[-keep_last_sentence:] = [
ii + high_priority_bonus for ii in sentence_ppl[-keep_last_sentence:]
ii + high_priority_bonus
for ii in sentence_ppl[-keep_last_sentence:]
]
if keep_sentence_number:
for dem_idx in range(len(sentences)):
keep_sentence(dem_idx, keep_sentence_number)
sort_direct = -1 if condition_in_question == "none" else 1
sent_sort = sorted(
enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
)
else:
sent_sort = self.get_rank_results(
context_sentences,
question,
rank_method,
condition_in_question,
[0] * len(context_sentences),
)
N = len(demonstrations_sentences)
sentence_flags = [False] * N
if target_token < 0:
target_token = 100
target_token *= token_budget_ratio
res = []
for idx, ppl in sorted(enumerate(sentence_ppl), key=lambda x: -x[1]):
for idx, _ in sent_sort:
idx = flags[idx]
target_token -= sentence_tokens_length[idx]
sentence_flags[idx] = True
if target_token < 0:
@@ -273,10 +490,51 @@ class PromptCompressor:
threshold=0.5,
keep_flag=None,
split_token_id: int = 13,
start: int = 0,
self_loss=None,
self_input_ids=None,
self_attention_mask=None,
):
if self_loss is not None:
need_idx = torch.concat(
[
loss[:start] > 0,
self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
loss[:1] > 0,
]
)
else:
need_idx = torch.concat([loss > threshold, loss[:1] > 0])
need_idx[end:] = 1
need_idx[: end - iterative_size] = 1
loss = loss[need_idx[:-1]]
if self_loss is not None:
if need_idx.shape[0] < self_loss.shape[0] + start + 1:
need_idx = torch.cat(
[
need_idx,
torch.ones(
self_loss.shape[0] - need_idx.shape[0] + start + 1,
dtype=torch.bool,
).to(need_idx.device),
]
)
self_loss = self_loss[need_idx[start:-1]]
if need_idx.shape[0] < input_ids.shape[1]:
need_idx = torch.cat(
[
need_idx,
torch.ones(
input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
).to(need_idx.device),
]
)
elif need_idx.shape[0] > input_ids.shape[1]:
need_idx = need_idx[: input_ids.shape[1]]
if keep_flag is not None:
need_idx[keep_flag == 1] = 1
last = -1
if keep_flag is not None:
for ii in range(end - iterative_size, end):
@@ -295,32 +553,84 @@ class PromptCompressor:
compressed_attention_mask = attention_mask[attention_mask == 1][
need_idx
].unsqueeze(0)
if self_loss is not None:
self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
need_idx[start:]
].unsqueeze(0)
self_compressed_attention_mask = self_attention_mask[
self_attention_mask == 1
][need_idx[start:]].unsqueeze(0)
else:
self_compressed_input_ids, self_compressed_attention_mask = None, None
if keep_flag is not None:
if len(keep_flag) > len(need_idx):
keep_flag = torch.cat(
[
keep_flag[:start],
keep_flag[start : len(need_idx) + start][need_idx],
keep_flag[start + len(need_idx) :],
]
)
else:
keep_flag = keep_flag[need_idx]
end -= (need_idx[:end] == 0).sum()
return compressed_input_ids, compressed_attention_mask, keep_flag, end
return (
compressed_input_ids,
compressed_attention_mask,
keep_flag,
end,
loss,
self_loss,
self_compressed_input_ids,
self_compressed_attention_mask,
)
def get_estimate_threshold_base_distribution(self, ppl, target_token: int):
target_token = max(0, min(len(ppl) - 1, int(target_token)))
return ppl.sort(descending=True).values[target_token].detach().cpu().item()
def get_estimate_threshold_base_distribution(
self, ppl, ratio: float, condition_flag: bool = False
):
ppl = ppl[ppl != 10000]
target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
return (
ppl.sort(descending=not condition_flag)
.values[target_token]
.detach()
.cpu()
.item()
)
def iterative_compress_prompt(
self,
demonstrations: List[str],
context: List[str],
target_token: float,
iterative_size: int = 200,
keep_split: bool = False,
split_token_id: int = 13,
start: int = 0,
dynamic_ratio: list = None,
condition_compare: bool = False,
):
demonstrations = "\n\n".join(demonstrations)
tokenized_text = self.tokenizer(demonstrations, return_tensors="pt")
input_ids = tokenized_text["input_ids"].cuda()
attention_mask = tokenized_text["attention_mask"].cuda()
iterative_ratios = self.get_dynamic_compression_ratio(
context, target_token, iterative_size, dynamic_ratio, start
)
context = "\n\n".join(context)
tokenized_text = self.tokenizer(context, return_tensors="pt")
input_ids = tokenized_text["input_ids"].to(self.device)
attention_mask = tokenized_text["attention_mask"].to(self.device)
N = (attention_mask == 1).sum()
end = iterative_size + start
compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
if condition_compare:
self_input_ids, self_attention_mask = (
input_ids[:, start:],
attention_mask[:, start:],
)
self_compressed_input_ids, self_compressed_attention_mask = (
self_input_ids,
self_attention_mask,
)
end = min(iterative_size + start, compressed_input_ids.shape[1])
threshold, keep_flag = None, None
if keep_split:
input_ids_numpy = input_ids.cpu().detach().numpy()[0]
@@ -340,32 +650,275 @@ class PromptCompressor:
)
for ii in range(N)
]
keep_flag = torch.tensor(keep_flag).cuda()
while end < compressed_input_ids.shape[1]:
loss = self.get_ppl(
"", "token", compressed_input_ids, compressed_attention_mask
keep_flag = torch.tensor(keep_flag).to(self.device)
past_key_values, past_loss, ready_end = None, None, 0
self_past_key_values, self_past_loss, self_ready_end = None, None, 0
idx = 0
while end <= compressed_input_ids.shape[1]:
loss, past_key_values = self.get_ppl(
"",
"token",
compressed_input_ids,
compressed_attention_mask,
past_key_values=past_key_values,
return_kv=True,
end=end if idx else None,
)
# if threshold is None:
if past_loss is not None:
if end - 1 > len(past_loss):
past_loss = torch.cat(
[past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
)
past_loss[ready_end : end - 1] = loss
loss = past_loss
else:
past_loss = loss
if idx:
past_key_values = [
[k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
for k, v in past_key_values
]
else:
past_key_values = None
if condition_compare:
self_loss, self_past_key_values = self.get_ppl(
"",
"token",
self_compressed_input_ids,
self_compressed_attention_mask,
past_key_values=self_past_key_values,
return_kv=True,
end=end - start if idx else None,
)
if self_past_loss is not None:
if end - start - 1 > len(self_past_loss):
self_past_loss = torch.cat(
[
self_past_loss,
torch.zeros_like(self_loss)[
: end - 1 - start - len(self_past_loss)
],
]
)
self_past_loss[self_ready_end : end - start - 1] = self_loss
self_loss = self_past_loss
else:
self_past_loss = self_loss
if idx:
self_past_key_values = [
[
k[:, :, : end - iterative_size - start],
v[:, :, : end - iterative_size - start],
]
for k, v in self_past_key_values
]
else:
self_past_key_values = None
self_ready_end = (
end - start - iterative_size if not (start and idx == 0) else 0
)
ready_end = end - iterative_size if not (start and idx == 0) else 0
for delta_end, ratio in iterative_ratios[idx]:
loss = past_loss
if condition_compare:
self_loss = self_past_loss
threshold = self.get_estimate_threshold_base_distribution(
loss, target_token
self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
)
else:
threshold = self.get_estimate_threshold_base_distribution(
loss, ratio, False
)
if keep_split:
loss[keep_flag[:-1] == 1] = 100
(
compressed_input_ids,
compressed_attention_mask,
keep_flag,
end,
past_loss,
self_past_loss,
self_compressed_input_ids,
self_compressed_attention_mask,
) = self.get_compressed_input(
loss,
compressed_input_ids,
compressed_attention_mask,
end,
iterative_size=iterative_size,
end - iterative_size + delta_end,
iterative_size=delta_end,
threshold=threshold,
keep_flag=keep_flag,
split_token_id=split_token_id,
start=start,
self_loss=self_loss if condition_compare else None,
self_input_ids=self_compressed_input_ids
if condition_compare
else None,
self_attention_mask=self_compressed_attention_mask
if condition_compare
else None,
)
end += iterative_size
idx += 1
return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
def recover(
self,
original_prompt: str,
compressed_prompt: str,
response: str,
):
def match_from_compressed(response_word):
response_input_ids = self.tokenizer(
response_word, add_special_tokens=False
)["input_ids"]
response_set, response_c = set(response_input_ids), defaultdict(list)
for idx in range(M):
if original_input_ids[idx] in response_set:
response_c[original_input_ids[idx]].append(idx)
res, res_min, res_c = None, float("inf"), 1
n = len(response_input_ids)
for l in response_c[response_input_ids[0]]:
x, y, c = 0, l, 1
for x in range(1, n):
idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
if (
idx >= len(response_c[response_input_ids[x]])
or response_c[response_input_ids[x]][idx] - y > 10
):
continue
c += 1
y = response_c[response_input_ids[x]][idx]
if c > res_c:
res_c = c
res_min = y - l + 1
res = (l, y + 1)
elif c == res_c and y - l + 1 < res_min:
res_min = y - l + 1
res = (l, y + 1)
if res is None:
return response_word
# while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
# l -= 1
# while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
# l -= 1
return self.tokenizer.decode(original_input_ids[res[0] : res[1]])
response_words = response.split(" ")
original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
"input_ids"
]
N, M = len(response_words), len(original_input_ids)
recovered_response_words = []
l = 0
while l < N:
if response_words[l] not in compressed_prompt:
recovered_response_words.append(response_words[l])
l += 1
continue
r = l
while (
r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
):
r += 1
match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
recovered_response_words.append(match_words)
l = r + 1
return " ".join(recovered_response_words)
def get_rank_results(
self,
context: list,
question: str,
rank_method: str,
condition_in_question: str,
context_tokens_length: list,
):
def get_distance_bm25(corpus, query):
from rank_bm25 import BM25Okapi
tokenized_corpus = [doc.split(" ") for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = query.split(" ")
doc_scores = bm25.get_scores(tokenized_query)
idx = [(ii, 0) for ii in (-doc_scores).argsort()]
return idx
def get_distance_gzip(corpus, query):
def get_score(x, y):
cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
cxy = len(gzip.compress(f"{x} {y}".encode()))
return (cxy - min(cx, cy)) / max(cx, cy)
import gzip
doc_scores = [get_score(doc, query) for doc in corpus]
idx = [(ii, 0) for ii in np.argsort(doc_scores)]
return idx
def get_distance_sentbert(corpus, query):
from sentence_transformers import SentenceTransformer, util
if self.sbert is None:
self.sbert = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
doc_embeds = self.sbert.encode(corpus)
query = self.sbert.encode(query)
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
idx = [(ii, 0) for ii in np.argsort(doc_scores)]
return idx
def get_distance_openai(corpus, query):
import openai
from sentence_transformers import util
openai.api_key = self.open_api_config.get("api_key", "")
openai.api_base = self.open_api_config.get(
"api_base", "https://api.openai.com/v1"
)
openai.api_type = self.open_api_config.get("api_type", "open_ai")
openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
engine = self.open_api_config.get("engine", "text-embedding-ada-002")
def get_embed(text):
return openai.Embedding.create(
input=[text.replace("\n", " ")], engine=engine
)["data"][0]["embedding"]
doc_embeds = [get_embed(i) for i in corpus]
query = get_embed(query)
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
idx = [(ii, 0) for ii in np.argsort(doc_scores)]
return idx
def get_distance_longllmlingua(corpus, query):
context_ppl = [
self.get_condition_ppl(
d,
query
+ " We can get the answer to this question in the given documents.",
condition_in_question,
)
- dl * 2 / 250 * 0
for d, dl in zip(corpus, context_tokens_length)
]
sort_direct = -1 if condition_in_question == "none" else 1
ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1])
return ys
method = None
if rank_method == "bm25":
method = get_distance_bm25
elif rank_method == "gzip":
method = get_distance_gzip
elif rank_method == "sentbert":
method = get_distance_sentbert
elif rank_method == "openai":
method = get_distance_openai
elif rank_method in ["longllmlingua", "llmlingua"]:
method = get_distance_longllmlingua
return method(context, question)

View File

@@ -1,11 +1,11 @@
_MAJOR = "0"
_MINOR = "0"
_MINOR = "1"
# On master and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "1"
_PATCH = "0"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = "dev0"
_SUFFIX = ""
VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
VERSION = "{0}.{1}.{2}.{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)

View File

View File

@@ -24,6 +24,7 @@ INSTALL_REQUIRES = [
"torch",
"tiktoken",
"nltk",
"numpy",
]
QUANLITY_REQUIRES = [
"black==21.4b0",