Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			L4
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			L4
	| #!/usr/bin/env python3 | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| from diffusers import ( | |
| StableDiffusion3Pipeline, | |
| SD3Transformer2DModel, | |
| FlowMatchEulerDiscreteScheduler, | |
| AutoencoderTiny, | |
| ) | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| # import spaces | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 | |
| repo = "stabilityai/stable-diffusion-3-medium-diffusers" | |
| pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=torch.float16).to( | |
| device | |
| ) | |
| taesd3 = ( | |
| AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16) | |
| .half() | |
| .eval() | |
| .requires_grad_(False) | |
| .to(device) | |
| ) | |
| taesd3.decoder.layers = torch.compile( | |
| taesd3.decoder.layers, | |
| fullgraph=True, | |
| dynamic=False, | |
| mode="max-autotune-no-cudagraphs", | |
| ) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1344 | |
| def get_pred_original_sample(sched, model_output, timestep, sample): | |
| return ( | |
| sample | |
| - sched.sigmas[(sched.timesteps == timestep).nonzero().item()] * model_output | |
| ) | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps: Optional[int] = None, | |
| device: Optional[Union[str, torch.device]] = None, | |
| timesteps: Optional[List[int]] = None, | |
| sigmas: Optional[List[float]] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
| Args: | |
| scheduler (`SchedulerMixin`): | |
| The scheduler to get timesteps from. | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
| must be `None`. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
| `num_inference_steps` and `sigmas` must be `None`. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
| `num_inference_steps` and `timesteps` must be `None`. | |
| Returns: | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
| second element is the number of inference steps. | |
| """ | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError( | |
| "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" | |
| ) | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set( | |
| inspect.signature(scheduler.set_timesteps).parameters.keys() | |
| ) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| accept_sigmas = "sigmas" in set( | |
| inspect.signature(scheduler.set_timesteps).parameters.keys() | |
| ) | |
| if not accept_sigmas: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" sigmas schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| def sd3_pipe_call_that_returns_an_iterable_of_images( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| prompt_2: Optional[Union[str, List[str]]] = None, | |
| prompt_3: Optional[Union[str, List[str]]] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 28, | |
| timesteps: List[int] = None, | |
| guidance_scale: float = 3.0, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| negative_prompt_2: Optional[Union[str, List[str]]] = None, | |
| negative_prompt_3: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| clip_skip: Optional[int] = None, | |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| ): | |
| height = height or self.default_sample_size * self.vae_scale_factor | |
| width = width or self.default_sample_size * self.vae_scale_factor | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| prompt_2, | |
| prompt_3, | |
| height, | |
| width, | |
| negative_prompt=negative_prompt, | |
| negative_prompt_2=negative_prompt_2, | |
| negative_prompt_3=negative_prompt_3, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._clip_skip = clip_skip | |
| self._joint_attention_kwargs = joint_attention_kwargs | |
| self._interrupt = False | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = self.encode_prompt( | |
| prompt=prompt, | |
| prompt_2=prompt_2, | |
| prompt_3=prompt_3, | |
| negative_prompt=negative_prompt, | |
| negative_prompt_2=negative_prompt_2, | |
| negative_prompt_3=negative_prompt_3, | |
| do_classifier_free_guidance=self.do_classifier_free_guidance, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| device=device, | |
| clip_skip=self.clip_skip, | |
| num_images_per_prompt=num_images_per_prompt, | |
| ) | |
| if self.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| pooled_prompt_embeds = torch.cat( | |
| [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 | |
| ) | |
| # 4. Prepare timesteps | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, num_inference_steps, device, timesteps | |
| ) | |
| num_warmup_steps = max( | |
| len(timesteps) - num_inference_steps * self.scheduler.order, 0 | |
| ) | |
| self._num_timesteps = len(timesteps) | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.transformer.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 6. Denoising loop | |
| # with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| if True: | |
| for i, t in enumerate(timesteps): | |
| if self.interrupt: | |
| continue | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = ( | |
| torch.cat([latents] * 2) | |
| if self.do_classifier_free_guidance | |
| else latents | |
| ) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| noise_pred = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| pooled_projections=pooled_prompt_embeds, | |
| joint_attention_kwargs=self.joint_attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| x0_pred = get_pred_original_sample(self.scheduler, noise_pred, t, latents) | |
| yield self.image_processor.postprocess(taesd3.decode(x0_pred)[0])[0] | |
| # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| # progress_bar.update() | |
| # | |
| yield self.image_processor.postprocess( | |
| self.vae.decode( | |
| (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor, | |
| return_dict=False, | |
| )[0] | |
| )[0] | |
| # @spaces.GPU | |
| def infer( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator().manual_seed(seed) | |
| yield from sd3_pipe_call_that_returns_an_iterable_of_images( | |
| pipe, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| ) | |
| examples = [ | |
| "A beautiful discovery in the north cascades", | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 580px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown( | |
| f""" | |
| # Demo [Stable Diffusion 3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) with real-time [TAESD3](https://huggingface.co/madebyollin/taesd3) previews | |
| Learn more about the [Stable Diffusion 3 series](https://stability.ai/news/stable-diffusion-3). Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), [Stable Assistant](https://stability.ai/stable-assistant), or on Discord via [Stable Artisan](https://stability.ai/stable-artisan). Run locally with [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [diffusers](https://github.com/huggingface/diffusers) | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter your prompt", | |
| container=False, | |
| ) | |
| run_button = gr.Button("Run", scale=0, variant="primary") | |
| result = gr.Image(label="Result", show_label=False) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt = gr.Text( | |
| label="Negative prompt", | |
| max_lines=1, | |
| placeholder="Enter a negative prompt", | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=64, | |
| value=1024, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=64, | |
| value=1024, | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=3.0, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=28, | |
| ) | |
| gr.Examples(examples=examples, inputs=[prompt]) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit, negative_prompt.submit], | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| ], | |
| outputs=result, | |
| ) | |
| demo.launch(share=True) | |
