Add a cast operation. For an input tensor of type INT8, this function sets the dynamic range of the input to [-127, 127] for automatic dequantization. For a cast into INT8, that function sets the dynamic range of the output to [-127, 127] for automatic quantization. Parame
(input: Tensor, dtype: Union[str, trt.DataType])
| 874 | |
| 875 | |
| 876 | def cast(input: Tensor, dtype: Union[str, trt.DataType]): |
| 877 | ''' |
| 878 | Add a cast operation. |
| 879 | |
| 880 | For an input tensor of type INT8, this function sets the dynamic range of |
| 881 | the input to [-127, 127] for automatic dequantization. For a cast into |
| 882 | INT8, that function sets the dynamic range of the output to [-127, 127] for |
| 883 | automatic quantization. |
| 884 | |
| 885 | Parameters: |
| 886 | input : Tensor |
| 887 | The input tensor on which the cast is applied. |
| 888 | |
| 889 | dtype : str or trt.DataType |
| 890 | The data type of the output tensor after the cast. When 'dtype' is |
| 891 | provided as a string, it must be a name amongst the valid names. |
| 892 | See _str_to_trt_dtype_dict in _utils.py for a list of supported |
| 893 | types and type names. |
| 894 | |
| 895 | Returns: |
| 896 | The tensor produced by the inserted layer. |
| 897 | ''' |
| 898 | if isinstance(dtype, str): |
| 899 | cvt_dtype = str_dtype_to_trt(dtype) |
| 900 | elif isinstance(dtype, trt.DataType): |
| 901 | cvt_dtype = dtype |
| 902 | else: |
| 903 | raise TypeError("%s is not supported" % type(dtype)) |
| 904 | |
| 905 | if input.dtype == cvt_dtype: |
| 906 | # If input type and cast dtype are the same, do nothing |
| 907 | return input |
| 908 | |
| 909 | layer = default_trtnet().add_cast(input.trt_tensor, cvt_dtype) |
| 910 | if not default_net().strongly_typed: |
| 911 | layer.set_output_type(0, cvt_dtype) |
| 912 | output = _create_tensor(layer.get_output(0), layer) |
| 913 | if not default_net().strongly_typed: |
| 914 | if input.dtype == str_dtype_to_trt('int8'): |
| 915 | layer.get_input(0).set_dynamic_range(-127, 127) |
| 916 | if cvt_dtype == str_dtype_to_trt('int8'): |
| 917 | layer.get_output(0).set_dynamic_range(-127, 127) |
| 918 | |
| 919 | return output |
| 920 | |
| 921 | |
| 922 | def flip(input: Tensor, dims: Sequence[int]) -> Tensor: |