Source code for perceptor.losses.memorability
from torchvision import transforms
from resmem import ResMem
from .interface import LossInterface
[docs]class Memorability(LossInterface):
def __init__(self):
super().__init__()
self.recenter = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(227),
]
)
self.model = ResMem(pretrained=True)
self.model.eval()
self.model.requires_grad_(False)
[docs] def forward(self, images):
image_x = self.recenter(images)
prediction = self.model(image_x)
return prediction.mean() * 0.05