| 907 | |
| 908 | |
| 909 | def quantize_tensor(x, scale): |
| 910 | if not default_net().plugin_config.quantize_tensor_plugin: |
| 911 | if scale.dtype == str_dtype_to_trt('float32'): |
| 912 | x = cast(x, 'float32') |
| 913 | scaled = x * scale |
| 914 | rounded = round(scaled) |
| 915 | clipped = clip(rounded, -128, 127) |
| 916 | quantized = cast(clipped, 'int8') |
| 917 | else: |
| 918 | scale = cast(scale, 'float32') |
| 919 | |
| 920 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 921 | 'QuantizeTensor', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 922 | assert plg_creator is not None |
| 923 | |
| 924 | pfc = trt.PluginFieldCollection([]) |
| 925 | quantize_plug = plg_creator.create_plugin("quantize_tensor_plugin", pfc) |
| 926 | |
| 927 | plug_inputs = [x.trt_tensor, scale.trt_tensor] |
| 928 | layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug) |
| 929 | if not default_net().strongly_typed: |
| 930 | layer.get_output(0).set_dynamic_range(-127, 127) |
| 931 | _add_plugin_info(layer, plg_creator, "quantize_tensor_plugin", pfc) |
| 932 | |
| 933 | quantized = _create_tensor(layer.get_output(0), layer) |
| 934 | return quantized |
| 935 | |
| 936 | |
| 937 | def symmetric_quantize_last_axis_of_batched_matrix(weight, quant_mode): |