Source code for perceptor.transforms.clamp_with_grad

import torch
from torch.nn import functional as F
from torchvision.transforms import functional as TF

from .interface import TransformInterface


class ClampWithGradFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min=0, max=1):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        (input,) = ctx.saved_tensors
        return (
            grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
            None,
            None,
        )


[docs]def clamp_with_grad(tensor, min=0.0, max=1.0): return ClampWithGradFunction.apply(tensor, min, max)
[docs]class ClampWithGrad(TransformInterface): def __init__(self, min=0, max=1): super().__init__() self.min = min self.max = max
[docs] def encode(self, tensor): return clamp_with_grad(tensor, self.min, self.max)
[docs] def decode(self, tensor): return tensor