from typing import Optional
import torch
import lantern
[docs]class Conditioning(torch.nn.Module):
def __init__(
self,
model_name: str,
encodings: lantern.Tensor,
inpainting_latent_masks: Optional[lantern.Tensor.dims("NCHW")] = None,
inpainting_latents: Optional[lantern.Tensor.dims("NCHW")] = None,
):
super().__init__()
self.model_name = model_name
self.encodings = torch.nn.Parameter(encodings, requires_grad=False)
self.inpainting_latent_masks = inpainting_latent_masks
self.inpainting_latents = inpainting_latents
@property
def device(self):
return self.encodings.device
def __neg__(self):
return Conditioning(
-self.encodings,
inpainting_latent_masks=self.inpainting_latent_masks,
inpainting_latents=self.inpainting_latents,
)