(input: Tensor,
size: Union[int, List[int]] = None,
scale_factor: Union[float, List[float]] = None,
mode: str = 'nearest',
align_corners: bool = False,
recompute_scale_factor: bool = False,
antialias: bool = False)
| 964 | |
| 965 | |
| 966 | def interpolate(input: Tensor, |
| 967 | size: Union[int, List[int]] = None, |
| 968 | scale_factor: Union[float, List[float]] = None, |
| 969 | mode: str = 'nearest', |
| 970 | align_corners: bool = False, |
| 971 | recompute_scale_factor: bool = False, |
| 972 | antialias: bool = False) -> Tensor: |
| 973 | ## |
| 974 | ## TODO: Document that function! |
| 975 | ## |
| 976 | |
| 977 | assert not input.is_dynamic() |
| 978 | |
| 979 | input_ndim = input.ndim() |
| 980 | |
| 981 | assert 2 < input_ndim < 6, "Only 3D, 4D and 5D input Tensors supported" |
| 982 | assert (size is not None) ^ ( |
| 983 | scale_factor |
| 984 | is not None), "Only one of out_shape or scales should be defined" |
| 985 | |
| 986 | assert mode in ('nearest', 'linear', 'bilinear', 'bicubic', 'trilinear', |
| 987 | 'nearest-exact') |
| 988 | |
| 989 | if mode == 'trilinear' and input_ndim != 5: |
| 990 | raise ValueError("trilinear only supports 5D tensor") |
| 991 | |
| 992 | if mode == "bilinear" and input_ndim != 4: |
| 993 | raise ValueError("bilinear only supports 4D tensor") |
| 994 | |
| 995 | if mode == "linear" and input_ndim != 3: |
| 996 | raise ValueError("linear only supports 3D tensor") |
| 997 | |
| 998 | layer = default_trtnet().add_resize(input.trt_tensor) |
| 999 | |
| 1000 | input_shape = input.size() |
| 1001 | |
| 1002 | updated_shape = [] |
| 1003 | if scale_factor: |
| 1004 | scale_len = 1 if isinstance(scale_factor, |
| 1005 | (float, int)) else len(scale_factor) |
| 1006 | if scale_len == 1 and isinstance(scale_factor, (float, int)): |
| 1007 | updated_scale = [scale_factor for _ in range(input_ndim - 2)] |
| 1008 | |
| 1009 | else: |
| 1010 | updated_scale = scale_factor |
| 1011 | updated_shape = [ |
| 1012 | int(math.floor(updated_scale[i - 2] * |
| 1013 | input_shape[i])) if i > 1 else input_shape[i] |
| 1014 | for i in range(input_ndim) |
| 1015 | ] |
| 1016 | |
| 1017 | else: |
| 1018 | size_len = 1 if isinstance(size, int) else len(size) |
| 1019 | assert size_len == input_ndim - 2 |
| 1020 | if size_len == 1 and isinstance(size, int): |
| 1021 | updated_size = [size for _ in range(input_ndim - 2)] |
| 1022 | else: |
| 1023 | updated_size = size |
no test coverage detected