This commit is contained in:
cocktailpeanut
2024-08-07 11:05:10 -04:00
parent 5dc49230c6
commit f8ac21c51f
4 changed files with 15 additions and 7 deletions

20
app.py
View File

@@ -18,22 +18,29 @@ nav {
display: inline;
}
"""
def infer(prompt, checkpoint="black-fores-labs/FLUX.1-schnell", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
global pipe
global selected
# if the new checkpoint is different from the selected one, re-instantiate the pipe
if selected != checkpoint:
if checkpoint == "sayakpaul/FLUX.1-merged":
transformer = FluxTransformer2DModel.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=dtype)
pipe = FluxPipeline.from_pretrained("cocktailpeanut/xulf-d", transformer=transformer, torch_dtype=torch.bfloat16)
pipe = FluxPipeline.from_pretrained("cocktailpeanut/xulf-d", transformer=transformer, torch_dtype=dtype)
else:
pipe = FluxPipeline.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
#transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors")
#pipe = FluxPipeline.from_pretrained(checkpoint, transformer=transformer, torch_dtype=torch.bfloat16)
pipe = FluxPipeline.from_pretrained(checkpoint, torch_dtype=dtype)
#pipe.enable_model_cpu_offload()
pipe.to(device)
pipe.enable_attention_slicing()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
if device == "cuda":
#pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()
selected = checkpoint
devicetorch.empty_cache(torch)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
@@ -45,6 +52,7 @@ def infer(prompt, checkpoint="black-fores-labs/FLUX.1-schnell", seed=42, randomi
generator = generator,
guidance_scale=0.0
).images[0]
devicetorch.empty_cache(torch)
return image, seed
def update_slider(checkpoint, num_inference_steps):
if checkpoint == "sayakpaul/FLUX.1-merged":
@@ -53,7 +61,7 @@ def update_slider(checkpoint, num_inference_steps):
return 4
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("<nav><img id='logo' src='file/icon.webp'/></nav>")
gr.HTML("<nav><img id='logo' src='file/icon.png'/></nav>")
with gr.Row():
prompt = gr.Text(
label="Prompt",

BIN
icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

BIN
icon.webp

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.6 KiB

View File

@@ -3,7 +3,7 @@ module.exports = {
version: "2.0",
title: "flux-webui",
description: "Minimal Flux Web UI powered by Gradio & Diffusers",
icon: "icon.webp",
icon: "icon.png",
menu: async (kernel, info) => {
let installed = info.exists("env")
let running = {