| 788 | |
| 789 | |
| 790 | def quantize_per_token( |
| 791 | x: Tensor, |
| 792 | clamp_val: Optional[Tensor] = None, |
| 793 | scale_dtype='float32', |
| 794 | sum_per_token: bool = False, |
| 795 | sum_dtype='float32', |
| 796 | ) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]: |
| 797 | if not default_net().plugin_config.quantize_per_token_plugin: |
| 798 | x = cast(x, 'float32') |
| 799 | xmax = x.abs().max(-1, keepdim=True) |
| 800 | scales = xmax / 127.0 |
| 801 | out = x * 127.0 / xmax |
| 802 | out = round(out) |
| 803 | out = clip(out, -128, 127) |
| 804 | quantized = cast(out, 'int8') |
| 805 | if not sum_per_token: |
| 806 | return quantized, scales |
| 807 | sums = sum(x, -1, keepdim=True) |
| 808 | if sum_dtype is not None and str_dtype_to_trt(sum_dtype) != sums.dtype: |
| 809 | sums = cast(sums, sum_dtype) |
| 810 | return quantized, scales, sums |
| 811 | |
| 812 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 813 | 'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 814 | assert plg_creator is not None |
| 815 | |
| 816 | output_type = trt.PluginField("type_id", np.array([int(trt.int8)], |
| 817 | np.int32), |
| 818 | trt.PluginFieldType.INT32) |
| 819 | quant_mode = trt.PluginField( |
| 820 | "quant_mode", |
| 821 | np.array([int(QuantMode.use_smooth_quant(per_token=True))], np.int32), |
| 822 | trt.PluginFieldType.INT32) |
| 823 | clamp_enabled = trt.PluginField("clamp_enabled", |
| 824 | np.array([clamp_val is not None], np.int8), |
| 825 | trt.PluginFieldType.INT8) |
| 826 | |
| 827 | sum_per_token_pf = trt.PluginField("sum_per_token", |
| 828 | np.array([int(sum_per_token)], np.int32), |
| 829 | trt.PluginFieldType.INT32) |
| 830 | |
| 831 | pfc = trt.PluginFieldCollection( |
| 832 | [output_type, quant_mode, clamp_enabled, sum_per_token_pf]) |
| 833 | quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin", pfc) |
| 834 | |
| 835 | plug_inputs = [x.trt_tensor] |
| 836 | if clamp_val: |
| 837 | plug_inputs += [clamp_val.trt_tensor] |
| 838 | layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug) |
| 839 | if not default_net().strongly_typed: |
| 840 | layer.get_output(0).set_dynamic_range(-127, 127) |
| 841 | _add_plugin_info(layer, plg_creator, "quantize_per_token_plugin", pfc) |
| 842 | |
| 843 | quantized = _create_tensor(layer.get_output(0), layer) |
| 844 | scales = _create_tensor(layer.get_output(1), layer) |
| 845 | |
| 846 | # TODO: The plugin should be able to directly output float16 scales to avoid a cast |
| 847 | if scale_dtype is not None and str_dtype_to_trt( |