Source code for perceptor.losses.super_resolution.discriminator

import torch
from basicsr.utils.download_util import load_file_from_url

from perceptor.losses.interface import LossInterface
from .unet_discriminator_sn import UNetDiscriminatorSN


checkpoints = {
    "RealESRGAN_x4plus_netD": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth",
}


[docs]class SuperResolutionDiscriminator(LossInterface): def __init__(self, name="RealESRGAN_x4plus_netD"): super().__init__() self.name = name checkpoint_path = load_file_from_url( checkpoints[self.name], "models", ) self.model = UNetDiscriminatorSN(num_in_ch=3, num_feat=64, skip_connection=True) self.model.requires_grad_(False) weights = torch.load(checkpoint_path, map_location=torch.device("cpu")) self.model.load_state_dict(weights["params"], strict=True)
[docs] def forward(self, images): return -self.model(images).mean() * 0.001