Source code for perceptor.models.simulacra_aesthetic.simulacra_aesthetic

"""
https://github.com/crowsonkb/simulacra-aesthetic-models/blob/master/LICENSE
"""

import torch
from torch import nn
from torch.nn import functional as F
from basicsr.utils.download_util import load_file_from_url

from perceptor import models

CHECKPOINT_URLS = {
    "ViT-B-32": "https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/master/models/sac_public_2022_06_29_vit_b_32_linear.pth",
    "ViT-B-16": "https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/master/models/sac_public_2022_06_29_vit_b_16_linear.pth",
    "ViT-L-14": "https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/master/models/sac_public_2022_06_29_vit_l_14_linear.pth",
    "RN50": "https://raw.githubusercontent.com/samedii/perceptor/master/perceptor/models/simulacra_aesthetic/weights/RN50.pth",
    "RN101": "https://raw.githubusercontent.com/samedii/perceptor/master/perceptor/models/simulacra_aesthetic/weights/RN101.pth",
    "RN50x4": "https://raw.githubusercontent.com/samedii/perceptor/master/perceptor/models/simulacra_aesthetic/weights/RN50x4.pth",
    "RN50x16": "https://raw.githubusercontent.com/samedii/perceptor/master/perceptor/models/simulacra_aesthetic/weights/RN50x16.pth",
    "RN50x64": "https://raw.githubusercontent.com/samedii/perceptor/master/perceptor/models/simulacra_aesthetic/weights/RN50x64.pth",
    "ViT-L-14-336": "https://raw.githubusercontent.com/samedii/perceptor/master/perceptor/models/simulacra_aesthetic/weights/ViT-L-14-336px.pth",
}


[docs]class SimulacraAesthetic(nn.Module): def __init__(self, model_name="ViT-B-32"): """ Simulacra aesthetic loss based on clip linear regression probe that predicts the aesthetic rating of an image. Args: model_name (str): Name of CLIP model. Available models: - ViT-B-32 - ViT-B-16 - ViT-L-14 - RN50 - RN101 - RN50x4 - RN50x16 - RN50x64 - ViT-L-14-336 """ super().__init__() clip_model = models.CLIP(model_name) checkpoint_path = load_file_from_url( CHECKPOINT_URLS[model_name], "models", ) state_dict = torch.load(checkpoint_path, map_location="cpu") self.linear = nn.Linear(state_dict["linear.weight"].shape[1], 1) self.load_state_dict(state_dict) self.linear.eval() self.linear.requires_grad_(False) self.clip_model = clip_model
[docs] def forward(self, images): encodings = self.clip_model.encode_images(images) return self.linear(F.normalize(encodings, dim=-1) * encodings.shape[-1] ** 0.5)
def test_simulacra_aesthetic(): model = SimulacraAesthetic().cuda() model(torch.randn(1, 3, 256, 256).cuda())