Source code for perceptor.models.stable_diffusion.stable_diffusion

from typing import Optional, Literal, Union, List
from contextlib import contextmanager
from tqdm import tqdm
import copy
import torch
import torch.nn.functional as F
import kornia
import lantern
from transformers import CLIPTokenizer, CLIPTextModel, logging
from diffusers import (
    DDPMScheduler,
    AutoencoderKL,
    UNet2DConditionModel,
)
import diffusers.models

import perceptor

from . import diffusion_space
from .predictions import Predictions
from .conditioning import Conditioning

try:
    from . import attention

    XFORMERS_INSTALLED = True
except ImportError:
    XFORMERS_INSTALLED = False


# @cache
[docs]class StableDiffusion(torch.nn.Module): def __init__( self, name: str = "runwayml/stable-diffusion-v1-5", decoder_name: Optional[str] = "stabilityai/sd-vae-ft-mse", fp16: bool = True, auth_token: Union[bool, str] = True, flash_attention: bool = True, attention_slicing: Optional[Union[int, Literal["auto"]]] = None, ): """ Stable Diffusion text2image model. Args: name (str): Name of the model. Defaults to "runwayml/stable-diffusion-v1-5". Available models are: - runwayml/stable-diffusion-v1-5 (512x512) - runwayml/stable-diffusion-inpainting (512x512) - CompVis/stable-diffusion-v1-4 (512x512) - Huggingface model id - Path to weights decoder_name (str, optional): Name of the decoder model. Defaults to "stabilityai/sd-vae-ft-mse". Available models are: - stabilityai/sd-vae-ft-mse - stabilityai/sd-vae-ft-ema - None (use the original decoder) fp16 (bool): Whether to use mixed precision. Defaults to True. auth_token (bool): Whether to use an auth token. Defaults to True. flash_attention (bool): Whether to use flash attention. Defaults to True. attention_slicing (Union[int, Literal["auto"]], optional): Number of attention steps. Defaults to None. Options are "auto" or an integer. Lowers VRAM usage but increases inference time. """ super().__init__() self.name = name self.decoder_name = decoder_name if XFORMERS_INSTALLED and flash_attention: # monkeypatch xformers flash attention patch = { "AttentionBlock", "FeedForward", "CrossAttention", "SpatialTransformer", } for attribute in patch: setattr( diffusers.models.attention, attribute, getattr(attention, attribute) ) self.vae = AutoencoderKL.from_pretrained( decoder_name or name, use_auth_token=auth_token ) config = ( dict( use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16, ) if fp16 else dict(use_auth_token=auth_token) ) self.unet = UNet2DConditionModel.from_pretrained( name, subfolder="unet", **config ) self.scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) if attention_slicing is not None: if attention_slicing == "auto": attention_slicing = 2 slice_size = self.unet.config.attention_head_dim // attention_slicing self.unet.set_attention_slice(slice_size) self.schedule_alphas = torch.nn.Parameter( self.scheduler.alphas_cumprod.sqrt(), requires_grad=False ) self.schedule_sigmas = torch.nn.Parameter( (1 - self.scheduler.alphas_cumprod).sqrt(), requires_grad=False, ) self.vae_original_requires_grads = [ parameter.requires_grad for parameter in self.vae.parameters() ] self.vae.eval() self.vae.requires_grad_(False) self.unet.eval() self.unet.requires_grad_(False) @property def device(self): return next(iter(self.parameters())).device @property def shape(self): return self.model.shape
[docs] def schedule_indices( self, n_steps=500, from_index=999, to_index=0, rho=3.0 ) -> lantern.Tensor: if from_index < to_index: raise ValueError("from_index must be greater than to_index") from_alpha, from_sigma = self.alphas(from_index), self.sigmas(from_index) to_alpha, to_sigma = self.alphas(to_index), self.sigmas(to_index) from_log_snr = torch.log(from_alpha**2 / from_sigma**2) to_log_snr = torch.log(to_alpha**2 / to_sigma**2) elucidated_from_sigma = (1 / from_log_snr.exp()).sqrt().clamp(max=150) elucidated_to_sigma = (1 / to_log_snr.exp()).sqrt().clamp(min=1e-3) ramp = torch.linspace(0, 1, n_steps + 1).to(self.device) min_inv_rho = elucidated_to_sigma ** (1 / rho) max_inv_rho = elucidated_from_sigma ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho target_log_snr = torch.log(torch.ones_like(sigmas) ** 2 / sigmas**2) schedule_log_snr = torch.log( self.schedule_alphas**2 / self.schedule_sigmas**2 ) assert target_log_snr.squeeze().ndim == 1 assert schedule_log_snr.squeeze().ndim == 1 schedule_indices = ( (target_log_snr.squeeze()[:, None] - schedule_log_snr.squeeze()[None, :]) .abs() .argmin(dim=1) .unique() .sort(descending=True)[0] ) if len(schedule_indices) <= n_steps * 0.9: raise ValueError( f"Scheduled steps {len(schedule_indices)} is too far from wanted number of steps {n_steps}" ) assert (schedule_indices[:-1] != schedule_indices[1:]).all() return torch.stack([schedule_indices[:-1], schedule_indices[1:]], dim=1)
[docs] @torch.cuda.amp.autocast() def encode( self, images: lantern.Tensor.dims("NCHW").float(), method="mode" ) -> lantern.Tensor.dims("NCHW"): _, _, h, w = images.shape if h % 32 != 0: raise Exception(f"Height must be divisible by 32, got {h}") if w % 32 != 0: raise Exception(f"Width must be divisible by 32, got {w}") distribution = self.vae.encode(diffusion_space.encode(images.to(self.device))) if method == "sample": return 0.18215 * distribution.latent_dist.sample() elif method == "mode": return 0.18215 * distribution.latent_dist.mode() else: raise ValueError(f"Unknown encoding method {method}")
[docs] @torch.cuda.amp.autocast() def decode( self, latents: lantern.Tensor.dims("NCHW").float() ) -> lantern.Tensor.dims("NCHW"): return diffusion_space.decode(self.vae.decode(latents / 0.18215).sample)
[docs] @contextmanager def finetuneable_vae(self): """ with diffusion_model.finetuneable_vae(): images = diffusion_model.decode(latents) """ state_dict = copy.deepcopy(self.vae.state_dict()) try: for parameter, requires_grad in zip( self.vae.parameters(), self.vae_original_requires_grads ): parameter.requires_grad_(requires_grad) yield self finally: self.vae.load_state_dict(state_dict) self.vae.requires_grad_(False)
[docs] @torch.cuda.amp.autocast() def latents( self, images: lantern.Tensor.dims("NCHW").float() ) -> lantern.Tensor.dims("NCHW"): return self.encode(images).float()
[docs] @torch.cuda.amp.autocast() def images( self, latents: lantern.Tensor.dims("NCHW").float() ) -> lantern.Tensor.dims("NCHW"): return self.decode(latents).float()
[docs] def random_diffused_latents(self, shape) -> lantern.Tensor: n, c, h, w = shape if h % 32 != 0: raise ValueError("Height must be divisible by 32") if w % 32 != 0: raise ValueError("Width must be divisible by 32") return ( torch.randn((n, self.unet.in_channels, h // 8, w // 8)).to(self.device) * self.scheduler.init_noise_sigma )
[docs] def indices(self, indices) -> lantern.Tensor: if isinstance(indices, float) or isinstance(indices, int): indices = torch.as_tensor(indices) if indices.ndim == 0: indices = indices[None] if indices.ndim != 1: raise ValueError("indices must be a scalar or a 1-dimensional tensor") return indices.long().to(self.device)
[docs] def alphas(self, indices) -> lantern.Tensor: return self.schedule_alphas[self.indices(indices)][:, None, None, None].to( self.device )
[docs] def sigmas(self, indices) -> lantern.Tensor: return self.schedule_sigmas[self.indices(indices)][:, None, None, None].to( self.device )
[docs] @torch.cuda.amp.autocast() def predicted_noise( self, diffused_latents, from_indices, conditioning: Conditioning, ) -> lantern.Tensor: predicted_noise = self.unet( conditioning.input(diffused_latents), self.indices(from_indices), conditioning.encodings, )["sample"] return predicted_noise.float()
[docs] def forward( self, diffused_latents: lantern.Tensor, indices: lantern.Tensor, conditioning: Optional[Conditioning] = None, ) -> Predictions: indices = self.indices(indices) return Predictions( from_diffused_latents=diffused_latents, from_indices=indices, predicted_noise=self.predicted_noise( diffused_latents, indices, conditioning ), schedule_alphas=self.schedule_alphas, schedule_sigmas=self.schedule_sigmas, encode=self.encode, decode=self.decode, )
[docs] def predictions(self, diffused_latents, indices, conditioning) -> Predictions: return self.forward(diffused_latents, indices, conditioning)
[docs] def text_encodings(self, texts): verbosity = logging.get_verbosity() logging.set_verbosity_error() tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14" ).to(self.device) logging.set_verbosity(verbosity) tokenized_text = tokenizer( texts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = tokenized_text.input_ids if text_input_ids.shape[-1] > tokenizer.model_max_length: removed_text = tokenizer.batch_decode( text_input_ids[:, tokenizer.model_max_length :] ) print( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : tokenizer.model_max_length] return text_encoder(text_input_ids.to(self.device))[0]
[docs] def latent_masks(self, masks, blur): n, c, h, w = masks.shape if h % 8 != 0: raise ValueError("Height must be divisible by 8") if w % 8 != 0: raise ValueError("Width must be divisible by 8") if c != 1: raise ValueError("Masks must be 1-channel") if masks.gt(1).any() or masks.lt(0).any(): raise ValueError("Masks must be between 0 and 1") if blur is not None and blur > 0: ks = int(blur * 2) + 1 masks = kornia.filters.gaussian_blur2d(masks, (ks, ks), (blur, blur)) return F.interpolate( masks.to(self.device).float(), size=(h // 8, w // 8), mode="bilinear" )
[docs] def conditioning( self, texts: List[str] = [""], inpainting_masks: Optional[lantern.Tensor.dims("NCHW")] = None, inpainting_images: Optional[lantern.Tensor.dims("NCHW")] = None, mask_blur=4.0, ) -> Conditioning: """ Create a conditioning object from a list of texts. Unconditional is an empty string. Args: texts: A list of texts to condition on. Unconditional is an empty string inpainting_masks: A tensor of masks to condition on. Must be 1-channel and between 0 and 1 inpainting_images: A tensor of images to condition on. Must be 3-channel and between 0 and 1 """ if self.name == "runwayml/stable-diffusion-inpainting": inpainting_latent_masks = self.latent_masks(inpainting_masks, mask_blur) inpainting_latents = self.latents( inpainting_images * inpainting_masks.le(0.5) + 0.5 * inpainting_masks.gt(0.5).float() # important that this matches the mask given to the model # not certain that it matches at the moment when doing blurring ) else: inpainting_latent_masks = None inpainting_latents = None return Conditioning( model_name=self.name, encodings=self.text_encodings(texts), inpainting_latent_masks=inpainting_latent_masks, inpainting_latents=inpainting_latents, )
[docs] def diffuse_latents(self, denoised_latents, indices, noise=None) -> lantern.Tensor: indices = self.indices(indices) if noise is None: noise = torch.randn_like(denoised_latents) alphas, sigmas = self.alphas(indices), self.sigmas(indices) return denoised_latents * alphas + noise * sigmas
[docs] @torch.cuda.amp.autocast() @torch.no_grad() def sample( self, text: str, from_index: int = 999, to_index: int = 0, n_steps: int = 50, guidance_scale: float = 7.0, n_resample: int = 0, init_image: Optional[lantern.Tensor] = None, inpainting_mask: Optional[lantern.Tensor] = None, mask_blur: float = 4.0, replace_diffused: bool = True, ): """ Helper function to sample a single image. Args: text: The text to condition on from_index: The index to start sampling from to_index: The index to end sampling at n_steps: The number of steps to take between from_index and to_index guidance_scale: The scale of the guidance signal n_resample: The number of times to resample at each step init_image: The initial image to start sampling from (also used for inpainting) inpainting_mask: The mask to use for inpainting mask_blur: The amount of blur to apply to the inpainting mask replace_diffused: Whether to replace the diffused latents at each step (peeks into the init image so it's not true inpainting) """ neutral_conditioning = self.conditioning( texts=[""], inpainting_masks=inpainting_mask, inpainting_images=init_image, mask_blur=mask_blur, ) positive_conditioning = self.conditioning( texts=[text], inpainting_masks=inpainting_mask, inpainting_images=init_image, mask_blur=mask_blur, ) schedule_indices = self.schedule_indices( from_index=from_index, to_index=to_index, n_steps=n_steps ) from_index = schedule_indices[0, 0] if init_image is None: if from_index != 999: raise ValueError("init_image must be provided if from_index < 999") diffused_latents = self.random_diffused_latents((1, 3, 512, 512)) else: init_latents = self.latents(init_image) diffused_latents = self.diffuse_latents(init_latents, from_index) for from_index, to_index in tqdm(schedule_indices): for _ in range(n_resample): unconditioned_predictions = self.predictions( diffused_latents, from_index, neutral_conditioning, ) positive_predictions = self.predictions( diffused_latents, from_index, positive_conditioning, ) strongly_positive_predictions = ( unconditioned_predictions.classifier_free_guidance( positive_predictions, guidance_scale=guidance_scale ) ) diffused_latents = strongly_positive_predictions.resample(to_index) unconditioned_predictions = self.predictions( diffused_latents, from_index, neutral_conditioning, ) positive_predictions = self.predictions( diffused_latents, from_index, positive_conditioning, ) strongly_positive_predictions = ( unconditioned_predictions.classifier_free_guidance( positive_predictions, guidance_scale=guidance_scale ) ) diffused_latents = strongly_positive_predictions.step(to_index) if replace_diffused and inpainting_mask is not None: # this is peeking into the original masked image diffused_latents = ( self.diffuse_latents(init_latents, to_index) * (1 - positive_conditioning.inpainting_latent_masks) + diffused_latents * positive_conditioning.inpainting_latent_masks ) yield positive_predictions yield self.predictions( diffused_latents, to_index, positive_conditioning, )
def test_stable_diffusion_attention_slicing(): diffusion_model = StableDiffusion(attention_slicing="auto").cuda() diffused_latents = diffusion_model.random_diffused_latents((1, 3, 256, 256)) diffusion_model.predictions(diffused_latents, 100, diffusion_model.conditioning()) def test_stable_diffusion(): for predictions in ( StableDiffusion(fp16=False) .cuda() .sample("photograph of a playful cat", to_index=20) ): pass perceptor.utils.pil_image(predictions.denoised_images.clamp(0, 1)).save( "tests/stable_diffusion.png" ) def test_stable_diffusion_init_image(): import requests from PIL import Image import torchvision.transforms.functional as TF image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" init_image = TF.to_tensor( Image.open(requests.get(image_url, stream=True).raw).resize((512, 512)) )[None].cuda() for predictions in ( StableDiffusion() .cuda() .sample( "photograph of playful lions", from_index=500, to_index=20, init_image=init_image, n_resample=4, guidance_scale=3, ) ): pass perceptor.utils.pil_image(predictions.denoised_images.clamp(0, 1)).save( "tests/stable_diffusion_init_image.png" ) def test_stable_diffusion_inpainting(): import requests from PIL import Image import torch import torchvision.transforms.functional as TF from perceptor import utils image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" init_image = TF.to_tensor( Image.open(requests.get(image_url, stream=True).raw).resize((512, 512)) )[None].cuda() inpainting_mask = torch.zeros_like(init_image[:, :1]) inpainting_mask[:, :, :, 128:] = 1.0 torch.set_grad_enabled(False) diffusion_model = StableDiffusion("runwayml/stable-diffusion-inpainting").cuda() for predictions in diffusion_model.sample( "photograph of playful lions", from_index=600, to_index=20, n_steps=50, guidance_scale=7, init_image=init_image, inpainting_mask=inpainting_mask, replace_diffused=True, n_resample=4, ): pass utils.pil_image(predictions.denoised_images.clamp(0, 1)).save( "tests/stable_diffusion_inpainting.png" ) def test_stable_diffusion_step(): from diffusers import DDIMScheduler, StableDiffusionPipeline torch.set_grad_enabled(False) device = torch.device("cuda") for model_name in [ "runwayml/stable-diffusion-v1-5", "CompVis/stable-diffusion-v1-4", ]: batch_size = 1 height = 512 width = 512 from_index = 999 to_index = 998 texts = ["painting of a dog"] diffusion_model = StableDiffusion(model_name).to(device) diffused_latents = diffusion_model.random_diffused_latents((1, 3, 512, 512)) conditioning = diffusion_model.conditioning(texts) predictions = diffusion_model.predictions( diffused_latents, from_index, conditioning ) compare_next_diffused_latents = predictions.step(to_index) del diffusion_model # compare with diffusers scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", ) scheduler.set_timesteps(1000) pipeline = StableDiffusionPipeline.from_pretrained( model_name, scheduler=scheduler, use_auth_token=True, ).to(device) latents_shape = (batch_size, pipeline.unet.in_channels, height // 8, width // 8) assert latents_shape == diffused_latents.shape tokenized_text = pipeline.tokenizer( texts, padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = pipeline.text_encoder( tokenized_text.input_ids.to(pipeline.device) )[0] assert text_embeddings.shape == conditioning.encodings.shape assert torch.allclose( text_embeddings[0, 0], conditioning.encodings[0, 0], atol=1e-3 ) assert torch.allclose(text_embeddings, conditioning.encodings, atol=1e-3) index = next(iter(pipeline.scheduler.timesteps)) assert from_index == index predicted_noise = pipeline.unet( diffused_latents, index, encoder_hidden_states=text_embeddings )["sample"] assert torch.allclose(predictions.predicted_noise, predicted_noise, atol=5e-3) next_diffused_latents = pipeline.scheduler.step( predicted_noise, index, diffused_latents )["prev_sample"] assert torch.allclose( next_diffused_latents[0, 0, 0, 0], compare_next_diffused_latents[0, 0, 0, 0], atol=1e-3, ) assert ( next_diffused_latents.sub(compare_next_diffused_latents).abs().max() <= 1e-3 )