This commit is contained in:
OlivierDehaene
2023-07-01 19:25:41 +02:00
committed by GitHub
parent 2b53d71991
commit e28a809004
16 changed files with 376 additions and 258 deletions

505
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ members = [
]
[workspace.package]
version = "0.8.2"
version = "0.9.0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.70 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

View File

@@ -84,7 +84,7 @@ model=bigscience/bloom-560m
num_shard=2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.8 --model-id $model --num-shard $num_shard
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9 --model-id $model --num-shard $num_shard
```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.

View File

@@ -1,15 +0,0 @@
# Azure ML endpoint
## Create all resources
```shell
az ml model create -f model.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
```
## Update deployment
```shell
az ml online-deployment update -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
```

View File

@@ -1,38 +0,0 @@
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
name: bloom-deployment
endpoint_name: bloom-inference
model: azureml:bloom-safetensors:1
model_mount_path: /var/azureml-model
environment_variables:
WEIGHTS_CACHE_OVERRIDE: /var/azureml-model/bloom-safetensors
MODEL_ID: bigscience/bloom
NUM_SHARD: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2.0
inference_config:
liveness_route:
port: 80
path: /health
readiness_route:
port: 80
path: /health
scoring_route:
port: 80
path: /generate
instance_type: Standard_ND96amsr_A100_v4
request_settings:
request_timeout_ms: 90000
max_concurrent_requests_per_instance: 256
liveness_probe:
initial_delay: 600
timeout: 90
period: 120
success_threshold: 1
failure_threshold: 5
readiness_probe:
initial_delay: 600
timeout: 90
period: 120
success_threshold: 1
failure_threshold: 5
instance_count: 1

View File

@@ -1,3 +0,0 @@
$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json
name: bloom-inference
auth_mode: key

View File

@@ -1,3 +0,0 @@
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
name: bloom-safetensors
path: /data/bloom-safetensors

View File

@@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "0.8.2"
"version": "0.9.0"
},
"paths": {
"/": {
@@ -270,6 +270,35 @@
}
}
},
"/health": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Health check method",
"description": "Health check method",
"operationId": "health",
"responses": {
"200": {
"description": "Everything is working fine"
},
"503": {
"description": "Text generation inference is down",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "unhealthy",
"error_type": "healthcheck"
}
}
}
}
}
}
},
"/info": {
"get": {
"tags": [

View File

@@ -1040,14 +1040,18 @@ fn main() -> Result<(), LauncherError> {
return Ok(());
}
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
let mut webserver =
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;
// Default exit code
let mut exit_code = Ok(());
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed to start");
tracing::error!("Shard {rank} crashed");
if let Some(err) = err {
tracing::error!("{err}");
}

View File

@@ -22,11 +22,11 @@ text-generation-client = { path = "client" }
clap = { version = "4.1.4", features = ["derive", "env"] }
flume = "0.10.14"
futures = "0.3.26"
metrics = "0.20.1"
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
metrics = "0.21.0"
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
opentelemetry = { version = "0.19.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.12.0"
rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152"
@@ -36,7 +36,7 @@ tokenizers = "0.13.3"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tower-http = { version = "0.4.0", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"
tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@@ -11,10 +11,10 @@ grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.11"
thiserror = "^1.0"
tokio = { version = "^1.25", features = ["sync"] }
tonic = "^0.8"
tonic = "^0.9"
tower = "^0.4"
tracing = "^0.1"
[build-dependencies]
tonic-build = "0.8.4"
tonic-build = "0.9.2"
prost-build = "0.11.6"

View File

@@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
opentelemetry = "0.18.0"
tonic = "^0.8"
opentelemetry = "^0.19"
tonic = "^0.9"
tracing = "^0.1"
tracing-opentelemetry = "0.18.0"
tracing-opentelemetry = "^0.19"

View File

@@ -532,6 +532,7 @@ pub async fn run(
#[derive(OpenApi)]
#[openapi(
paths(
health,
get_model_info,
compat_generate,
generate,

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "1.69.0"
channel = "1.70.0"
components = ["rustfmt", "clippy"]

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "0.8.2"
version = "0.9.0"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]