cuda qfloat8

This commit is contained in:
cocktailpeanut
2024-08-08 19:44:23 -04:00
parent a913229e85
commit ffbffa54e0
3 changed files with 62 additions and 24 deletions

64
app.py
View File

@@ -2,9 +2,16 @@ import gradio as gr
import numpy as np
import random
import torch
from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import QuantizedDiffusersModel, QuantizedTransformersModel
import json
import devicetorch
import os
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
dtype = torch.bfloat16
#dtype = torch.float32
device = devicetorch.get(torch)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
@@ -18,42 +25,60 @@ nav {
display: inline;
}
"""
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, num_images_per_prompt=1, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, guidance_scale=0.0, num_images_per_prompt=1, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
global pipe
global selected
# if the new checkpoint is different from the selected one, re-instantiate the pipe
if selected != checkpoint:
if checkpoint == "sayakpaul/FLUX.1-merged":
transformer = FluxTransformer2DModel.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=dtype)
pipe = FluxPipeline.from_pretrained("cocktailpeanut/xulf-d", transformer=transformer, torch_dtype=dtype)
bfl_repo = "cocktailpeanut/xulf-d"
if device == "mps":
transformer = QuantizedFluxTransformer2DModel.from_pretrained("cocktailpeanut/flux1-merged-qint8")
else:
print("initializing quantized transformer...")
transformer = QuantizedFluxTransformer2DModel.from_pretrained("cocktailpeanut/flux1-merged-q8")
print("initialized!")
else:
#transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors")
#pipe = FluxPipeline.from_pretrained(checkpoint, transformer=transformer, torch_dtype=torch.bfloat16)
pipe = FluxPipeline.from_pretrained(checkpoint, torch_dtype=dtype)
bfl_repo = "cocktailpeanut/xulf-s"
if device == "mps":
transformer = QuantizedFluxTransformer2DModel.from_pretrained("cocktailpeanut/flux1-schnell-qint8")
else:
print("initializing quantized transformer...")
transformer = QuantizedFluxTransformer2DModel.from_pretrained("cocktailpeanut/flux1-schnell-q8")
print("initialized!")
print(f"moving device to {device}")
transformer.to(device=device, dtype=dtype)
print(f"initializing pipeline...")
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)
print("initialized!")
pipe.transformer = transformer
pipe.to(device)
pipe.enable_attention_slicing()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
if device == "cuda":
#pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()
print(f"enable model cpu offload...")
pipe.enable_model_cpu_offload()
#pipe.enable_sequential_cpu_offload()
print(f"done!")
selected = checkpoint
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
print(f"Started the inference. Wait...")
images = pipe(
prompt = prompt,
prompt = prompt,
width = width,
height = height,
num_inference_steps = num_inference_steps,
num_inference_steps = num_inference_steps,
generator = generator,
num_images_per_prompt = num_images_per_prompt,
guidance_scale=0.0
guidance_scale=guidance_scale
).images
print(f"Inference finished!")
devicetorch.empty_cache(torch)
print(f"emptied cache")
return images, seed
def update_slider(checkpoint, num_inference_steps):
if checkpoint == "sayakpaul/FLUX.1-merged":
@@ -72,7 +97,7 @@ with gr.Blocks(css=css) as demo:
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", show_label=False)
result = gr.Gallery(label="Result", show_label=False, object_fit="contain")
checkpoint = gr.Dropdown(
label="Model",
value= "black-forest-labs/FLUX.1-schnell",
@@ -119,11 +144,18 @@ with gr.Blocks(css=css) as demo:
step=1,
value=4,
)
guidance_scale = gr.Number(
label="Guidance Scale",
minimum=0,
maximum=50,
value=0.0,
)
checkpoint.change(fn=update_slider, inputs=[checkpoint], outputs=[num_inference_steps])
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, checkpoint, seed, num_images_per_prompt, randomize_seed, width, height, num_inference_steps],
inputs = [prompt, checkpoint, seed, guidance_scale, num_images_per_prompt, randomize_seed, width, height, num_inference_steps],
outputs = [result, seed]
)
demo.launch()

View File

@@ -1,8 +1,13 @@
gradio
devicetorch
accelerate
git+https://github.com/peanutcocktail/diffusers.git
git+https://github.com/huggingface/accelerate.git@test-clear-memory-cpu-offload
git+https://github.com/huggingface/diffusers.git
transformers==4.42.4
sentencepiece
protobuf
einops
git+https://github.com/huggingface/optimum-quanto.git@feat-hub-support
#git+https://github.com/peanutcocktail/optimum-quanto.git@feat-hub-support
#accelerate
#git+https://github.com/peanutcocktail/diffusers.git

View File

@@ -7,7 +7,7 @@ module.exports = {
"params": {
"venv": "{{args && args.venv ? args.venv : null}}",
"path": "{{args && args.path ? args.path : '.'}}",
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 {{args && args.xformers ? 'xformers' : ''}} --index-url https://download.pytorch.org/whl/cu121"
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0 {{args && args.xformers ? 'xformers' : ''}} --index-url https://download.pytorch.org/whl/cu121"
}
},
// windows amd
@@ -27,7 +27,7 @@ module.exports = {
"params": {
"venv": "{{args && args.venv ? args.venv : null}}",
"path": "{{args && args.path ? args.path : '.'}}",
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1"
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0"
}
},
// mac
@@ -37,7 +37,8 @@ module.exports = {
"params": {
"venv": "{{args && args.venv ? args.venv : null}}",
"path": "{{args && args.path ? args.path : '.'}}",
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1"
//"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0"
"message": "pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu"
}
},
// linux nvidia
@@ -47,7 +48,7 @@ module.exports = {
"params": {
"venv": "{{args && args.venv ? args.venv : null}}",
"path": "{{args && args.path ? args.path : '.'}}",
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 {{args && args.xformers ? 'xformers' : ''}} --index-url https://download.pytorch.org/whl/cu121"
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0 {{args && args.xformers ? 'xformers' : ''}} --index-url https://download.pytorch.org/whl/cu121"
}
},
// linux rocm (amd)
@@ -57,7 +58,7 @@ module.exports = {
"params": {
"venv": "{{args && args.venv ? args.venv : null}}",
"path": "{{args && args.path ? args.path : '.'}}",
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/rocm6.0"
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/rocm6.0"
}
},
// linux cpu
@@ -67,7 +68,7 @@ module.exports = {
"params": {
"venv": "{{args && args.venv ? args.venv : null}}",
"path": "{{args && args.path ? args.path : '.'}}",
"message": "pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu"
"message": "pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu"
}
}
]