import os import sys import time import torch from diffusers import AutoencoderTiny, StableDiffusionPipeline from diffusers.utils import load_image sys.path.insert(0, os.path.abspath('../StreamDiffusion')) from streamdiffusion import StreamDiffusion from streamdiffusion.image_utils import postprocess_image from utils.viewer import receive_images from utils.wrapper import StreamDiffusionWrapper from threading import Thread from multiprocessing import Process, Queue, get_context from perlin import perlin_2d, rand_perlin_2d, rand_perlin_2d_octaves, perlin_2d_octaves from scene_prompt import surreal_prompt_parts from scene_prompt import surreal_prompts from scene_prompt import regret_prompts from spout_util import send_spout_image, get_spout_image from osc import start_osc_server import fire def image_generation_process( queue: Queue, fps_queue: Queue, prompt_queue: Queue, input_queue: Queue, # prompt: str, model_id_or_path: str, )-> None: # stream = StreamDiffusionWrapper( # model_id_or_path=model_id_or_path, # lora_dict=None, # t_index_list=[0, 16, 32, 45], # frame_buffer_size=1, # width=512, # height=512, # warmup=10, # acceleration="xformers", # mode="txt2img", # use_denoising_batch=False, # cfg_type="none", # seed=2, # ) stream = StreamDiffusionWrapper( model_id_or_path=model_id_or_path, t_index_list=[0], frame_buffer_size=1, warmup=10, acceleration="tensorrt", use_lcm_lora=False, mode="img2img", cfg_type="none", use_denoising_batch=True, output_type="pil", ) start_prompt = "A glowing, vintage phone booth standing in surreal landscapes across different scene" # Prepare the stream stream.prepare( prompt=start_prompt, num_inference_steps=50, ) # Prepare image # init_image = load_image("example.png").resize((512, 512)) # Warmup >= len(t_index_list) x frame_buffer_size # for _ in range(stream.batch_size - 1): # stream() previous_output = None idx=0 last_time = time.time() while True: # try: start_time = time.time() # x_output = stream(image=previous_output) # x_output=stream.stream.txt2img_sd_turbo(1).cpu() input_image= input_queue.get(block=True) # input_image = stream.preprocess_image('input.png') # Check if a new prompt is available in the prompt_queue if not prompt_queue.empty(): new_prompt = prompt_queue.get(block=False) if new_prompt: x_output = stream.img2img(image=input_image, prompt=new_prompt) print(f"Received new prompt from queue: {new_prompt}") else: # Use the current prompt if no new prompt is available x_output = stream.img2img(image=input_image) # preprocessed_image =stream.postprocess_image(x_output) queue.put(x_output, block=False) # queue.put(preprocessed_image, block=False) # Calculate FPS elapsed_time = time.time() - start_time fps = 1 / elapsed_time if elapsed_time > 0 else float('inf') fps_queue.put(fps) # x_output = (x_output + 1) / 2 # Scale from [-1, 1] to [0, 1] # x_output = torch.clamp(x_output, 0, 1) # previous_output = x_output # except KeyboardInterrupt: # print(f"fps: {fps}") # return def main()-> None: try: ctx = get_context('spawn') queue = Queue() fps_queue = Queue() # noise_queue = Queue() spout_in_queue = Queue() # prompt = "A surreal landscapes" # prompt=regret_prompts[0] prompt_queue = Queue() # model_id_or_path = "KBlueLeaf/kohaku-v2.1" model_id_or_path = "stabilityai/sd-turbo" # start_osc_server(prompt_queue) process_osc = ctx.Process( target=start_osc_server, args=(prompt_queue,) ) process_osc.start() print("Starting spout input process") process_spout_in = ctx.Process( target=get_spout_image, args=(spout_in_queue, 512, 512), ) process_spout_in.start() print("Starting image generation process") process_gen= ctx.Process( target=image_generation_process, args=(queue, fps_queue, prompt_queue, spout_in_queue, model_id_or_path), ) process_gen.start() # process_show=ctx.Process(target=receive_images, args=(queue, fps_queue)) # process_show.start() # print("Starting spout output process") process_spout_out=ctx.Process(target=send_spout_image, args=(queue, 512, 512)) process_spout_out.start() process_gen.join() # process_spout_in.join() process_spout_out.join() process_osc.join() except KeyboardInterrupt: print("Process interrupted") process_gen.terminate() # process_spout_in.terminate() process_spout_out.terminate() process_osc.terminate() return if __name__ == "__main__": fire.Fire(main)