Source code for perceptor.models.stable_diffusion.predictions

from __future__ import annotations
from typing import Callable
import torch
import lantern

from perceptor.transforms.clamp_with_grad import clamp_with_grad
from . import diffusion_space


[docs]class Predictions(lantern.FunctionalBase): from_diffused_latents: lantern.Tensor.dims("NCHW") from_indices: lantern.Tensor.dims("N") predicted_noise: lantern.Tensor.dims("NCHW") schedule_alphas: lantern.Tensor schedule_sigmas: lantern.Tensor encode: Callable[[lantern.Tensor.dims("NCHW")], lantern.Tensor.dims("NCHW")] decode: Callable[[lantern.Tensor.dims("NCHW")], lantern.Tensor.dims("NCHW")] @property def device(self): return self.predicted_noise.device
[docs] def indices(self, indices) -> lantern.Tensor: if isinstance(indices, float) or isinstance(indices, int): indices = torch.as_tensor(indices) if indices.ndim == 0: indices = indices[None] if indices.ndim != 1: raise ValueError("indices must be a scalar or a 1D tensor") return indices.long().to(self.device)
[docs] def alphas(self, indices) -> lantern.Tensor: return self.schedule_alphas[self.indices(indices)][:, None, None, None].to( self.device )
[docs] def sigmas(self, indices) -> lantern.Tensor: return self.schedule_sigmas[self.indices(indices)][:, None, None, None].to( self.device )
@property def from_alphas(self) -> lantern.Tensor: return self.alphas(self.from_indices) @property def from_sigmas(self) -> lantern.Tensor: return self.sigmas(self.from_indices) @property def denoised_latents(self) -> lantern.Tensor: return ( self.from_diffused_latents - self.from_sigmas * self.predicted_noise ) / self.from_alphas.clamp(min=1e-7) @property def denoised_images(self) -> lantern.Tensor: return self.decode(self.denoised_latents)
[docs] def step(self, to_indices, eta=0.0) -> lantern.Tensor: """ Reduce noise level to `to_indices` Args: to_indices: Union[Tensor, Tensor.shape("N"), float] eta: float Returns: diffused_images: torch.Tensor.shape("NCHW") """ to_alphas, to_sigmas = self.alphas(to_indices), self.sigmas(to_indices) if eta > 0.0: # If eta > 0, adjust the scaling factor for the predicted noise # downward according to the amount of additional noise to add ddim_sigma = ( eta * (to_sigmas**2 / self.from_sigmas**2).sqrt() * (1 - self.from_alphas**2 / to_alphas**2).sqrt() ) adjusted_sigma = (to_sigmas**2 - ddim_sigma**2).sqrt() # Recombine the predicted noise and predicted denoised image in the # correct proportions for the next step to_diffused_latents = ( self.denoised_latents * to_alphas + self.predicted_noise * adjusted_sigma ) # Add the correct amount of fresh noise noise = torch.randn_like(to_diffused_latents) to_diffused_latents += noise * ddim_sigma else: to_diffused_latents = ( self.denoised_latents * to_alphas + self.predicted_noise * to_sigmas ) return to_diffused_latents
# TODO: do not need to calculate denoised latents? this could introduce errors?
[docs] def correction(self, previous: Predictions) -> Predictions: # k-diffusion has alphas=1 always so this should not work # corrected_diffused_xs = ( # previous.from_diffused_xs # + (self.from_sigmas - previous.from_sigmas) * (self.eps - previous.eps) / 2 # ) # return diffusion_space.decode(corrected_diffused_xs) # blurry # return Predictions( # from_diffused_images=previous.from_diffused_images, # from_indices=previous.from_indices, # velocities=(self.velocities + previous.velocities) / 2, # ) # looks ok but no apparent difference from double budget return previous.forced_denoised( (self.denoised_latents + previous.denoised_latents) / 2 )
[docs] def reverse_step(self, to_indices) -> lantern.Tensor: if (torch.as_tensor(self.from_indices) > torch.as_tensor(to_indices)).any(): raise ValueError("from_indices must be less than to_indices") # replaced_noise_sigma = ( # self.sigmas(to_t) ** 2 - self.sigmas(from_t) ** 2 # ).sqrt() # to_eps = ( # from_eps * self.sigmas(from_t) # + torch.randn_like(from_eps) * replaced_noise_sigma # ) / self.sigmas(to_t) # to_eps = eps # to_diffused = self.diffuse(from_denoised, to_t, noise=to_eps) # return to_diffused to_alphas, to_sigmas = self.alphas(to_indices), self.sigmas(to_indices) return self.denoised_latents * to_alphas + self.predicted_noise * to_sigmas
# TODO: do not need to calculate denoied latents? this could introduce errors?
[docs] def resample(self, resample_indices) -> lantern.Tensor: """ Harmonizing resampling from https://github.com/andreas128/RePaint """ return ( self.denoised_latents * self.from_alphas + self.resample_noise(resample_indices) * self.from_sigmas )
[docs] def resample_noise(self, resample_indices) -> lantern.Tensor: if ( torch.as_tensor(self.from_indices) < torch.as_tensor(resample_indices) ).any(): raise ValueError("from_indices must be greater than resample_indices") resampled_noise_sigma = ( self.sigmas(resample_indices) * self.predicted_noise + ( self.from_sigmas**2 - self.sigmas(resample_indices) ** 2 ).sqrt() * torch.randn_like(self.predicted_noise) ) # fmt: skip return resampled_noise_sigma / self.from_sigmas
[docs] def noisy_reverse_step(self, to_indices) -> lantern.Tensor: to_alphas, to_sigmas = self.alphas(to_indices), self.sigmas(to_indices) noise_sigma = self.from_sigmas * self.predicted_noise + ( to_sigmas**2 - self.from_sigmas**2 ).sqrt() * torch.randn_like(self.predicted_noise) return self.denoised_latents * to_alphas + noise_sigma
[docs] def guided(self, guiding, guidance_scale=0.5, clamp_value=1e-6) -> Predictions: return self.replace( predicted_noise=self.predicted_noise + guidance_scale * self.from_sigmas * guiding.clamp(-clamp_value, clamp_value) / clamp_value )
[docs] def latent_dynamic_threshold(self, quantile=0.95) -> Predictions: if quantile is None: return self dynamic_threshold = torch.quantile( self.predicted_noise.flatten(start_dim=1).abs(), quantile, dim=1 ).clamp(min=2.5) predicted_noise = clamp_with_grad( self.predicted_noise, -dynamic_threshold, dynamic_threshold, ) return self.forced_predicted_noise(predicted_noise)
[docs] def dynamic_threshold(self, quantile=0.95) -> Predictions: """ Thresholding heuristic from imagen paper """ if quantile is None: return self denoised_xs = diffusion_space.encode(self.decode(self.denoised_latents)) dynamic_threshold = torch.quantile( denoised_xs.flatten(start_dim=1).abs(), quantile, dim=1 ).clamp(min=1.0) denoised_xs = ( clamp_with_grad( denoised_xs, -dynamic_threshold, dynamic_threshold, ) / dynamic_threshold ) return self.forced_denoised_latents( self.encode(diffusion_space.decode(denoised_xs)) )
[docs] def forced_denoised_latents(self, denoised_latents) -> Predictions: predicted_noise = ( self.from_diffused_latents - denoised_latents * self.from_alphas ) / self.from_sigmas.clamp(min=1e-7) return self.replace(predicted_noise=predicted_noise)
[docs] def forced_predicted_noise(self, predicted_noise) -> Predictions: return self.replace(predicted_noise=predicted_noise)
[docs] def wasserstein_distance(self) -> lantern.Tensor: sorted_noise = self.predicted_noise.flatten(start_dim=1).sort(dim=1)[0] n = sorted_noise.shape[1] margin = 0.5 / n points = torch.linspace(margin, 1 - margin, sorted_noise.shape[1]) expected_noise = torch.distributions.Normal(0, 1).icdf(points) return (sorted_noise - expected_noise[None].to(sorted_noise)).abs().mean()
[docs] def wasserstein_square_distance(self) -> lantern.Tensor: sorted_noise = self.predicted_noise.flatten(start_dim=1).sort(dim=1)[0] n = sorted_noise.shape[1] margin = 0.5 / n points = torch.linspace(margin, 1 - margin, sorted_noise.shape[1]) expected_noise = torch.distributions.Normal(0, 1).icdf(points) return (sorted_noise - expected_noise[None].to(sorted_noise)).square().mean()
[docs] def classifier_free_guidance( self, positive_predictions: Predictions, guidance_scale=7.0 ) -> Predictions: return self.replace( predicted_noise=self.predicted_noise + (positive_predictions.predicted_noise - self.predicted_noise) * guidance_scale )