Source code for perceptor.losses.owlvit

from typing import List
import torch

from perceptor.models.owlvit.owlvit import OWLViTEncodings
from perceptor import models
from perceptor.losses.interface import LossInterface


[docs]class OWLViT(LossInterface): def __init__(self): """ OWL-ViT zero-shot text-conditioned bounding box model. Loss is designed in such a way that only one example of the text prompt is expected to be found in the image. """ super().__init__() self.model = models.OWLViT() self.encodings = None self.weights = None @property def device(self): return next(iter(self.model.parameters())).device
[docs] def to(self, device): if self.encodings is not None: self.encodings.to(device) if self.weights is not None: self.weights.to(device) return super().to(device)
[docs] def cuda(self): return self.to(torch.device("cuda"))
[docs] def cpu(self): return self.to(torch.device("cpu"))
[docs] def add_texts_(self, texts: List[str], weights=None): return self.add_encodings_(self.model.encode_texts([texts]), weights)
[docs] def add_images_(self, images, weights=None): # return self.add_encodings_(self.model.encode_images(images), weights) raise NotImplementedError()
[docs] def add_encodings_( self, encodings: OWLViTEncodings, weights=None, ): if isinstance(weights, list) or isinstance(weights, tuple): weights = torch.tensor(weights) elif weights is None: weights = torch.ones(len(encodings.texts)) if self.encodings is None: self.encodings = encodings self.weights = torch.nn.Parameter( weights.to(self.device), requires_grad=False, ) else: raise ValueError("OWLViT can only have one set of encodings") return self
[docs] def forward(self, images, top_k=5): predictions = self.model(images, self.encodings) loss = torch.tensor(0.0, device=self.device) for label_index, weight in enumerate(self.weights): loss -= ( predictions.logits[:, :, label_index] .flatten(start_dim=1) .log_softmax(dim=1) .sort(dim=1)[0][:, -top_k:] .mean() * weight ) return loss * 0.01
def test_owlvit_loss(): text_prompts = [ "hello", "world", ] loss = OWLViT().add_texts_(text_prompts).cuda() loss(torch.zeros(3, 3, 480, 480).cuda())