main
reng 5 months ago
parent 028f10f0ec
commit 92bf193f1a
  1. 9
      img2img.py
  2. 11
      main.py

@ -78,7 +78,8 @@ class Pipeline:
use_tiny_vae=True, use_tiny_vae=True,
device=device, device=device,
dtype=torch_dtype, dtype=torch_dtype,
t_index_list=[35, 45], # t_index_list=[35, 45],
t_index_list=[1],
frame_buffer_size=1, frame_buffer_size=1,
width=params.width, width=params.width,
height=params.height, height=params.height,
@ -101,12 +102,12 @@ class Pipeline:
prompt=default_prompt, prompt=default_prompt,
negative_prompt=default_negative_prompt, negative_prompt=default_negative_prompt,
num_inference_steps=50, num_inference_steps=50,
guidance_scale=1.2, guidance_scale=1.0,
) )
def predict(self, image: Image.Image, params: "Pipeline.InputParams") -> Image.Image: def predict(self, image: Image.Image, params: "Pipeline.InputParams") -> Image.Image:
image_tensor = self.stream.preprocess_image(image) # image_tensor = self.stream.preprocess_image(image)
# output_image = self.stream(image=image_tensor, prompt=params.prompt) # output_image = self.stream(image=image_tensor, prompt=params.prompt)
output_image = self.stream(image=image_tensor, prompt=params.prompt) output_image = self.stream(image=image, prompt=params.prompt)
return output_image return output_image

@ -8,6 +8,7 @@ import SpoutGL
from OpenGL.GL import GL_RGBA from OpenGL.GL import GL_RGBA
import time import time
import img2img import img2img
from multiprocessing import Queue
def main(): def main():
TARGET_FPS = 60 TARGET_FPS = 60
@ -19,6 +20,8 @@ def main():
timestamp = datetime.datetime.now() timestamp = datetime.datetime.now()
fps = 30.0 fps = 30.0
prompt_queue = Queue()
print("Initializing StreamDiffusion pipeline...") print("Initializing StreamDiffusion pipeline...")
global pipeline global pipeline
try: try:
@ -39,6 +42,7 @@ def main():
async def update_prompt(update: PromptUpdate): async def update_prompt(update: PromptUpdate):
global PROMPT global PROMPT
PROMPT = update.prompt PROMPT = update.prompt
prompt_queue.put(PROMPT)
print(f"Prompt updated to: {PROMPT}") print(f"Prompt updated to: {PROMPT}")
return {"message": "Prompt updated successfully", "new_prompt": PROMPT} return {"message": "Prompt updated successfully", "new_prompt": PROMPT}
@ -78,10 +82,17 @@ def main():
continue continue
image_rgb_array = image_bgra[:, :, [2,1,0]] image_rgb_array = image_bgra[:, :, [2,1,0]]
# image_rgb_array = image_rgb_array.astype(np.float32) / 255.0
input_image = Image.fromarray(image_rgb_array, 'RGB') input_image = Image.fromarray(image_rgb_array, 'RGB')
# input_image.save("debug_input.png") # input_image.save("debug_input.png")
if not prompt_queue.empty():
new_prompt = prompt_queue.get(block=False)
if new_prompt:
print(f"Received new prompt from queue: {new_prompt}")
PROMPT = new_prompt
# print(f"current prompt: {PROMPT}")
params = img2img.Pipeline.InputParams(prompt=PROMPT) params = img2img.Pipeline.InputParams(prompt=PROMPT)
output_image = pipeline.predict(image=input_image, params=params) output_image = pipeline.predict(image=input_image, params=params)
# output_image.save("debug_output.png") # output_image.save("debug_output.png")

Loading…
Cancel
Save