Source code for perceptor.drawers.deep_image_prior

import torch
from torch import nn

from .interface import DrawingInterface
from perceptor import models


[docs]class DeepImagePrior(DrawingInterface): def __init__(self, size, n_feature_channels=64, output_channels=3): super().__init__() self.deep_image_prior = models.DeepImagePrior( shape=(n_feature_channels, *size), output_channels=output_channels ) self.latents = nn.Parameter( self.deep_image_prior.random_latents(), requires_grad=False ) self.images = nn.Parameter(torch.zeros((1, output_channels, *size)))
[docs] def synthesize(self, _=None): return self.deep_image_prior(self.latents) + self.images
[docs] def loss(self): return self.images.abs().mean().mul(0.0001)