mirror of
				https://github.com/huggingface/text-generation-inference.git
				synced 2023-08-23 10:47:54 +03:00 
			
		
		
		
	feat(router): add tests to validation (#237)
This commit is contained in:
		
							
								
								
									
										3
									
								
								.github/workflows/tests.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/tests.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -67,6 +67,9 @@ jobs: | ||||
|         run: | | ||||
|           pip install pytest | ||||
|           HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests | ||||
|       - name: Run Clippy | ||||
|         run: | | ||||
|           cargo clippy | ||||
|       - name: Run Rust tests | ||||
|         run: | | ||||
|           cargo test | ||||
|   | ||||
| @@ -276,3 +276,19 @@ pub(crate) struct ErrorResponse { | ||||
|     pub error: String, | ||||
|     pub error_type: String, | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests{ | ||||
|     use std::io::Write; | ||||
|     use tokenizers::Tokenizer; | ||||
|  | ||||
|     pub(crate) async fn get_tokenizer() -> Tokenizer{ | ||||
|         if !std::path::Path::new("tokenizer.json").exists(){ | ||||
|             let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); | ||||
|              let mut file = std::fs::File::create("tokenizer.json").unwrap(); | ||||
|             file.write_all(&content).unwrap(); | ||||
|         } | ||||
|         Tokenizer::from_file("tokenizer.json").unwrap() | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -141,6 +141,7 @@ impl State { | ||||
|  | ||||
|     // Get the next batch | ||||
|     fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> { | ||||
|  | ||||
|         if self.entries.is_empty() { | ||||
|             return None; | ||||
|         } | ||||
| @@ -430,7 +431,17 @@ mod tests { | ||||
|         let (entry3, _guard3) = default_entry(); | ||||
|         queue.append(entry3); | ||||
|  | ||||
|         // Not enough requests pending | ||||
|         assert!(queue.next_batch(Some(2), 2).await.is_none()); | ||||
|         // Not enough token budget | ||||
|         assert!(queue.next_batch(Some(1), 0).await.is_none()); | ||||
|         // Ok | ||||
|         let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap(); | ||||
|         assert_eq!(entries2.len(), 1); | ||||
|         assert!(entries2.contains_key(&2)); | ||||
|         assert!(entries2.get(&2).unwrap().batch_time.is_some()); | ||||
|         assert_eq!(batch2.id, 1); | ||||
|         assert_eq!(batch2.size, 1); | ||||
|     } | ||||
|  | ||||
|     #[tokio::test] | ||||
|   | ||||
| @@ -741,3 +741,4 @@ impl From<InferError> for Event { | ||||
|             .unwrap() | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -382,7 +382,8 @@ pub enum ValidationError { | ||||
| #[cfg(test)] | ||||
| mod tests{ | ||||
|     use super::*; | ||||
|     use std::io::Write; | ||||
|     use crate::default_parameters; | ||||
|     use crate::tests::get_tokenizer; | ||||
|  | ||||
|     #[tokio::test] | ||||
|     async fn test_validation_max_new_tokens(){ | ||||
| @@ -401,15 +402,6 @@ mod tests{ | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     async fn get_tokenizer() -> Tokenizer{ | ||||
|         if !std::path::Path::new("tokenizer.json").exists(){ | ||||
|             let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); | ||||
|              let mut file = std::fs::File::create("tokenizer.json").unwrap(); | ||||
|             file.write_all(&content).unwrap(); | ||||
|         } | ||||
|         Tokenizer::from_file("tokenizer.json").unwrap() | ||||
|     } | ||||
|  | ||||
|     #[tokio::test] | ||||
|     async fn test_validation_input_length(){ | ||||
|         let tokenizer = Some(get_tokenizer().await); | ||||
| @@ -426,4 +418,73 @@ mod tests{ | ||||
|             _ => panic!("Unexpected not max new tokens") | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[tokio::test] | ||||
|     async fn test_validation_best_of_sampling(){ | ||||
|         let tokenizer = Some(get_tokenizer().await); | ||||
|         let max_best_of = 2; | ||||
|         let max_stop_sequence = 3; | ||||
|         let max_input_length = 4; | ||||
|         let max_total_tokens = 5; | ||||
|         let workers = 1; | ||||
|         let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); | ||||
|         match validation.validate(GenerateRequest{ | ||||
|             inputs: "Hello".to_string(), | ||||
|             parameters: GenerateParameters{ | ||||
|                 best_of: Some(2), | ||||
|                 do_sample: false, | ||||
|                 ..default_parameters() | ||||
|             } | ||||
|         }).await{ | ||||
|             Err(ValidationError::BestOfSampling) => (), | ||||
|             _ => panic!("Unexpected not best of sampling") | ||||
|         } | ||||
|  | ||||
|     } | ||||
|  | ||||
|     #[tokio::test] | ||||
|     async fn test_validation_top_p(){ | ||||
|         let tokenizer = Some(get_tokenizer().await); | ||||
|         let max_best_of = 2; | ||||
|         let max_stop_sequence = 3; | ||||
|         let max_input_length = 4; | ||||
|         let max_total_tokens = 5; | ||||
|         let workers = 1; | ||||
|         let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); | ||||
|         match validation.validate(GenerateRequest{ | ||||
|             inputs: "Hello".to_string(), | ||||
|             parameters: GenerateParameters{ | ||||
|                 top_p: Some(1.0), | ||||
|                 ..default_parameters() | ||||
|             } | ||||
|         }).await{ | ||||
|             Err(ValidationError::TopP) => (), | ||||
|             _ => panic!("Unexpected top_p") | ||||
|         } | ||||
|  | ||||
|         match validation.validate(GenerateRequest{ | ||||
|             inputs: "Hello".to_string(), | ||||
|             parameters: GenerateParameters{ | ||||
|                 top_p: Some(0.99), | ||||
|                 max_new_tokens: 1, | ||||
|                 ..default_parameters() | ||||
|             } | ||||
|         }).await{ | ||||
|             Ok(_) => (), | ||||
|             _ => panic!("Unexpected top_p error") | ||||
|         } | ||||
|  | ||||
|         let valid_request = validation.validate(GenerateRequest{ | ||||
|             inputs: "Hello".to_string(), | ||||
|             parameters: GenerateParameters{ | ||||
|                 top_p: None, | ||||
|                 max_new_tokens: 1, | ||||
|                 ..default_parameters() | ||||
|             } | ||||
|         }).await.unwrap(); | ||||
|         // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. | ||||
|         assert_eq!(valid_request.parameters.top_p, 1.0); | ||||
|          | ||||
|  | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nicolas Patry
					Nicolas Patry