Source code for perceptor.transforms.resize.resize_right

# https://github.com/assafshocher/ResizeRight/blob/master/resize_right.py
from typing import Optional, Tuple
import warnings
from math import ceil
from . import interpolation_methods
from fractions import Fraction


class NoneClass:
    pass


try:
    import torch
    from torch import nn

    nnModuleWrapped = nn.Module
except ImportError:
    warnings.warn("No PyTorch found, will work only with Numpy")
    torch = None
    nnModuleWrapped = NoneClass

try:
    import numpy
except ImportError:
    warnings.warn("No Numpy found, will work only with PyTorch")
    numpy = None


if numpy is None and torch is None:
    raise ImportError("Must have either Numpy or PyTorch but both not found")


[docs]def resize( input, scale_factors: float = None, out_shape: Tuple[int, int] = None, resample: Optional[str] = None, support_sz: Optional[int] = None, antialiasing: bool = True, by_convs: bool = False, scale_tolerance: Optional[float] = None, max_numerator: int = 10, pad_mode: str = "constant", ): """ Alternative resize method. May be useful but pytorch 1.11 has added optional antialiasing too. Args: input: the input image/tensor, a Numpy or Torch tensor. scale_factors: can be specified as 1. one scalar scale - then it will be assumed that you want to resize first two dims with this scale for Numpy or last two dims for PyTorch. 2. a list or tuple of scales - one for each dimension you want to resize. note: if length of the list is L then first L dims will be rescaled for Numpy and last L for PyTorch. 3. not specified - then it will be calculated using output_size. this is not recomended (see advantage 3 in the list above). out_shape: A list or tuple. if shorter than input.shape then only the first/last (depending np/torch) dims are resized. if not specified, can be calculated from scale_factor. resample: The type of interpolation used to calculate the weights. this is a scalar to scalar function that can be applied to tensors pointwise. The classical methods are implemented and can be found in interpolation_methods.py. (cubic, linear, laczos2, lanczos3, box). If not specified, then bicubic is used for upsampling and laczos2 for downsampling. support_sz: This is the support of the interpolation function, i.e length of non-zero segment over its 1d input domain. this is a characteristic of the function. eg. for bicubic 4, linear 2, laczos2 4, lanczos3 6, box 1. antialiasing: This is an option similar to MATLAB's default. Only relevant for downscaling. If true it basically means that the kernel is stretched with 1/scale_factor to prevent aliasing (low-pass filtering) by_convs: This determines whether to allow efficient calculation using convolutions according to tolerance. This feature should be used when scale_factor is rational with a numerator low enough (or close enough to being an integer) and the tensors are big (batches or high-resolution). scale_tolerance: This is the allowed distance between the M/N closest frac to the float scale_factore provided. If the frac is closer than this distance, then it will be used and efficient convolution calculation will take place. max_numerator: When by_convs is on, the scale_factor is translated to a rational frac M/N. Where M is limited by this parameter. The goal is to make the calculation more efficient. The number of convolutions used is the size of the numerator. pad_mode: This can be used according to the padding methods of each framework. PyTorch: 'constant', 'reflect', 'replicate', 'circular'. Numpy: 'constant', 'edge', 'linear_ramp', 'maximum', 'mean, 'median', 'minimum', 'reflect', 'symmetric', 'wrap', 'empty' """ # get properties of the input tensor in_shape, n_dims = input.shape, input.ndim # fw stands for framework that can be either numpy or torch, # determined by the input type fw = numpy if type(input) is numpy.ndarray else torch eps = fw.finfo(fw.float32).eps device = input.device if fw is torch else None # set missing scale factors or output shapem one according to another, # scream if both missing. this is also where all the defults policies # take place. also handling the by_convs attribute carefully. scale_factors, out_shape, by_convs = set_scale_and_out_sz( in_shape, out_shape, scale_factors, by_convs, scale_tolerance, max_numerator, eps, fw, ) if resample is None: original_height, original_width = input.shape[-2:] new_height, new_width = out_shape[-2:] if original_height >= new_height and original_width >= new_width: resample = "lanczos3" else: resample = "bicubic" interp_method = interpolation_methods.methods[resample] # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient sorted_filtered_dims_and_scales = [ (dim, scale_factors[dim], by_convs[dim], in_shape[dim], out_shape[dim]) for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.0 ] # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz # output begins identical to input and changes with each iteration output = input # iterate over dims for ( dim, scale_factor, dim_by_convs, in_sz, out_sz, ) in sorted_filtered_dims_and_scales: # STEP 1- PROJECTED GRID: The non-integer locations of the projection # of output pixel locations to the input tensor projected_grid = get_projected_grid( in_sz, out_sz, scale_factor, fw, dim_by_convs, device ) # STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify # the window size and the interpolation method (see inside function) cur_interp_method, cur_support_sz = apply_antialiasing_if_needed( interp_method, support_sz, scale_factor, antialiasing ) # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels # that influence it. Also calculate needed padding and update grid # accoedingly field_of_view = get_field_of_view( projected_grid, cur_support_sz, fw, eps, device ) # STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view, # the input should be padded to handle the boundaries, coordinates # should be updated. actual padding only occurs when weights are # aplied (step 4). if using by_convs for this dim, then we need to # calc right and left boundaries for each filter instead. pad_sz, projected_grid, field_of_view = calc_pad_sz( in_sz, out_sz, field_of_view, projected_grid, scale_factor, dim_by_convs, fw, device, ) # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in # the field of view for each output pixel weights = get_weights(cur_interp_method, projected_grid, field_of_view) # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying # its set of weights with the pixel values in its field of view. # We now multiply the fields of view with their matching weights. # We do this by tensor multiplication and broadcasting. # if by_convs is true for this dim, then we do this action by # convolutions. this is equivalent but faster. if not dim_by_convs: output = apply_weights( output, field_of_view, weights, dim, n_dims, pad_sz, pad_mode, fw ) else: output = apply_convs( output, scale_factor, in_sz, out_sz, weights, dim, pad_sz, pad_mode, fw ) return output
def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None): # we start by having the ouput coordinates which are just integer locations # in the special case when usin by_convs, we only need two cycles of grid # points. the first and last. grid_sz = out_sz if not by_convs else scale_factor.numerator out_coordinates = fw_arange(grid_sz, fw, device) # This is projecting the ouput pixel locations in 1d to the input tensor, # as non-integer locations. # the following fomrula is derived in the paper # "From Discrete to Continuous Convolutions" by Shocher et al. return ( out_coordinates / float(scale_factor) + (in_sz - 1) / 2 - (out_sz - 1) / (2 * float(scale_factor)) ) def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device): # for each output pixel, map which input pixels influence it, in 1d. # we start by calculating the leftmost neighbor, using half of the window # size (eps is for when boundary is exact int) left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) # then we simply take all the pixel centers in the field by counting # window size pixels from the left boundary ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device) return left_boundaries[:, None] + ordinal_numbers def calc_pad_sz( in_sz, out_sz, field_of_view, projected_grid, scale_factor, dim_by_convs, fw, device ): if not dim_by_convs: # determine padding according to neighbor coords out of bound. # this is a generalized notion of padding, when pad<0 it means crop pad_sz = [-field_of_view[0, 0].item(), field_of_view[-1, -1].item() - in_sz + 1] # since input image will be changed by padding, coordinates of both # field_of_view and projected_grid need to be updated field_of_view += pad_sz[0] projected_grid += pad_sz[0] else: # only used for by_convs, to calc the boundaries of each filter the # number of distinct convolutions is the numerator of the scale factor num_convs, stride = scale_factor.numerator, scale_factor.denominator # calculate left and right boundaries for each conv. left can also be # negative right can be bigger than in_sz. such cases imply padding if # needed. however if# both are in-bounds, it means we need to crop, # practically apply the conv only on part of the image. left_pads = -field_of_view[:, 0] # next calc is tricky, explanation by rows: # 1) counting output pixels between the first position of each filter # to the right boundary of the input # 2) dividing it by number of filters to count how many 'jumps' # each filter does # 3) multiplying by the stride gives us the distance over the input # coords done by all these jumps for each filter # 4) to this distance we add the right boundary of the filter when # placed in its leftmost position. so now we get the right boundary # of that filter in input coord. # 5) the padding size needed is obtained by subtracting the rightmost # input coordinate. if the result is positive padding is needed. if # negative then negative padding means shaving off pixel columns. right_pads = ( ((out_sz - fw_arange(num_convs, fw, device) - 1) // num_convs) # (1) # (2) * stride # (3) + field_of_view[:, -1] # (4) - in_sz + 1 ) # (5) # in the by_convs case pad_sz is a list of left-right pairs. one per # each filter pad_sz = list(zip(left_pads, right_pads)) return pad_sz, projected_grid, field_of_view def get_weights(interp_method, projected_grid, field_of_view): # the set of weights per each output pixels is the result of the chosen # interpolation method applied to the distances between projected grid # locations and the pixel-centers in the field of view (distances are # directed, can be positive or negative) weights = interp_method(projected_grid[:, None] - field_of_view) # we now carefully normalize the weights to sum to 1 per each output pixel sum_weights = weights.sum(1, keepdims=True) sum_weights[sum_weights == 0] = 1 return weights / sum_weights def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode, fw): # for this operation we assume the resized dim is the first one. # so we transpose and will transpose back after multiplying tmp_input = fw_swapaxes(input, dim, 0, fw) # apply padding tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode) # field_of_view is a tensor of order 2: for each output (1d location # along cur dim)- a list of 1d neighbors locations. # note that this whole operations is applied to each dim separately, # this is why it is all in 1d. # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: # for each output pixel (this time indicated in all dims), these are the # values of the neighbors in the 1d field of view. note that we only # consider neighbors along the current dim, but such set exists for every # multi-dim location, hence the final tensor order is image_dims+1. neighbors = tmp_input[field_of_view] # weights is an order 2 tensor: for each output location along 1d- a list # of weights matching the field of view. we augment it with ones, for # broadcasting, so that when multiplies some tensor the weights affect # only its first dim. tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1))) # now we simply multiply the weights with the neighbors, and then sum # along the field of view, to get a single value per out pixel tmp_output = (neighbors * tmp_weights).sum(1) # we transpose back the resized dim to its original position return fw_swapaxes(tmp_output, 0, dim, fw) def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz, pad_mode, fw): # for this operations we assume the resized dim is the last one. # so we transpose and will transpose back after multiplying input = fw_swapaxes(input, dim, -1, fw) # the stride for all convs is the denominator of the scale factor stride, num_convs = scale_factor.denominator, scale_factor.numerator # prepare an empty tensor for the output tmp_out_shape = list(input.shape) tmp_out_shape[-1] = out_sz tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device) # iterate over the conv operations. we have as many as the numerator # of the scale-factor. for each we need boundaries and a filter. for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)): # apply padding (we pad last dim, padding can be negative) pad_dim = input.ndim - 1 tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim) # apply convolution over last dim. store in the output tensor with # positional strides so that when the loop is comlete conv results are # interwind tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride) return fw_swapaxes(tmp_output, -1, dim, fw) def set_scale_and_out_sz( in_shape, out_shape, scale_factors, by_convs, scale_tolerance, max_numerator, eps, fw, ): # eventually we must have both scale-factors and out-sizes for all in/out # dims. however, we support many possible partial arguments if scale_factors is None and out_shape is None: raise ValueError("either scale_factors or out_shape should be " "provided") if out_shape is not None: # if out_shape has less dims than in_shape, we defaultly resize the # first dims for numpy and last dims for torch out_shape = ( list(out_shape) + list(in_shape[len(out_shape) :]) if fw is numpy else list(in_shape[: -len(out_shape)]) + list(out_shape) ) if scale_factors is None: # if no scale given, we calculate it as the out to in ratio # (not recomended) scale_factors = [ out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape) ] if scale_factors is not None: # by default, if a single number is given as scale, we assume resizing # two dims (most common are images with 2 spatial dims) scale_factors = ( scale_factors if isinstance(scale_factors, (list, tuple)) else [scale_factors, scale_factors] ) # if less scale_factors than in_shape dims, we defaultly resize the # first dims for numpy and last dims for torch scale_factors = ( list(scale_factors) + [1] * (len(in_shape) - len(scale_factors)) if fw is numpy else [1] * (len(in_shape) - len(scale_factors)) + list(scale_factors) ) if out_shape is None: # when no out_shape given, it is calculated by multiplying the # scale by the in_shape (not recomended) out_shape = [ ceil(scale_factor * in_sz) for scale_factor, in_sz in zip(scale_factors, in_shape) ] # next part intentionally after out_shape determined for stability # we fix by_convs to be a list of truth values in case it is not if not isinstance(by_convs, (list, tuple)): by_convs = [by_convs] * len(out_shape) # next loop fixes the scale for each dim to be either frac or float. # this is determined by by_convs and by tolerance for scale accuracy. for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)): # first we fractionaize if dim_by_convs: frac = Fraction(1 / sf).limit_denominator(max_numerator) frac = Fraction(numerator=frac.denominator, denominator=frac.numerator) # if accuracy is within tolerance scale will be frac. if not, then # it will be float and the by_convs attr will be set false for # this dim if scale_tolerance is None: scale_tolerance = eps if dim_by_convs and abs(frac - sf) < scale_tolerance: scale_factors[ind] = frac else: scale_factors[ind] = float(sf) by_convs[ind] = False return scale_factors, out_shape, by_convs def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, antialiasing): # antialiasing is "stretching" the field of view according to the scale # factor (only for downscaling). this is low-pass filtering. this # requires modifying both the interpolation (stretching the 1d # function and multiplying by the scale-factor) and the window size. scale_factor = float(scale_factor) if scale_factor >= 1.0 or not antialiasing: return interp_method, support_sz cur_interp_method = lambda arg: scale_factor * interp_method(scale_factor * arg) cur_support_sz = support_sz / scale_factor return cur_interp_method, cur_support_sz def fw_ceil(x, fw): if fw is numpy: return fw.int_(fw.ceil(x)) else: return x.ceil().long() def fw_floor(x, fw): if fw is numpy: return fw.int_(fw.floor(x)) else: return x.floor().long() def fw_cat(x, fw): if fw is numpy: return fw.concatenate(x) else: return fw.cat(x) def fw_swapaxes(x, ax_1, ax_2, fw): if fw is numpy: return fw.swapaxes(x, ax_1, ax_2) else: return x.transpose(ax_1, ax_2) def fw_pad(x, fw, pad_sz, pad_mode, dim=0): if pad_sz == (0, 0): return x if fw is numpy: pad_vec = [(0, 0)] * x.ndim pad_vec[dim] = pad_sz return fw.pad(x, pad_width=pad_vec, mode=pad_mode) else: if x.ndim < 3: x = x[None, None, ...] pad_vec = [0] * ((x.ndim - 2) * 2) pad_vec[0:2] = pad_sz return fw.nn.functional.pad( x.transpose(dim, -1), pad=pad_vec, mode=pad_mode ).transpose(dim, -1) def fw_conv(input, filter, stride): # we want to apply 1d conv to any nd array. the way to do it is to reshape # the input to a 4D tensor. first two dims are singeletons, 3rd dim stores # all the spatial dims that we are not convolving along now. then we can # apply conv2d with a 1xK filter. This convolves the same way all the other # dims stored in the 3d dim. like depthwise conv over these. # TODO: numpy support reshaped_input = input.reshape(1, 1, -1, input.shape[-1]) reshaped_output = torch.nn.functional.conv2d( reshaped_input, filter.view(1, 1, 1, -1), stride=(1, stride) ) return reshaped_output.reshape(*input.shape[:-1], -1) def fw_arange(upper_bound, fw, device): if fw is numpy: return fw.arange(upper_bound) else: return fw.arange(upper_bound, device=device) def fw_empty(shape, fw, device): if fw is numpy: return fw.empty(shape) else: return fw.empty(size=(*shape,), device=device)