Source code for perceptor.losses.lpips

import lpips

from .interface import LossInterface


[docs]class LPIPS(LossInterface): def __init__(self, name="squeeze", linear_layers=True, spatial=False): """ LPIPS loss. Expects images of shape (batch_size, 3, height, width) between 0 and 1. Args: name (str): name of the loss. Available options: ["alex", "vgg", "squeeze"] """ super().__init__() self.model = lpips.LPIPS( net=name, lpips=linear_layers, spatial=spatial, verbose=False ) self.model.eval() self.model.requires_grad_(False)
[docs] def forward(self, images_a, images_b): """ Args: images_a: images of shape (batch_size, 3, height, width) between 0 and 1 images_b: images of shape (batch_size, 3, height, width) between 0 and 1 """ return self.model(images_a, images_b, normalize=True)
def test_lpips(): import torch model = LPIPS() images_a = torch.randn((1, 3, 256, 256)) images_b = torch.randn((1, 3, 256, 256)) model(images_a, images_b)