parent
f1d8b043ea
commit
028f10f0ec
6 changed files with 297 additions and 41 deletions
@ -0,0 +1,112 @@ |
||||
import sys |
||||
import os |
||||
|
||||
sys.path.append( |
||||
os.path.join( |
||||
os.path.dirname(__file__), |
||||
"..", |
||||
"..", |
||||
) |
||||
) |
||||
|
||||
from utils.wrapper import StreamDiffusionWrapper |
||||
|
||||
import torch |
||||
|
||||
# from config import Args |
||||
from pydantic import BaseModel, Field |
||||
from PIL import Image |
||||
import math |
||||
|
||||
# base_model = "stabilityai/sd-turbo" |
||||
# taesd_model = "madebyollin/taesd" |
||||
base_model = "./models/sd-turbo" |
||||
taesd_model = "./models/taesd" |
||||
|
||||
default_prompt = "Portrait of The Joker halloween costume, face painting, with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece" |
||||
default_negative_prompt = "black and white, blurry, low resolution, pixelated, pixel art, low quality, low fidelity" |
||||
|
||||
page_content = """<h1 class="text-3xl font-bold">StreamDiffusion</h1> |
||||
<h3 class="text-xl font-bold">Image-to-Image SD-Turbo</h3> |
||||
<p class="text-sm"> |
||||
This demo showcases |
||||
<a |
||||
href="https://github.com/cumulo-autumn/StreamDiffusion" |
||||
target="_blank" |
||||
class="text-blue-500 underline hover:no-underline">StreamDiffusion |
||||
</a> |
||||
Image to Image pipeline using |
||||
<a |
||||
href="https://huggingface.co/stabilityai/sd-turbo" |
||||
target="_blank" |
||||
class="text-blue-500 underline hover:no-underline">SD-Turbo</a |
||||
> with a MJPEG stream server. |
||||
</p> |
||||
""" |
||||
|
||||
|
||||
class Pipeline: |
||||
class Info(BaseModel): |
||||
name: str = "StreamDiffusion img2img" |
||||
input_mode: str = "image" |
||||
page_content: str = page_content |
||||
|
||||
class InputParams(BaseModel): |
||||
prompt: str = Field( |
||||
default_prompt, |
||||
title="Prompt", |
||||
field="textarea", |
||||
id="prompt", |
||||
) |
||||
# negative_prompt: str = Field( |
||||
# default_negative_prompt, |
||||
# title="Negative Prompt", |
||||
# field="textarea", |
||||
# id="negative_prompt", |
||||
# ) |
||||
width: int = Field( |
||||
512, min=2, max=15, title="Width", disabled=True, hide=True, id="width" |
||||
) |
||||
height: int = Field( |
||||
512, min=2, max=15, title="Height", disabled=True, hide=True, id="height" |
||||
) |
||||
|
||||
def __init__(self, device: torch.device, torch_dtype: torch.dtype): |
||||
params = self.InputParams() |
||||
self.stream = StreamDiffusionWrapper( |
||||
model_id_or_path=base_model, |
||||
use_tiny_vae=True, |
||||
device=device, |
||||
dtype=torch_dtype, |
||||
t_index_list=[35, 45], |
||||
frame_buffer_size=1, |
||||
width=params.width, |
||||
height=params.height, |
||||
use_lcm_lora=False, |
||||
output_type="pil", |
||||
warmup=10, |
||||
vae_id=taesd_model, |
||||
acceleration="xformers", |
||||
mode="img2img", |
||||
use_denoising_batch=True, |
||||
cfg_type="none", |
||||
# use_safety_checker=args.safety_checker, |
||||
enable_similar_image_filter=True, |
||||
similar_image_filter_threshold=0.98, |
||||
# engine_dir=args.engine_dir, |
||||
) |
||||
|
||||
self.last_prompt = default_prompt |
||||
self.stream.prepare( |
||||
prompt=default_prompt, |
||||
negative_prompt=default_negative_prompt, |
||||
num_inference_steps=50, |
||||
guidance_scale=1.2, |
||||
) |
||||
|
||||
def predict(self, image: Image.Image, params: "Pipeline.InputParams") -> Image.Image: |
||||
image_tensor = self.stream.preprocess_image(image) |
||||
# output_image = self.stream(image=image_tensor, prompt=params.prompt) |
||||
output_image = self.stream(image=image_tensor, prompt=params.prompt) |
||||
|
||||
return output_image |
||||
@ -0,0 +1,120 @@ |
||||
from fastapi import FastAPI |
||||
from pydantic import BaseModel |
||||
import datetime |
||||
import torch |
||||
from PIL import Image |
||||
import numpy as np |
||||
import SpoutGL |
||||
from OpenGL.GL import GL_RGBA |
||||
import time |
||||
import img2img |
||||
|
||||
def main(): |
||||
TARGET_FPS = 60 |
||||
SPOUT_RECEIVER_NAME = "Spout DX11 Sender" |
||||
SPOUT_SENDER_NAME = "Output - StreamDiffusion" |
||||
WIDTH = 512 |
||||
HEIGHT = 512 |
||||
PROMPT = "a beautiful landscape painting, trending on artstation, 8k, hyperrealistic" |
||||
timestamp = datetime.datetime.now() |
||||
fps = 30.0 |
||||
|
||||
print("Initializing StreamDiffusion pipeline...") |
||||
global pipeline |
||||
try: |
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
torch_dtype = torch.float16 |
||||
pipeline = img2img.Pipeline(device, torch_dtype) |
||||
|
||||
app = FastAPI() |
||||
|
||||
@app.get("/health") |
||||
def read_root(): |
||||
return {"status": "ok"} |
||||
|
||||
class PromptUpdate(BaseModel): |
||||
prompt: str |
||||
|
||||
@app.post("/api/update/prompt") |
||||
async def update_prompt(update: PromptUpdate): |
||||
global PROMPT |
||||
PROMPT = update.prompt |
||||
print(f"Prompt updated to: {PROMPT}") |
||||
return {"message": "Prompt updated successfully", "new_prompt": PROMPT} |
||||
|
||||
print("Pipeline initialized.") |
||||
except Exception as e: |
||||
print(f"Error initializing StreamDiffusion pipeline: {e}") |
||||
return |
||||
|
||||
print(f"Initializing Spout receiver for '{SPOUT_RECEIVER_NAME}'...") |
||||
spout_receiver = SpoutGL.SpoutReceiver() |
||||
spout_receiver.setReceiverName(SPOUT_RECEIVER_NAME) |
||||
|
||||
print(f"Initializing Spout sender as '{SPOUT_SENDER_NAME}'...") |
||||
spout_sender = SpoutGL.SpoutSender() |
||||
spout_sender.setSenderName(SPOUT_SENDER_NAME) |
||||
|
||||
image_bgra = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8) |
||||
|
||||
import uvicorn |
||||
import threading |
||||
config = uvicorn.Config(app, host="0.0.0.0", port=34800, log_level="info") |
||||
server = uvicorn.Server(config) |
||||
threading.Thread(target=server.run, daemon=True).start() |
||||
print("FastAPI server started at http://0.0.0.0:34800") |
||||
|
||||
try: |
||||
print("Starting main loop. Press Ctrl+C to exit.") |
||||
while True: |
||||
received = spout_receiver.receiveImage(image_bgra, GL_RGBA, False, 0) |
||||
# print(f"Received: {received}, Connected: {spout_receiver.isConnected()}, Updated: {spout_receiver.isUpdated()}, Empty: {SpoutGL.helpers.isBufferEmpty(image_bgra)}") |
||||
|
||||
if received: |
||||
if spout_receiver.isUpdated(): |
||||
continue |
||||
|
||||
if spout_receiver.isConnected() and SpoutGL.helpers.isBufferEmpty(image_bgra): |
||||
continue |
||||
|
||||
image_rgb_array = image_bgra[:, :, [2,1,0]] |
||||
input_image = Image.fromarray(image_rgb_array, 'RGB') |
||||
# input_image.save("debug_input.png") |
||||
|
||||
|
||||
params = img2img.Pipeline.InputParams(prompt=PROMPT) |
||||
output_image = pipeline.predict(image=input_image, params=params) |
||||
# output_image.save("debug_output.png") |
||||
|
||||
# output_rgba_array = np.array(output_image.convert("RGBA")) |
||||
# output_bgra_array = output_rgba_array[:, :, [2, 1, 0, 3]] |
||||
# buffer = np.ascontiguousarray(output_bgra_array) |
||||
output_bgr_array = np.array(output_image, dtype=np.uint8)[:, :, ::-1] |
||||
output_bgra_array = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8) |
||||
output_bgra_array[:, :, :3] = output_bgr_array |
||||
output_bgra_array[:, :, 3] = 255 |
||||
buffer = output_bgra_array |
||||
|
||||
spout_sender.sendImage(buffer, WIDTH, HEIGHT, GL_RGBA, False, 0) |
||||
|
||||
# timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] |
||||
dt = (datetime.datetime.now() - timestamp).total_seconds() |
||||
t = 0.05 |
||||
fps = fps * t + 1 / dt * (1 - t) |
||||
timestamp = datetime.datetime.now() |
||||
|
||||
print("\033[92m[ STREAM DIFFUSION ]\033[0m " + f"Frame processed and sent to Spout: {fps:2f}", end="\r", flush=True) |
||||
else: |
||||
time.sleep(1. / TARGET_FPS) |
||||
|
||||
except KeyboardInterrupt: |
||||
print("\nExiting...") |
||||
finally: |
||||
print("Releasing Spout resources.") |
||||
spout_receiver.releaseReceiver() |
||||
spout_sender.releaseSender() |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
main() |
||||
|
||||
Loading…
Reference in new issue