mirror of
https://github.com/pinokiofactory/flux-webui.git
synced 2024-10-05 23:57:57 +03:00
cuda qfloat8
This commit is contained in:
60
app.py
60
app.py
@@ -2,9 +2,16 @@ import gradio as gr
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel
|
||||
from diffusers import FluxTransformer2DModel, FluxPipeline
|
||||
from transformers import T5EncoderModel, CLIPTextModel
|
||||
from optimum.quanto import QuantizedDiffusersModel, QuantizedTransformersModel
|
||||
import json
|
||||
import devicetorch
|
||||
import os
|
||||
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
||||
base_class = FluxTransformer2DModel
|
||||
dtype = torch.bfloat16
|
||||
#dtype = torch.float32
|
||||
device = devicetorch.get(torch)
|
||||
MAX_SEED = np.iinfo(np.int32).max
|
||||
MAX_IMAGE_SIZE = 2048
|
||||
@@ -18,32 +25,48 @@ nav {
|
||||
display: inline;
|
||||
}
|
||||
"""
|
||||
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, num_images_per_prompt=1, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
||||
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, guidance_scale=0.0, num_images_per_prompt=1, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
||||
global pipe
|
||||
global selected
|
||||
# if the new checkpoint is different from the selected one, re-instantiate the pipe
|
||||
if selected != checkpoint:
|
||||
if checkpoint == "sayakpaul/FLUX.1-merged":
|
||||
transformer = FluxTransformer2DModel.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=dtype)
|
||||
pipe = FluxPipeline.from_pretrained("cocktailpeanut/xulf-d", transformer=transformer, torch_dtype=dtype)
|
||||
bfl_repo = "cocktailpeanut/xulf-d"
|
||||
if device == "mps":
|
||||
transformer = QuantizedFluxTransformer2DModel.from_pretrained("cocktailpeanut/flux1-merged-qint8")
|
||||
else:
|
||||
#transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors")
|
||||
#pipe = FluxPipeline.from_pretrained(checkpoint, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe = FluxPipeline.from_pretrained(checkpoint, torch_dtype=dtype)
|
||||
|
||||
print("initializing quantized transformer...")
|
||||
transformer = QuantizedFluxTransformer2DModel.from_pretrained("cocktailpeanut/flux1-merged-q8")
|
||||
print("initialized!")
|
||||
else:
|
||||
bfl_repo = "cocktailpeanut/xulf-s"
|
||||
if device == "mps":
|
||||
transformer = QuantizedFluxTransformer2DModel.from_pretrained("cocktailpeanut/flux1-schnell-qint8")
|
||||
else:
|
||||
print("initializing quantized transformer...")
|
||||
transformer = QuantizedFluxTransformer2DModel.from_pretrained("cocktailpeanut/flux1-schnell-q8")
|
||||
print("initialized!")
|
||||
print(f"moving device to {device}")
|
||||
transformer.to(device=device, dtype=dtype)
|
||||
print(f"initializing pipeline...")
|
||||
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)
|
||||
print("initialized!")
|
||||
pipe.transformer = transformer
|
||||
pipe.to(device)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
if device == "cuda":
|
||||
#pipe.enable_model_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
print(f"enable model cpu offload...")
|
||||
pipe.enable_model_cpu_offload()
|
||||
#pipe.enable_sequential_cpu_offload()
|
||||
print(f"done!")
|
||||
selected = checkpoint
|
||||
if randomize_seed:
|
||||
seed = random.randint(0, MAX_SEED)
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
print(f"Started the inference. Wait...")
|
||||
images = pipe(
|
||||
prompt = prompt,
|
||||
width = width,
|
||||
@@ -51,9 +74,11 @@ def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, num_im
|
||||
num_inference_steps = num_inference_steps,
|
||||
generator = generator,
|
||||
num_images_per_prompt = num_images_per_prompt,
|
||||
guidance_scale=0.0
|
||||
guidance_scale=guidance_scale
|
||||
).images
|
||||
print(f"Inference finished!")
|
||||
devicetorch.empty_cache(torch)
|
||||
print(f"emptied cache")
|
||||
return images, seed
|
||||
def update_slider(checkpoint, num_inference_steps):
|
||||
if checkpoint == "sayakpaul/FLUX.1-merged":
|
||||
@@ -72,7 +97,7 @@ with gr.Blocks(css=css) as demo:
|
||||
container=False,
|
||||
)
|
||||
run_button = gr.Button("Run", scale=0)
|
||||
result = gr.Gallery(label="Result", show_label=False)
|
||||
result = gr.Gallery(label="Result", show_label=False, object_fit="contain")
|
||||
checkpoint = gr.Dropdown(
|
||||
label="Model",
|
||||
value= "black-forest-labs/FLUX.1-schnell",
|
||||
@@ -119,11 +144,18 @@ with gr.Blocks(css=css) as demo:
|
||||
step=1,
|
||||
value=4,
|
||||
)
|
||||
guidance_scale = gr.Number(
|
||||
label="Guidance Scale",
|
||||
minimum=0,
|
||||
maximum=50,
|
||||
value=0.0,
|
||||
)
|
||||
checkpoint.change(fn=update_slider, inputs=[checkpoint], outputs=[num_inference_steps])
|
||||
gr.on(
|
||||
triggers=[run_button.click, prompt.submit],
|
||||
fn = infer,
|
||||
inputs = [prompt, checkpoint, seed, num_images_per_prompt, randomize_seed, width, height, num_inference_steps],
|
||||
inputs = [prompt, checkpoint, seed, guidance_scale, num_images_per_prompt, randomize_seed, width, height, num_inference_steps],
|
||||
outputs = [result, seed]
|
||||
)
|
||||
demo.launch()
|
||||
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
gradio
|
||||
devicetorch
|
||||
accelerate
|
||||
git+https://github.com/peanutcocktail/diffusers.git
|
||||
git+https://github.com/huggingface/accelerate.git@test-clear-memory-cpu-offload
|
||||
git+https://github.com/huggingface/diffusers.git
|
||||
transformers==4.42.4
|
||||
sentencepiece
|
||||
protobuf
|
||||
einops
|
||||
git+https://github.com/huggingface/optimum-quanto.git@feat-hub-support
|
||||
|
||||
#git+https://github.com/peanutcocktail/optimum-quanto.git@feat-hub-support
|
||||
#accelerate
|
||||
#git+https://github.com/peanutcocktail/diffusers.git
|
||||
|
||||
13
torch.js
13
torch.js
@@ -7,7 +7,7 @@ module.exports = {
|
||||
"params": {
|
||||
"venv": "{{args && args.venv ? args.venv : null}}",
|
||||
"path": "{{args && args.path ? args.path : '.'}}",
|
||||
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 {{args && args.xformers ? 'xformers' : ''}} --index-url https://download.pytorch.org/whl/cu121"
|
||||
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0 {{args && args.xformers ? 'xformers' : ''}} --index-url https://download.pytorch.org/whl/cu121"
|
||||
}
|
||||
},
|
||||
// windows amd
|
||||
@@ -27,7 +27,7 @@ module.exports = {
|
||||
"params": {
|
||||
"venv": "{{args && args.venv ? args.venv : null}}",
|
||||
"path": "{{args && args.path ? args.path : '.'}}",
|
||||
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1"
|
||||
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0"
|
||||
}
|
||||
},
|
||||
// mac
|
||||
@@ -37,7 +37,8 @@ module.exports = {
|
||||
"params": {
|
||||
"venv": "{{args && args.venv ? args.venv : null}}",
|
||||
"path": "{{args && args.path ? args.path : '.'}}",
|
||||
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1"
|
||||
//"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0"
|
||||
"message": "pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu"
|
||||
}
|
||||
},
|
||||
// linux nvidia
|
||||
@@ -47,7 +48,7 @@ module.exports = {
|
||||
"params": {
|
||||
"venv": "{{args && args.venv ? args.venv : null}}",
|
||||
"path": "{{args && args.path ? args.path : '.'}}",
|
||||
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 {{args && args.xformers ? 'xformers' : ''}} --index-url https://download.pytorch.org/whl/cu121"
|
||||
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0 {{args && args.xformers ? 'xformers' : ''}} --index-url https://download.pytorch.org/whl/cu121"
|
||||
}
|
||||
},
|
||||
// linux rocm (amd)
|
||||
@@ -57,7 +58,7 @@ module.exports = {
|
||||
"params": {
|
||||
"venv": "{{args && args.venv ? args.venv : null}}",
|
||||
"path": "{{args && args.path ? args.path : '.'}}",
|
||||
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/rocm6.0"
|
||||
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/rocm6.0"
|
||||
}
|
||||
},
|
||||
// linux cpu
|
||||
@@ -67,7 +68,7 @@ module.exports = {
|
||||
"params": {
|
||||
"venv": "{{args && args.venv ? args.venv : null}}",
|
||||
"path": "{{args && args.path ? args.path : '.'}}",
|
||||
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu"
|
||||
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user