Source code for perceptor.drawers.raw.raw

from torch import nn

from ..interface import DrawingInterface
from perceptor.transforms import resize
from .init.fractal import fractal
from .init.gradient import gradient


[docs]class Raw(DrawingInterface): def __init__(self, init_images): """ Minimal container for a nn.Parameter with init helpers. Usage: images = Raw(Raw.random_fractal_image((1, 3, 256, 256))) """ super().__init__() self.images = nn.Parameter(init_images)
[docs] def synthesize(self, _=None): return self.images
[docs] def encode(self, images, mode="bilinear"): return resize( images, out_shape=tuple(self.images.shape[-2:]), resample=mode, )
[docs] def replace_(self, images): self.images.data.copy_(images.data) return self
[docs] @staticmethod def random_fractal_image(shape): return Raw(fractal(shape))
[docs] @staticmethod def random_gradient_image(shape): return Raw(gradient(shape))
def test_raw(): import torch Raw(torch.zeros(1, 3, 128, 128)) def test_raw_fractal(): Raw.random_fractal_image((1, 3, 256, 256)) def test_raw_gradient(): Raw.random_gradient_image((1, 3, 256, 256))