Source code for perceptor.utils.gradient_checkpoint

from lantern import FunctionalBase, Tensor
import torch


class GradientCheckpoint(FunctionalBase):
    original: Tensor
    detached: Tensor

    def __init__(self, tensor):
        super().__init__(original=tensor, detached=tensor.detach().requires_grad_())

    def zero_grad_(self):
        self.detached.grad.zero_()
        return self

    def backward(self, loss):
        loss.backward()
        gradients = self.detached.grad.clone()
        self.zero_grad_()
        return gradients

    def continue_backward(self, gradients=None, retain_graph=False):
        if self.detached.grad is None:
            raise ValueError("Gradient is not defined")

        if gradients is None:
            return self.original.backward(self.detached.grad, retain_graph=retain_graph)
        else:
            return self.original.backward(gradients, retain_graph=retain_graph)

    def tensor(self):
        return self.detached

    @staticmethod
    def nonzero_mean(gradients, dim=0):
        if isinstance(gradients, list):
            gradients = torch.stack(gradients)
        return gradients.sum(dim).div(gradients.ne(0).sum(dim).add(1e-6))

    @staticmethod
    def nonzero_scale(tensor, dim=None):
        if isinstance(tensor, list):
            tensor = torch.stack(tensor)
        shape = tensor.shape
        if dim is None:
            tensor = tensor.flatten()
            dim = 0

        mask = tensor.ne(0)
        mean_square = tensor.square().sum(dim) / mask.sum(dim).add(1e-6)
        mean = tensor.sum(dim) / mask.sum(dim).add(1e-6)
        std = (mean_square - mean.square()).sqrt().add(1e-6)
        scaled_tensor = tensor / std.unsqueeze(dim).add(1e-6)
        return scaled_tensor.view(*shape)


[docs]def gradient_checkpoint(tensor: Tensor) -> GradientCheckpoint: """ Gradient checkpointing to save compute for common part of graph. Usage: >>> checkpoint = gradient_checkpoint(images) >>> for text_loss in text_losses: >>> text_loss(checkpoint.tensor()).backward() >>> checkpoint.continue_backward() """ return GradientCheckpoint(tensor)
def test_gradient_checkpoint(): import torch with torch.enable_grad(): images = torch.zeros(1, 3, 64, 64).requires_grad_() checkpoint = gradient_checkpoint(images * 2) checkpoint.tensor().pow(2).mean().backward() assert checkpoint.detached.grad is not None checkpoint.continue_backward() assert images.grad is not None