mirror of
				https://github.com/huggingface/text-generation-inference.git
				synced 2023-08-15 01:09:35 +03:00 
			
		
		
		
	feat(router): add endpoint info to /info route (#228)
This commit is contained in:
		@@ -21,6 +21,7 @@ pub struct HubModelInfo {
 | 
			
		||||
 | 
			
		||||
#[derive(Clone, Debug, Serialize, ToSchema)]
 | 
			
		||||
pub struct Info {
 | 
			
		||||
    /// Model info
 | 
			
		||||
    #[schema(example = "bigscience/blomm-560m")]
 | 
			
		||||
    pub model_id: String,
 | 
			
		||||
    #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
 | 
			
		||||
@@ -31,6 +32,26 @@ pub struct Info {
 | 
			
		||||
    pub model_device_type: String,
 | 
			
		||||
    #[schema(nullable = true, example = "text-generation")]
 | 
			
		||||
    pub model_pipeline_tag: Option<String>,
 | 
			
		||||
    /// Router Parameters
 | 
			
		||||
    #[schema(example = "128")]
 | 
			
		||||
    pub max_concurrent_requests: usize,
 | 
			
		||||
    #[schema(example = "2")]
 | 
			
		||||
    pub max_best_of: usize,
 | 
			
		||||
    #[schema(example = "4")]
 | 
			
		||||
    pub max_stop_sequences: usize,
 | 
			
		||||
    #[schema(example = "1024")]
 | 
			
		||||
    pub max_input_length: usize,
 | 
			
		||||
    #[schema(example = "2048")]
 | 
			
		||||
    pub max_total_tokens: usize,
 | 
			
		||||
    #[schema(example = "1.2")]
 | 
			
		||||
    pub waiting_served_ratio: f32,
 | 
			
		||||
    #[schema(example = "32000")]
 | 
			
		||||
    pub max_batch_total_tokens: u32,
 | 
			
		||||
    #[schema(example = "20")]
 | 
			
		||||
    pub max_waiting_tokens: usize,
 | 
			
		||||
    #[schema(example = "2")]
 | 
			
		||||
    pub validation_workers: usize,
 | 
			
		||||
    /// Router Info
 | 
			
		||||
    #[schema(example = "0.5.0")]
 | 
			
		||||
    pub version: &'static str,
 | 
			
		||||
    #[schema(nullable = true, example = "null")]
 | 
			
		||||
 
 | 
			
		||||
@@ -78,22 +78,8 @@ async fn compat_generate(
 | 
			
		||||
    responses((status = 200, description = "Served model info", body = Info))
 | 
			
		||||
)]
 | 
			
		||||
#[instrument]
 | 
			
		||||
async fn get_model_info(
 | 
			
		||||
    model_info: Extension<HubModelInfo>,
 | 
			
		||||
    shard_info: Extension<ShardInfo>,
 | 
			
		||||
) -> Json<Info> {
 | 
			
		||||
    let model_info = model_info.0;
 | 
			
		||||
    let shard_info = shard_info.0;
 | 
			
		||||
    let info = Info {
 | 
			
		||||
        version: env!("CARGO_PKG_VERSION"),
 | 
			
		||||
        sha: option_env!("VERGEN_GIT_SHA"),
 | 
			
		||||
        model_id: model_info.model_id,
 | 
			
		||||
        model_sha: model_info.sha,
 | 
			
		||||
        model_dtype: shard_info.dtype,
 | 
			
		||||
        model_device_type: shard_info.device_type,
 | 
			
		||||
        model_pipeline_tag: model_info.pipeline_tag,
 | 
			
		||||
    };
 | 
			
		||||
    Json(info)
 | 
			
		||||
async fn get_model_info(info: Extension<Info>) -> Json<Info> {
 | 
			
		||||
    Json(info.0)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Health check method
 | 
			
		||||
@@ -632,6 +618,26 @@ pub async fn run(
 | 
			
		||||
        .allow_headers([http::header::CONTENT_TYPE])
 | 
			
		||||
        .allow_origin(allow_origin);
 | 
			
		||||
 | 
			
		||||
    // Endpoint info
 | 
			
		||||
    let info = Info {
 | 
			
		||||
        model_id: model_info.model_id,
 | 
			
		||||
        model_sha: model_info.sha,
 | 
			
		||||
        model_dtype: shard_info.dtype,
 | 
			
		||||
        model_device_type: shard_info.device_type,
 | 
			
		||||
        model_pipeline_tag: model_info.pipeline_tag,
 | 
			
		||||
        max_concurrent_requests,
 | 
			
		||||
        max_best_of,
 | 
			
		||||
        max_stop_sequences,
 | 
			
		||||
        max_input_length,
 | 
			
		||||
        max_total_tokens,
 | 
			
		||||
        waiting_served_ratio,
 | 
			
		||||
        max_batch_total_tokens,
 | 
			
		||||
        max_waiting_tokens,
 | 
			
		||||
        validation_workers,
 | 
			
		||||
        version: env!("CARGO_PKG_VERSION"),
 | 
			
		||||
        sha: option_env!("VERGEN_GIT_SHA"),
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // Create router
 | 
			
		||||
    let app = Router::new()
 | 
			
		||||
        .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
 | 
			
		||||
@@ -650,8 +656,7 @@ pub async fn run(
 | 
			
		||||
        .route("/ping", get(health))
 | 
			
		||||
        // Prometheus metrics route
 | 
			
		||||
        .route("/metrics", get(metrics))
 | 
			
		||||
        .layer(Extension(model_info))
 | 
			
		||||
        .layer(Extension(shard_info))
 | 
			
		||||
        .layer(Extension(info))
 | 
			
		||||
        .layer(Extension(compat_return_full_text))
 | 
			
		||||
        .layer(Extension(infer))
 | 
			
		||||
        .layer(Extension(prom_handle))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user