mirror of
https://github.com/pinokiofactory/flux-webui.git
synced 2024-10-05 23:57:57 +03:00
add save to outputs
This commit is contained in:
21
app.py
21
app.py
@@ -5,6 +5,8 @@ import torch
|
|||||||
from diffusers import FluxTransformer2DModel, FluxPipeline
|
from diffusers import FluxTransformer2DModel, FluxPipeline
|
||||||
from transformers import T5EncoderModel, CLIPTextModel
|
from transformers import T5EncoderModel, CLIPTextModel
|
||||||
from optimum.quanto import QuantizedDiffusersModel, QuantizedTransformersModel
|
from optimum.quanto import QuantizedDiffusersModel, QuantizedTransformersModel
|
||||||
|
from datetime import datetime
|
||||||
|
from PIL import Image
|
||||||
import json
|
import json
|
||||||
import devicetorch
|
import devicetorch
|
||||||
import os
|
import os
|
||||||
@@ -25,6 +27,21 @@ nav {
|
|||||||
display: inline;
|
display: inline;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
#save all generated images into an output folder with unique name
|
||||||
|
def save_images(images):
|
||||||
|
output_folder = "output"
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
saved_paths = []
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"flux_{timestamp}_{i}.png"
|
||||||
|
filepath = os.path.join(output_folder, filename)
|
||||||
|
img.save(filepath)
|
||||||
|
saved_paths.append(filepath)
|
||||||
|
|
||||||
|
return saved_paths
|
||||||
|
|
||||||
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, guidance_scale=0.0, num_images_per_prompt=1, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, guidance_scale=0.0, num_images_per_prompt=1, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
||||||
global pipe
|
global pipe
|
||||||
global selected
|
global selected
|
||||||
@@ -79,7 +96,9 @@ def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, guidan
|
|||||||
print(f"Inference finished!")
|
print(f"Inference finished!")
|
||||||
devicetorch.empty_cache(torch)
|
devicetorch.empty_cache(torch)
|
||||||
print(f"emptied cache")
|
print(f"emptied cache")
|
||||||
return images, seed
|
saved_paths = save_images(images) #save the images into the output folder
|
||||||
|
return images, seed, saved_paths
|
||||||
|
|
||||||
def update_slider(checkpoint, num_inference_steps):
|
def update_slider(checkpoint, num_inference_steps):
|
||||||
if checkpoint == "sayakpaul/FLUX.1-merged":
|
if checkpoint == "sayakpaul/FLUX.1-merged":
|
||||||
return 8
|
return 8
|
||||||
|
|||||||
Reference in New Issue
Block a user