mirror of
https://github.com/pinokiofactory/flux-webui.git
synced 2024-10-05 23:57:57 +03:00
update
This commit is contained in:
20
app.py
20
app.py
@@ -18,22 +18,29 @@ nav {
|
||||
display: inline;
|
||||
}
|
||||
"""
|
||||
def infer(prompt, checkpoint="black-fores-labs/FLUX.1-schnell", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
||||
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
||||
global pipe
|
||||
global selected
|
||||
# if the new checkpoint is different from the selected one, re-instantiate the pipe
|
||||
if selected != checkpoint:
|
||||
if checkpoint == "sayakpaul/FLUX.1-merged":
|
||||
transformer = FluxTransformer2DModel.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=dtype)
|
||||
pipe = FluxPipeline.from_pretrained("cocktailpeanut/xulf-d", transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe = FluxPipeline.from_pretrained("cocktailpeanut/xulf-d", transformer=transformer, torch_dtype=dtype)
|
||||
else:
|
||||
pipe = FluxPipeline.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
|
||||
#transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors")
|
||||
#pipe = FluxPipeline.from_pretrained(checkpoint, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe = FluxPipeline.from_pretrained(checkpoint, torch_dtype=dtype)
|
||||
|
||||
#pipe.enable_model_cpu_offload()
|
||||
pipe.to(device)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
if device == "cuda":
|
||||
#pipe.enable_model_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
selected = checkpoint
|
||||
devicetorch.empty_cache(torch)
|
||||
if randomize_seed:
|
||||
seed = random.randint(0, MAX_SEED)
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
@@ -45,6 +52,7 @@ def infer(prompt, checkpoint="black-fores-labs/FLUX.1-schnell", seed=42, randomi
|
||||
generator = generator,
|
||||
guidance_scale=0.0
|
||||
).images[0]
|
||||
devicetorch.empty_cache(torch)
|
||||
return image, seed
|
||||
def update_slider(checkpoint, num_inference_steps):
|
||||
if checkpoint == "sayakpaul/FLUX.1-merged":
|
||||
@@ -53,7 +61,7 @@ def update_slider(checkpoint, num_inference_steps):
|
||||
return 4
|
||||
with gr.Blocks(css=css) as demo:
|
||||
with gr.Column(elem_id="col-container"):
|
||||
gr.HTML("<nav><img id='logo' src='file/icon.webp'/></nav>")
|
||||
gr.HTML("<nav><img id='logo' src='file/icon.png'/></nav>")
|
||||
with gr.Row():
|
||||
prompt = gr.Text(
|
||||
label="Prompt",
|
||||
|
||||
@@ -3,7 +3,7 @@ module.exports = {
|
||||
version: "2.0",
|
||||
title: "flux-webui",
|
||||
description: "Minimal Flux Web UI powered by Gradio & Diffusers",
|
||||
icon: "icon.webp",
|
||||
icon: "icon.png",
|
||||
menu: async (kernel, info) => {
|
||||
let installed = info.exists("env")
|
||||
let running = {
|
||||
|
||||
Reference in New Issue
Block a user