You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

200 lines
5.3 KiB

import os
import sys
import time
import torch
from diffusers import AutoencoderTiny, StableDiffusionPipeline
from diffusers.utils import load_image
sys.path.insert(0, os.path.abspath('../StreamDiffusion'))
from streamdiffusion import StreamDiffusion
from streamdiffusion.image_utils import postprocess_image
from utils.viewer import receive_images
from utils.wrapper import StreamDiffusionWrapper
from threading import Thread
from multiprocessing import Process, Queue, get_context
from perlin import perlin_2d, rand_perlin_2d, rand_perlin_2d_octaves, perlin_2d_octaves
from scene_prompt import surreal_prompt_parts
from scene_prompt import surreal_prompts
from scene_prompt import regret_prompts
from spout_util import send_spout_image, get_spout_image
from osc import start_osc_server
import fire
def image_generation_process(
queue: Queue,
fps_queue: Queue,
prompt_queue: Queue,
input_queue: Queue,
# prompt: str,
model_id_or_path: str,
)-> None:
# stream = StreamDiffusionWrapper(
# model_id_or_path=model_id_or_path,
# lora_dict=None,
# t_index_list=[0, 16, 32, 45],
# frame_buffer_size=1,
# width=512,
# height=512,
# warmup=10,
# acceleration="xformers",
# mode="txt2img",
# use_denoising_batch=False,
# cfg_type="none",
# seed=2,
# )
stream = StreamDiffusionWrapper(
model_id_or_path=model_id_or_path,
t_index_list=[0],
frame_buffer_size=1,
warmup=10,
acceleration="tensorrt",
use_lcm_lora=False,
mode="img2img",
cfg_type="none",
use_denoising_batch=True,
output_type="pil",
)
start_prompt = "A glowing, vintage phone booth standing in surreal landscapes across different scene"
# Prepare the stream
stream.prepare(
prompt=start_prompt,
num_inference_steps=50,
)
# Prepare image
# init_image = load_image("example.png").resize((512, 512))
# Warmup >= len(t_index_list) x frame_buffer_size
# for _ in range(stream.batch_size - 1):
# stream()
previous_output = None
idx=0
last_time = time.time()
while True:
# try:
start_time = time.time()
# x_output = stream(image=previous_output)
# x_output=stream.stream.txt2img_sd_turbo(1).cpu()
input_image= input_queue.get(block=True)
# input_image = stream.preprocess_image('input.png')
# Check if a new prompt is available in the prompt_queue
if not prompt_queue.empty():
new_prompt = prompt_queue.get(block=False)
if new_prompt:
x_output = stream.img2img(image=input_image, prompt=new_prompt)
print(f"Received new prompt from queue: {new_prompt}")
else:
# Use the current prompt if no new prompt is available
x_output = stream.img2img(image=input_image)
# preprocessed_image =stream.postprocess_image(x_output)
queue.put(x_output, block=False)
# queue.put(preprocessed_image, block=False)
# Calculate FPS
elapsed_time = time.time() - start_time
fps = 1 / elapsed_time if elapsed_time > 0 else float('inf')
fps_queue.put(fps)
# x_output = (x_output + 1) / 2 # Scale from [-1, 1] to [0, 1]
# x_output = torch.clamp(x_output, 0, 1)
# previous_output = x_output
# except KeyboardInterrupt:
# print(f"fps: {fps}")
# return
def main()-> None:
try:
ctx = get_context('spawn')
queue = Queue()
fps_queue = Queue()
# noise_queue = Queue()
spout_in_queue = Queue()
# prompt = "A surreal landscapes"
# prompt=regret_prompts[0]
prompt_queue = Queue()
# model_id_or_path = "KBlueLeaf/kohaku-v2.1"
model_id_or_path = "stabilityai/sd-turbo"
# start_osc_server(prompt_queue)
process_osc = ctx.Process(
target=start_osc_server,
args=(prompt_queue,)
)
process_osc.start()
print("Starting spout input process")
process_spout_in = ctx.Process(
target=get_spout_image,
args=(spout_in_queue, 512, 512),
)
process_spout_in.start()
print("Starting image generation process")
process_gen= ctx.Process(
target=image_generation_process,
args=(queue, fps_queue, prompt_queue, spout_in_queue, model_id_or_path),
)
process_gen.start()
# process_show=ctx.Process(target=receive_images, args=(queue, fps_queue))
# process_show.start()
# print("Starting spout output process")
process_spout_out=ctx.Process(target=send_spout_image, args=(queue, 512, 512))
process_spout_out.start()
process_gen.join()
# process_spout_in.join()
process_spout_out.join()
process_osc.join()
except KeyboardInterrupt:
print("Process interrupted")
process_gen.terminate()
# process_spout_in.terminate()
process_spout_out.terminate()
process_osc.terminate()
return
if __name__ == "__main__":
fire.Fire(main)