Source code for perceptor.losses.super_resolution.super_resolution

import torch
import torch.nn.functional as F

from perceptor.losses.interface import LossInterface
from perceptor import transforms


[docs]class SuperResolution(LossInterface): def __init__(self, name="x2", pre_downscale=None, half=True, mode="bicubic"): super().__init__() self.transform = transforms.SuperResolution(name, half) self.mode = mode self.pre_downscale = ( self.transform.model.scale if pre_downscale is None else pre_downscale )
[docs] def forward(self, images): with torch.no_grad(): downsampled_size = torch.tensor(images.shape[-2:]) // self.pre_downscale scale = downsampled_size / torch.tensor(images.shape[-2:]) downsampled = transforms.resize( images, scale.tolist(), downsampled_size, resample=self.mode ) upsampled = self.transform(downsampled) if upsampled.shape != images.shape: scale = torch.tensor(images.shape[-2:]) / torch.tensor( upsampled.shape[-2:] ) upsampled = transforms.resize( upsampled, scale.tolist(), out_shape=tuple(images.shape[-2:]), resample=self.mode, ) return (images - upsampled).square().mean()