mirror of
https://github.com/transformerlab/transformerlab-api.git
synced 2025-04-19 19:36:18 +03:00
432 lines
14 KiB
Python
432 lines
14 KiB
Python
"""
|
|
The Entrypoint File for Transformer Lab's API Server.
|
|
"""
|
|
|
|
import os
|
|
import argparse
|
|
import asyncio
|
|
|
|
import json
|
|
import signal
|
|
import subprocess
|
|
from contextlib import asynccontextmanager
|
|
import sys
|
|
from werkzeug.utils import secure_filename
|
|
|
|
import fastapi
|
|
import httpx
|
|
|
|
# Using torch to test for CUDA and MPS support.
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import JSONResponse
|
|
from fastchat.constants import (
|
|
ErrorCode,
|
|
)
|
|
from fastchat.protocol.openai_api_protocol import (
|
|
ErrorResponse,
|
|
)
|
|
|
|
import transformerlab.db as db
|
|
from transformerlab.routers import (
|
|
data,
|
|
model,
|
|
serverinfo,
|
|
train,
|
|
plugins,
|
|
evals,
|
|
config,
|
|
jobs,
|
|
workflows,
|
|
tasks,
|
|
prompts,
|
|
tools,
|
|
batched_prompts,
|
|
)
|
|
import torch
|
|
from pynvml import nvmlShutdown
|
|
from transformerlab import fastchat_openai_api
|
|
from transformerlab.routers.experiment import experiment
|
|
from transformerlab.shared import dirs
|
|
from transformerlab.shared import shared
|
|
from transformerlab.shared import galleries
|
|
|
|
|
|
# The following environment variable can be used by other scripts
|
|
# who need to connect to the root DB, for example
|
|
os.environ["LLM_LAB_ROOT_PATH"] = dirs.ROOT_DIR
|
|
# environment variables that start with _ are
|
|
# used internally to set constants that are shared between separate processes. They are not meant to be
|
|
# to be overriden by the user.
|
|
os.environ["_TFL_WORKSPACE_DIR"] = dirs.WORKSPACE_DIR
|
|
os.environ["_TFL_SOURCE_CODE_DIR"] = dirs.TFL_SOURCE_CODE_DIR
|
|
|
|
from transformerlab.routers.job_sdk import get_xmlrpc_router, get_trainer_xmlrpc_router
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Docs on lifespan events: https://fastapi.tiangolo.com/advanced/events/"""
|
|
# Do the following at API Startup:
|
|
print_launch_message()
|
|
galleries.update_gallery_cache()
|
|
spawn_fastchat_controller_subprocess()
|
|
await db.init()
|
|
if "--reload" in sys.argv:
|
|
await install_all_plugins()
|
|
# run the migration
|
|
asyncio.create_task(migrate())
|
|
asyncio.create_task(run_over_and_over())
|
|
print("FastAPI LIFESPAN: 🏁 🏁 🏁 Begin API Server 🏁 🏁 🏁", flush=True)
|
|
yield
|
|
# Do the following at API Shutdown:
|
|
await db.close()
|
|
# Run the clean up function
|
|
cleanup_at_exit()
|
|
print("FastAPI LIFESPAN: Complete")
|
|
|
|
|
|
# the migrate function only runs the conversion function if no tasks are already present
|
|
async def migrate():
|
|
if len(await tasks.tasks_get_all()) == 0:
|
|
for exp in await experiment.experiments_get_all():
|
|
await tasks.convert_all_to_tasks(exp["id"])
|
|
|
|
|
|
async def run_over_and_over():
|
|
"""Every three seconds, check for new jobs to run."""
|
|
while True:
|
|
await asyncio.sleep(3)
|
|
await jobs.start_next_job()
|
|
await workflows.start_next_step_in_workflow()
|
|
|
|
|
|
description = "Transformerlab API helps you do awesome stuff. 🚀"
|
|
|
|
tags_metadata = [
|
|
{
|
|
"name": "datasets",
|
|
"description": "Actions used to manage the datasets used by Transformer Lab.",
|
|
},
|
|
{"name": "train", "description": "Actions for training models."},
|
|
{"name": "experiment", "descriptions": "Actions for managinging experiments."},
|
|
{
|
|
"name": "model",
|
|
"description": "Actions for interacting with huggingface models", # TODO: is this true?
|
|
},
|
|
{
|
|
"name": "serverinfo",
|
|
"description": "Actions for interacting with the Transformer Lab server.",
|
|
},
|
|
]
|
|
|
|
app = fastapi.FastAPI(
|
|
title="Transformerlab API",
|
|
description=description,
|
|
summary="An API for working with LLMs.",
|
|
version="0.0.1",
|
|
terms_of_service="http://example.com/terms/",
|
|
license_info={
|
|
"name": "Apache 2.0",
|
|
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
|
},
|
|
lifespan=lifespan,
|
|
openapi_tags=tags_metadata,
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
def create_error_response(code: int, message: str) -> JSONResponse:
|
|
return JSONResponse(ErrorResponse(message=message, code=code).model_dump(), status_code=400)
|
|
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(request, exc):
|
|
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
|
|
|
|
|
### END GENERAL API - NOT OPENAI COMPATIBLE ###
|
|
|
|
|
|
app.include_router(model.router)
|
|
app.include_router(serverinfo.router)
|
|
app.include_router(train.router)
|
|
app.include_router(data.router)
|
|
app.include_router(experiment.router)
|
|
app.include_router(plugins.router)
|
|
app.include_router(evals.router)
|
|
app.include_router(jobs.router)
|
|
app.include_router(workflows.router)
|
|
app.include_router(tasks.router)
|
|
app.include_router(config.router)
|
|
app.include_router(prompts.router)
|
|
app.include_router(tools.router)
|
|
app.include_router(batched_prompts.router)
|
|
app.include_router(fastchat_openai_api.router)
|
|
app.include_router(get_xmlrpc_router())
|
|
app.include_router(get_trainer_xmlrpc_router())
|
|
|
|
controller_process = None
|
|
worker_process = None
|
|
|
|
|
|
def spawn_fastchat_controller_subprocess():
|
|
global controller_process
|
|
logfile = open("controller.log", "w")
|
|
port = "21001"
|
|
controller_process = subprocess.Popen(
|
|
[sys.executable, "-m", "fastchat.serve.controller", "--port", port], stdout=logfile, stderr=logfile
|
|
)
|
|
print(f"Started fastchat controller on port {port}")
|
|
|
|
|
|
async def install_all_plugins():
|
|
all_plugins = await plugins.list_plugins()
|
|
print("Re-copying all plugin files from source to workspace")
|
|
for plugin in all_plugins:
|
|
plugin_id = plugin["uniqueId"]
|
|
print(f"Refreshing workspace plugin: {plugin_id}")
|
|
await plugins.copy_plugin_files_to_workspace(plugin_id)
|
|
|
|
|
|
# @app.get("/")
|
|
# async def home():
|
|
# return {"msg": "Welcome to Transformer Lab!"}
|
|
|
|
|
|
@app.get("/server/controller_start", tags=["serverinfo"])
|
|
async def server_controler_start():
|
|
spawn_fastchat_controller_subprocess()
|
|
return {"message": "OK"}
|
|
|
|
|
|
@app.get("/server/controller_stop", tags=["serverinfo"])
|
|
async def server_controller_stop():
|
|
controller_process.terminate()
|
|
return {"message": "OK"}
|
|
|
|
|
|
def set_worker_process_id(process):
|
|
global worker_process
|
|
worker_process = process
|
|
|
|
|
|
@app.get("/server/worker_start", tags=["serverinfo"])
|
|
async def server_worker_start(
|
|
model_name: str,
|
|
adaptor: str = "",
|
|
model_filename: str | None = None,
|
|
eight_bit: bool = False,
|
|
cpu_offload: bool = False,
|
|
inference_engine: str = "default",
|
|
experiment_id: str = None,
|
|
inference_params: str = "",
|
|
):
|
|
global worker_process
|
|
|
|
# the first priority for inference params should be the inference params passed in, then the inference parameters in the experiment
|
|
# first we check to see if any inference params were passed in
|
|
if inference_params != "":
|
|
try:
|
|
inference_params = json.loads(inference_params)
|
|
except json.JSONDecodeError:
|
|
return {"status": "error", "message": "malformed inference params passed"}
|
|
# then we check to see if we are an experiment
|
|
elif experiment_id is not None:
|
|
try:
|
|
experiment = await db.experiment_get(experiment_id)
|
|
experiment_config = experiment["config"]
|
|
experiment_config = json.loads(experiment_config)
|
|
inference_params = experiment_config["inferenceParams"]
|
|
inference_params = json.loads(inference_params)
|
|
except json.JSONDecodeError:
|
|
return {"status": "error", "message": "malformed inference params passed"}
|
|
# if neither are true, then we have an issue
|
|
else:
|
|
return {"status": "error", "message": "malformed inference params passed"}
|
|
|
|
engine = inference_engine
|
|
if "inferenceEngine" in inference_params and engine == "default":
|
|
engine = inference_params.get("inferenceEngine")
|
|
|
|
if engine == "default":
|
|
return {"status": "error", "message": "no inference engine specified"}
|
|
|
|
inference_engine = engine
|
|
|
|
plugin_name = inference_engine
|
|
plugin_location = dirs.plugin_dir_by_name(plugin_name)
|
|
|
|
model = model_name
|
|
if model_filename is not None and model_filename != "":
|
|
model = model_filename
|
|
|
|
if adaptor != "":
|
|
adaptor = f"{dirs.WORKSPACE_DIR}/adaptors/{secure_filename(model)}/{adaptor}"
|
|
|
|
params = [
|
|
dirs.PLUGIN_HARNESS,
|
|
"--plugin_dir",
|
|
plugin_location,
|
|
"--model-path",
|
|
model,
|
|
"--adaptor-path",
|
|
adaptor,
|
|
"--parameters",
|
|
json.dumps(inference_params),
|
|
]
|
|
|
|
job_id = await db.job_create(type="LOAD_MODEL", status="STARTED", job_data="{}", experiment_id=experiment_id)
|
|
|
|
print("Loading plugin loader instead of default worker")
|
|
|
|
with open(dirs.GLOBAL_LOG_PATH, "a") as global_log:
|
|
global_log.write(f"🏃 Loading Inference Server for {model_name} with {inference_params}\n")
|
|
|
|
worker_process = await shared.async_run_python_daemon_and_update_status(
|
|
python_script=params,
|
|
job_id=job_id,
|
|
begin_string="Application startup complete.",
|
|
set_process_id_function=set_worker_process_id,
|
|
)
|
|
exitcode = worker_process.returncode
|
|
if exitcode == 99:
|
|
with open(dirs.GLOBAL_LOG_PATH, "a") as global_log:
|
|
global_log.write(
|
|
"GPU (CUDA) Out of Memory: Please try a smaller model or a different inference engine. Restarting the server may free up resources.\n"
|
|
)
|
|
return {
|
|
"status": "error",
|
|
"message": "GPU (CUDA) Out of Memory: Please try a smaller model or a different inference engine. Restarting the server may free up resources.",
|
|
}
|
|
if exitcode is not None and exitcode != 0:
|
|
with open(dirs.GLOBAL_LOG_PATH, "a") as global_log:
|
|
global_log.write(f"Error loading model: {model_name} with exit code {exitcode}\n")
|
|
error_msg = await db.job_get_error_msg(job_id)
|
|
if not error_msg:
|
|
error_msg = f"Exit code {exitcode}"
|
|
await db.job_update_status(job_id, "FAILED", error_msg)
|
|
return {"status": "error", "message": error_msg}
|
|
with open(dirs.GLOBAL_LOG_PATH, "a") as global_log:
|
|
global_log.write(f"Model loaded successfully: {model_name}\n")
|
|
return {"status": "success", "job_id": job_id}
|
|
|
|
|
|
@app.get("/server/worker_stop", tags=["serverinfo"])
|
|
async def server_worker_stop():
|
|
global worker_process
|
|
print(f"Stopping worker process: {worker_process}")
|
|
if worker_process is not None:
|
|
worker_process.terminate()
|
|
worker_process = None
|
|
# check if there is a file called worker.pid, if so kill the related process:
|
|
if os.path.isfile("worker.pid"):
|
|
with open("worker.pid", "r") as f:
|
|
pid = f.readline()
|
|
print(f"Killing worker process with PID: {pid}")
|
|
os.kill(int(pid), signal.SIGTERM)
|
|
# delete the worker.pid file:
|
|
os.remove("worker.pid")
|
|
return {"message": "OK"}
|
|
|
|
|
|
@app.get("/server/worker_healthz", tags=["serverinfo"])
|
|
async def server_worker_health(request: Request):
|
|
models = []
|
|
result = []
|
|
try:
|
|
models = await fastchat_openai_api.show_available_models()
|
|
except httpx.HTTPError as exc:
|
|
print(f"HTTP Exception for {exc.request.url} - {exc}")
|
|
raise HTTPException(status_code=503, detail="No worker")
|
|
|
|
# We create a new object with JUST the id of the models
|
|
# we do this so that we get a clean object that can be used
|
|
# by react to see if the object changed. If we returned the whole
|
|
# model object, you would see some changes in the object that are
|
|
# not relevant to the user -- triggering renders in React
|
|
for model_data in models.data:
|
|
result.append({"id": model_data.id})
|
|
|
|
return result
|
|
|
|
|
|
# Add an endpoint that serves the static files in the ~/.transformerlab/webapp directory:
|
|
app.mount("/", StaticFiles(directory=dirs.STATIC_FILES_DIR, html=True), name="application")
|
|
|
|
|
|
def cleanup_at_exit():
|
|
if controller_process is not None:
|
|
print("🔴 Quitting spawned controller.")
|
|
controller_process.kill()
|
|
if worker_process is not None:
|
|
print("🔴 Quitting spawned workers.")
|
|
try:
|
|
worker_process.kill()
|
|
except ProcessLookupError:
|
|
print(f"Process {worker_process.pid} doesn't exist so nothing to kill")
|
|
if os.path.isfile("worker.pid"):
|
|
with open("worker.pid", "r") as f:
|
|
pid = f.readline()
|
|
os.remove("worker.pid")
|
|
os.kill(int(pid), signal.SIGTERM)
|
|
# Perform NVML Shutdown if CUDA is available
|
|
if torch.cuda.is_available():
|
|
try:
|
|
print("🔴 Releasing allocated GPU Resources")
|
|
nvmlShutdown()
|
|
except Exception as e:
|
|
print(f"Error shutting down NVML: {e}")
|
|
print("🔴 Quitting Transformer Lab API server.")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="FastChat ChatGPT-Compatible RESTful API server.")
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="host name")
|
|
parser.add_argument("--port", type=int, default=8338, help="port number")
|
|
parser.add_argument("--allow-credentials", action="store_true", help="allow credentials")
|
|
parser.add_argument("--allowed-origins", type=json.loads, default=["*"], help="allowed origins")
|
|
parser.add_argument("--allowed-methods", type=json.loads, default=["*"], help="allowed methods")
|
|
parser.add_argument("--allowed-headers", type=json.loads, default=["*"], help="allowed headers")
|
|
parser.add_argument("auto_reinstall_plugins", type=bool, default=False, help="auto reinstall plugins")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def print_launch_message():
|
|
# Print the welcome message to the CLI
|
|
with open(os.path.join(os.path.dirname(__file__), "transformerlab/launch_header_text.txt"), "r") as f:
|
|
text = f.read()
|
|
shared.print_in_rainbow(text)
|
|
print("http://www.transformerlab.ai\nhttps://github.com/transformerlab/transformerlab-api\n")
|
|
|
|
|
|
def run():
|
|
args = parse_args()
|
|
|
|
print(f"args: {args}")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=args.allowed_origins,
|
|
allow_credentials=args.allow_credentials,
|
|
allow_methods=args.allowed_methods,
|
|
allow_headers=args.allowed_headers,
|
|
)
|
|
|
|
uvicorn.run("api:app", host=args.host, port=args.port, log_level="warning")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|