add save to outputs

This commit is contained in:
Lytan
2024-08-11 17:10:16 +12:00
committed by GitHub
parent 5e922a1bfc
commit bf77deb530

21
app.py
View File

@@ -5,6 +5,8 @@ import torch
from diffusers import FluxTransformer2DModel, FluxPipeline from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import QuantizedDiffusersModel, QuantizedTransformersModel from optimum.quanto import QuantizedDiffusersModel, QuantizedTransformersModel
from datetime import datetime
from PIL import Image
import json import json
import devicetorch import devicetorch
import os import os
@@ -25,6 +27,21 @@ nav {
display: inline; display: inline;
} }
""" """
#save all generated images into an output folder with unique name
def save_images(images):
output_folder = "output"
os.makedirs(output_folder, exist_ok=True)
saved_paths = []
for i, img in enumerate(images):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"flux_{timestamp}_{i}.png"
filepath = os.path.join(output_folder, filename)
img.save(filepath)
saved_paths.append(filepath)
return saved_paths
def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, guidance_scale=0.0, num_images_per_prompt=1, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)): def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, guidance_scale=0.0, num_images_per_prompt=1, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
global pipe global pipe
global selected global selected
@@ -79,7 +96,9 @@ def infer(prompt, checkpoint="black-forest-labs/FLUX.1-schnell", seed=42, guidan
print(f"Inference finished!") print(f"Inference finished!")
devicetorch.empty_cache(torch) devicetorch.empty_cache(torch)
print(f"emptied cache") print(f"emptied cache")
return images, seed saved_paths = save_images(images) #save the images into the output folder
return images, seed, saved_paths
def update_slider(checkpoint, num_inference_steps): def update_slider(checkpoint, num_inference_steps):
if checkpoint == "sayakpaul/FLUX.1-merged": if checkpoint == "sayakpaul/FLUX.1-merged":
return 8 return 8