import torch
from perceptor import models
from perceptor.losses.interface import LossInterface
[docs]class RuCLIP(LossInterface):
def __init__(self, name="ruclip-vit-base-patch32-224"):
super().__init__()
self.name = name
self.model = models.RuCLIP(name)
self.encodings = None
self.weights = None
@property
def device(self):
return next(iter(self.model.parameters())).device
[docs] def add_texts_(self, texts, weights=None):
return self.add_encodings_(self.model.encode_texts(texts), weights)
[docs] def add_images_(self, images, weights=None):
return self.add_encodings_(self.model.encode_images(images), weights)
[docs] def add_encodings_(
self,
encodings,
weights=None,
):
if isinstance(weights, list) or isinstance(weights, tuple):
weights = torch.tensor(weights)
elif weights is None:
weights = torch.ones_like(encodings[:, 0])
if self.encodings is None:
self.encodings = torch.nn.Parameter(
encodings.to(self.device), requires_grad=False
)
self.weights = torch.nn.Parameter(
weights.to(self.device),
requires_grad=False,
)
else:
self.encodings = torch.nn.Parameter(
torch.cat([self.encodings, encodings.to(self.device)]),
requires_grad=False,
)
self.weights = torch.nn.Parameter(
torch.cat([self.weights, weights.to(self.device)]),
requires_grad=False,
)
return self
[docs] def forward(self, images):
image_encodings = self.model.encode_images(images)
spherical_distance = (
(image_encodings[:, None] - self.encodings[None, :])
.norm(dim=2)
.div(2)
.arcsin()
.square()
.mul(2)
)
return (spherical_distance * self.weights).mean()