Source code for perceptor.models.monster_diffusion.monster_diffusion

from typing import Optional
from tqdm import tqdm
from scipy import integrate
import torch
import torch.nn as nn
import torch.nn.functional as F
from lantern import module_device, Tensor
from basicsr.utils.download_util import load_file_from_url

from . import settings
from . import standardize
from . import base
from . import diffusion
from .prediction import PredictionBatch

ALL_CHECKPOINT_URL = "https://s3.eu-central-1.wasabisys.com/nextml-model-data/monster-diffusion/6b70ff1e6c7f4c00ad8cb59879f7d88d.pt"
TINY_HERO_CHECKPOINT_URL = "https://s3.eu-central-1.wasabisys.com/nextml-model-data/monster-diffusion/f47af8975b744d4bae2b905bac223003.pt"


[docs]class MonsterDiffusion(nn.Module): def __init__(self, name="all"): super().__init__() self.network = base.Model( mapping_cond_dim=9, ) if name == "all": checkpoint_path = load_file_from_url(ALL_CHECKPOINT_URL, "models") elif name == "tiny-hero": checkpoint_path = load_file_from_url(TINY_HERO_CHECKPOINT_URL, "models") else: raise ValueError(f"Unknown model name {name}") self.load_state_dict(torch.load(checkpoint_path)) self.eval().requires_grad_(False) @property def device(self): return module_device(self)
[docs] @staticmethod def training_ts(size): random_ts = (diffusion.P_mean + torch.randn(size) * diffusion.P_std).exp() return random_ts
def _schedule_ts(self, n_steps): ramp = torch.linspace(0, 1, n_steps).to(self.device) min_inv_rho = diffusion.sigma_min ** (1 / diffusion.rho) max_inv_rho = diffusion.sigma_max ** (1 / diffusion.rho) return (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** diffusion.rho
[docs] def schedule_ts(self, n_steps): schedule_ts = self._schedule_ts(n_steps) return zip(schedule_ts[:-1], schedule_ts[1:])
[docs] @staticmethod def sigmas(ts): return PredictionBatch.sigmas(ts)
[docs] @staticmethod def alphas(ts): return PredictionBatch.alphas(ts)
[docs] @staticmethod def random_noise(size): return standardize.decode( torch.randn(size, *settings.INPUT_SHAPE) * diffusion.sigma_max )
[docs] @staticmethod def diffuse( images: Tensor.dims("NCHW").shape(-1, *settings.INPUT_SHAPE).float(), ts, noise=None, ): x0 = standardize.encode(images) if isinstance(ts, float) or ts.ndim == 0: ts = torch.full((x0.shape[0],), ts).to(x0.device) if noise is None: noise = torch.randn_like(x0).to(x0.device) assert x0.shape == noise.shape return standardize.decode(x0 + noise * MonsterDiffusion.sigmas(ts))
[docs] def c_skip(self, ts): return diffusion.sigma_data**2 / ( diffusion.sigma_data**2 + self.sigmas(ts) ** 2 )
[docs] def c_out(self, ts): return ( self.sigmas(ts) * diffusion.sigma_data / torch.sqrt(diffusion.sigma_data**2 + self.sigmas(ts) ** 2) )
[docs] def c_in(self, ts): return 1 / torch.sqrt(diffusion.sigma_data**2 + self.sigmas(ts) ** 2)
[docs] def c_noise(self, ts): return 1 / 4 * self.sigmas(ts).log().view(-1)
[docs] def denoised_( self, diffused_images: Tensor.dims("NCHW").shape(-1, *settings.INPUT_SHAPE).float(), ts: Tensor.dims("N"), nonleaky_augmentations: Optional[Tensor.dims("NK")] = None, ) -> Tensor.dims("NCHW").shape(-1, *settings.INPUT_SHAPE): """ Parameterization from https://arxiv.org/pdf/2206.00364.pdf """ diffused_xs = standardize.encode(diffused_images.to(self.device)) ts = ts.to(self.device) if nonleaky_augmentations is None: nonleaky_augmentations = torch.zeros( len(diffused_images), settings.N_AUGMENTATIONS ) nonleaky_augmentations = nonleaky_augmentations.to(self.device) output = self.network( self.c_in(ts) * diffused_xs, self.c_noise(self.sigmas(ts).flatten()), mapping_cond=nonleaky_augmentations, ) return self.c_skip(ts) * diffused_xs + self.c_out(ts) * output
[docs] def forward( self, diffused_images: Tensor.dims("NCHW"), ts: Tensor.dims("N"), nonleaky_augmentations: Optional[Tensor.dims("NK")] = None, ): denoised_xs = self.denoised_( diffused_images, ts, nonleaky_augmentations, ) return PredictionBatch( denoised_xs=denoised_xs, diffused_images=diffused_images, ts=torch.as_tensor(ts).flatten().to(self.device), )
[docs] def predictions_( self, diffused_images: Tensor.dims("NCHW"), ts: Tensor.dims("N"), nonleaky_augmentations: Optional[Tensor.dims("NK")] = None, ): return self.forward( diffused_images, ts, nonleaky_augmentations, )
[docs] def predictions( self, diffused_images: Tensor.dims("NCHW"), ts: Tensor.dims("N"), nonleaky_augmentations: Optional[Tensor.dims("NK")] = None, ): if self.training: raise Exception( "Cannot run predictions method while in training mode. Use predictions_" ) return self.predictions_( diffused_images, ts, nonleaky_augmentations, )
[docs] @staticmethod def gamma(ts, n_steps): return torch.where( (ts >= diffusion.S_tmin) & (ts <= diffusion.S_tmax), torch.minimum( torch.tensor(diffusion.S_churn / n_steps), torch.tensor(2).sqrt() - 1, ).to(ts), torch.zeros_like(ts), )
[docs] @staticmethod def reversed_ts(ts, n_steps): return ts + MonsterDiffusion.gamma(ts, n_steps) * ts
[docs] def inject_noise(self, diffused_images, ts, reversed_ts): diffused_xs = standardize.encode(diffused_images).to(self.device) reversed_diffused_xs = ( diffused_xs + (self.sigmas(reversed_ts).square() - self.sigmas(ts).square()).sqrt() * torch.randn_like(diffused_xs) * diffusion.S_noise ) return standardize.decode(reversed_diffused_xs)
[docs] def sample( self, size, n_evaluations=100, progress=False, diffused_images=None, ): return self.elucidated_sample( size, n_evaluations, progress, diffused_images, )
[docs] def elucidated_sample( self, size, n_evaluations=100, progress=False, diffused_images=None, ): """ Elucidated stochastic sampling from https://arxiv.org/pdf/2206.00364.pdf """ if self.training: raise Exception("Cannot run sample method while in training mode.") if diffused_images is None: diffused_images = self.random_noise(size).to(self.device) nonleaky_augmentations = torch.zeros( (size, settings.N_AUGMENTATIONS), dtype=torch.float32, device=self.device ) n_steps = n_evaluations // 2 i = 0 progress = tqdm(total=n_steps, disable=not progress, leave=False) for from_ts, to_ts in self.schedule_ts(n_steps): reversed_ts = self.reversed_ts(from_ts, n_steps).clamp( max=diffusion.sigma_max ) reversed_diffused_images = self.inject_noise( diffused_images, from_ts, reversed_ts ) i += 1 predictions = self.predictions( reversed_diffused_images, reversed_ts, nonleaky_augmentations, ) reversed_eps = predictions.eps diffused_images = predictions.step(to_ts) predictions = self.predictions( diffused_images, to_ts, nonleaky_augmentations, ) diffused_images = predictions.correction( reversed_diffused_images, reversed_ts, reversed_eps ) progress.update() yield predictions.denoised_images.clamp(0, 1) reversed_ts = self.reversed_ts(to_ts, n_steps) diffused_images = self.inject_noise(diffused_images, to_ts, reversed_ts) predictions = self.predictions( diffused_images, reversed_ts, nonleaky_augmentations, ) progress.close() yield predictions.denoised_images.clamp(0, 1)
[docs] @staticmethod def linear_multistep_coeff(order, sigmas, from_index, to_index): if order - 1 > from_index: raise ValueError(f"Order {order} too high for step {from_index}") def fn(tau): prod = 1.0 for k in range(order): if to_index == k: continue prod *= (tau - sigmas[from_index - k]) / ( sigmas[from_index - to_index] - sigmas[from_index - k] ) return prod return integrate.quad( fn, sigmas[from_index], sigmas[from_index + 1], epsrel=1e-4 )[0]
[docs] def linear_multistep_sample( self, size, n_evaluations=100, progress=False, diffused_images=None, order=4, ): """ Katherine Crowson's linear multistep method from https://github.com/crowsonkb/k-diffusion/blob/4fdb34081f7a09f16c33d3344a042e5bea8e69ee/k_diffusion/sampling.py """ if self.training: raise Exception("Cannot run sample method while in training mode.") if diffused_images is None: diffused_images = self.random_noise(size) nonleaky_augmentations = torch.zeros( (size, settings.N_AUGMENTATIONS), dtype=torch.float32, device=self.device ) diffused_images = diffused_images.to(self.device) n_steps = n_evaluations schedule_ts = self._schedule_ts(n_steps) epses = list() progress = tqdm(total=n_steps, disable=not progress, leave=False) for from_index, (from_ts, to_ts) in enumerate(self.schedule_ts(n_steps)): predictions = self.predictions( diffused_images, from_ts, nonleaky_augmentations, ) epses.append(predictions.eps) if len(epses) > order: epses.pop(0) current_order = len(epses) coeffs = [ self.linear_multistep_coeff( current_order, self.sigmas(schedule_ts).cpu().flatten(), from_index, to_index, ) for to_index in range(current_order) ] diffused_xs = standardize.encode(diffused_images) diffused_xs = diffused_xs + sum( coeff * eps for coeff, eps in zip(coeffs, reversed(epses)) ) diffused_images = standardize.decode(diffused_xs) progress.update() yield predictions.denoised_images.clamp(0, 1) predictions = self.predictions( diffused_images, to_ts, nonleaky_augmentations, ) progress.close() yield predictions.denoised_images.clamp(0, 1)
def test_monster_diffusion(): from perceptor import utils model = MonsterDiffusion().cuda() for images in model.sample(size=1, n_evaluations=50): pass utils.pil_image(images).save("tests/monster_diffusion.png") def test_monster_diffusion_lms(): from perceptor import utils model = MonsterDiffusion().cuda() for images in model.linear_multistep_sample(size=1, n_evaluations=50): pass utils.pil_image(images).save("tests/monster_diffusion_lms.png")