Source code for perceptor.models.velocity_diffusion.velocity_diffusion

from functools import partial
import torch
import lantern
from basicsr.utils.download_util import load_file_from_url

from perceptor.utils import cache
from perceptor import models
from .models import get_model
from .model_urls import model_urls
from . import diffusion_space, utils
from .predictions import Predictions


# @cache
[docs]class VelocityDiffusion(torch.nn.Module): def __init__(self, name="yfcc_2"): """ Args: name: The name of the model. Available models are: - yfcc_2 - yfcc_1 - cc12m_1_cfg (conditioned) - wikiart """ super().__init__() self.name = name self.model = get_model(name)() checkpoint_path = load_file_from_url(model_urls[name], "models") self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) print(f"Loaded checkpoint {checkpoint_path}") self.model.eval() self.model.requires_grad_(False)
[docs] def to(self, device): super().to(device) if device == torch.device("cuda"): self.model.half() return self
@property def device(self): return next(iter(self.parameters())).device @property def shape(self): return self.model.shape
[docs] @staticmethod def schedule_ts(n_steps=500, from_ts=1.0, to_ts=1e-2, rho=7.0) -> lantern.Tensor: from_alpha, from_sigma = utils.t_to_alpha_sigma(torch.as_tensor(from_ts)) to_alpha, to_sigma = utils.t_to_alpha_sigma(torch.as_tensor(to_ts)) from_log_snr = utils.alpha_sigma_to_log_snr(from_alpha, from_sigma) to_log_snr = utils.alpha_sigma_to_log_snr(to_alpha, to_sigma) elucidated_from_sigma = (1 / from_log_snr.exp()).sqrt().clamp(max=150) elucidated_to_sigma = (1 / to_log_snr.exp()).sqrt().clamp(min=1e-3) ramp = torch.linspace(0, 1, n_steps + 1) min_inv_rho = elucidated_to_sigma ** (1 / rho) max_inv_rho = elucidated_from_sigma ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho log_snr = utils.alpha_sigma_to_log_snr(torch.ones_like(sigmas), sigmas) alpha, sigma = utils.log_snr_to_alpha_sigma(log_snr) schedule_ts = utils.alpha_sigma_to_t(alpha, sigma) return torch.stack([schedule_ts[:-1], schedule_ts[1:]], dim=1)
[docs] def random_diffused(self, shape) -> lantern.Tensor: return diffusion_space.decode(torch.randn(shape)).to(self.device)
[docs] @staticmethod def sigmas_to_ts(sigmas) -> lantern.Tensor: sigmas = torch.as_tensor(sigmas) return utils.sigma_to_t(sigmas)
[docs] def alphas(self, ts) -> lantern.Tensor: if isinstance(ts, float): ts = torch.tensor(ts) if ts.ndim == 0: ts = ts[None] if ts.ndim != 1: raise ValueError("t must be a scalar or a 1D tensor") alphas, _ = utils.t_to_alpha_sigma(ts) return alphas[:, None, None, None].to(self.device)
[docs] def sigmas(self, ts) -> lantern.Tensor: if isinstance(ts, float): ts = torch.tensor(ts) if ts.ndim == 0: ts = ts[None] if ts.ndim != 1: raise ValueError("t must be a scalar or a 1D tensor") _, sigmas = utils.t_to_alpha_sigma(ts) return sigmas[:, None, None, None].to(self.device)
[docs] @torch.cuda.amp.autocast() def velocities(self, diffused, t, conditioning=None) -> lantern.Tensor: x = diffusion_space.encode(diffused) if hasattr(self.model, "clip_model"): model_fn = partial(self.model, clip_embed=conditioning.squeeze(dim=1)) else: model_fn = self.model if isinstance(t, float) or t.ndim == 0: t = torch.full((x.shape[0],), t).to(x) velocities = model_fn(x, t) return velocities.float()
[docs] def forward(self, diffused_images, ts, conditioning=None) -> Predictions: if isinstance(ts, float) or ts.ndim == 0: ts = torch.full((diffused_images.shape[0],), ts).to(diffused_images) return Predictions( from_diffused_images=diffused_images, from_ts=ts, velocities=self.velocities(diffused_images, ts, conditioning), )
[docs] def predictions(self, diffused_images, ts, conditioning=None) -> Predictions: return self.forward(diffused_images, ts, conditioning)
[docs] def conditioning(self, texts=None, images=None, encodings=None) -> lantern.Tensor: clip_model = models.CLIP(self.model.clip_model) all_encodings = list() if texts is not None: all_encodings.append(clip_model.encode_texts(texts)) if images is not None: all_encodings.append(clip_model.encode_images(images)) if encodings is not None: all_encodings.append(encodings) if len(all_encodings) == 0: raise ValueError("Must provide at least one of texts, images, or encodings") return torch.stack(all_encodings, dim=0).mean(dim=0)[None]
[docs] def diffuse(self, denoised_images, ts, noise=None) -> lantern.Tensor: denoised_xs = diffusion_space.encode(denoised_images) if isinstance(ts, float) or ts.ndim == 0: ts = torch.full((denoised_xs.shape[0],), ts).to(denoised_xs) if noise is None: noise = torch.randn_like(denoised_xs) alphas, sigmas = self.alphas(ts), self.sigmas(ts) return diffusion_space.decode(denoised_xs * alphas + noise * sigmas)
[docs] def inject_noise( self, diffused_images, ts, reversed_ts, extra_noise_multiplier=1.003 ) -> lantern.Tensor: diffused_xs = diffusion_space.encode(diffused_images).to(self.device) diffused_multiplier = self.alphas(reversed_ts) / self.alphas(ts) target_sigmas = self.sigmas(reversed_ts) additional_noise_std = ( target_sigmas.square() - self.sigmas(ts).square() * diffused_multiplier.square() ).sqrt() reversed_diffused_xs = ( diffused_xs * diffused_multiplier + additional_noise_std * torch.randn_like(diffused_xs) * extra_noise_multiplier ) return diffusion_space.decode(reversed_diffused_xs)
def test_velocity_diffusion(): from perceptor import utils torch.set_grad_enabled(False) device = torch.device("cuda") diffusion = models.VelocityDiffusion("yfcc_2").to(device) diffused_images = diffusion.random_diffused((1, 3, 512, 512)).to(device) for from_ts, to_ts in diffusion.schedule_ts(n_steps=50): if (from_ts < 1.0).all(): new_from_ts = from_ts * 1.003 diffused_images = diffusion.predictions( diffused_images, from_ts ).noisy_reverse_step(new_from_ts) from_ts = new_from_ts predictions = diffusion.predictions( diffused_images, from_ts, ) diffused_images = predictions.step(to_ts) diffused_images = ( diffusion.predictions(diffused_images, to_ts) .correction(predictions) .step(to_ts) ) utils.pil_image(diffusion.predictions(diffused_images, to_ts).denoised_images).save( "tests/velocity_diffusion_yfcc_2.png" ) def test_conditioned_velocity_diffusion(): from perceptor import utils torch.set_grad_enabled(False) device = torch.device("cuda") diffusion = VelocityDiffusion("cc12m_1_cfg").to(device) diffused_images = diffusion.random_diffused((1, 3, 256, 256)).to(device) conditioning = diffusion.conditioning(texts=["photo of a cute cat"]) for from_ts, to_ts in diffusion.schedule_ts(n_steps=50): predictions = diffusion.predictions(diffused_images, from_ts, conditioning) diffused_images = predictions.step(to_ts) utils.pil_image( diffusion.predictions(diffused_images, to_ts, conditioning).denoised_images ).save("tests/velocity_diffusion_cc12m_1.png") def test_convert_sigma_ts(): diffusion = VelocityDiffusion("cc12m_1_cfg") from_ts = 0.3 assert ( from_ts - diffusion.sigmas_to_ts(diffusion.sigmas(from_ts)).squeeze() ).abs() <= 1e-5 def test_schedule_ts(): diffusion = VelocityDiffusion("cc12m_1_cfg") from_ts = 0.6 assert torch.allclose( diffusion.schedule_ts(n_steps=50, from_ts=from_ts)[0, 0], torch.as_tensor(from_ts), ) def test_utils_conversion(): t = torch.as_tensor(0.3) alpha, sigma = utils.t_to_alpha_sigma(t) assert torch.allclose(utils.sigma_to_t(sigma), t) assert t == utils.alpha_sigma_to_t(alpha, sigma)