Source code for perceptor.losses.aesthetic_visual_assessment

import torch
from torch import nn
from torch.nn import functional as F
from basicsr.utils.download_util import load_file_from_url

from .interface import LossInterface
from perceptor import models


[docs]class AestheticVisualAssessment(LossInterface): def __init__(self, aesthetic_target=10, mode="expected"): """ Aesthetic visual assessment loss based on a classifier that predicts the aesthetic rating of an image. Args: aesthetic_target (int): Target asthetic rating of the image (1-10). mode (str): How to compute the loss. Available modes are: - "logit": compute the loss based on the logits of the classifier. - "expected": compute the loss based on the expected value of the classifier. - "probability": compute the loss based on the probability of the classifier. """ super().__init__() self.aesthetic_target = aesthetic_target self.mode = mode self.model = models.CLIP("ViT-B-16") checkpoint_path = load_file_from_url( "http://batbot.tv/ai/models/v-diffusion/ava_vit_b_16_full.pth", "models", ) self.aesthetic_head = nn.Linear(512, 10) self.aesthetic_head.load_state_dict( torch.load(checkpoint_path, map_location="cpu") ) self.aesthetic_head.eval() self.aesthetic_head.requires_grad_(False)
[docs] def forward(self, images): log_probs = self.aesthetic_head(self.model.encode_images(images)) if self.mode == "logit": return -log_probs[..., self.aesthetic_target - 1].mean().mul(0.01) elif self.mode == "expected": expected_target = F.softmax(log_probs, dim=-1) * torch.arange(10).add(1).to( images.device ) return (expected_target - self.aesthetic_target).square().mean().mul(0.01) elif self.mode == "probability": return -F.softmax(log_probs, dim=-1)[..., self.aesthetic_target - 1].mean() else: raise ValueError(f"Unknown mode: {self.mode}")
def test_aesthetic_visual_assessment(): loss = AestheticVisualAssessment().cuda() loss(torch.randn(1, 3, 256, 256).cuda())