Source code for perceptor.losses.ruclip

import torch

from perceptor import models
from perceptor.losses.interface import LossInterface


[docs]class RuCLIP(LossInterface): def __init__(self, name="ruclip-vit-base-patch32-224"): super().__init__() self.name = name self.model = models.RuCLIP(name) self.encodings = None self.weights = None @property def device(self): return next(iter(self.model.parameters())).device
[docs] def add_texts_(self, texts, weights=None): return self.add_encodings_(self.model.encode_texts(texts), weights)
[docs] def add_images_(self, images, weights=None): return self.add_encodings_(self.model.encode_images(images), weights)
[docs] def add_encodings_( self, encodings, weights=None, ): if isinstance(weights, list) or isinstance(weights, tuple): weights = torch.tensor(weights) elif weights is None: weights = torch.ones_like(encodings[:, 0]) if self.encodings is None: self.encodings = torch.nn.Parameter( encodings.to(self.device), requires_grad=False ) self.weights = torch.nn.Parameter( weights.to(self.device), requires_grad=False, ) else: self.encodings = torch.nn.Parameter( torch.cat([self.encodings, encodings.to(self.device)]), requires_grad=False, ) self.weights = torch.nn.Parameter( torch.cat([self.weights, weights.to(self.device)]), requires_grad=False, ) return self
[docs] def forward(self, images): image_encodings = self.model.encode_images(images) spherical_distance = ( (image_encodings[:, None] - self.encodings[None, :]) .norm(dim=2) .div(2) .arcsin() .square() .mul(2) ) return (spherical_distance * self.weights).mean()