MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / quantize_per_token

Function quantize_per_token

tensorrt_llm/quantization/functional.py:790–858  ·  view source on GitHub ↗
(
    x: Tensor,
    clamp_val: Optional[Tensor] = None,
    scale_dtype='float32',
    sum_per_token: bool = False,
    sum_dtype='float32',
)

Source from the content-addressed store, hash-verified

788
789
790def quantize_per_token(
791 x: Tensor,
792 clamp_val: Optional[Tensor] = None,
793 scale_dtype='float32',
794 sum_per_token: bool = False,
795 sum_dtype='float32',
796) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
797 if not default_net().plugin_config.quantize_per_token_plugin:
798 x = cast(x, 'float32')
799 xmax = x.abs().max(-1, keepdim=True)
800 scales = xmax / 127.0
801 out = x * 127.0 / xmax
802 out = round(out)
803 out = clip(out, -128, 127)
804 quantized = cast(out, 'int8')
805 if not sum_per_token:
806 return quantized, scales
807 sums = sum(x, -1, keepdim=True)
808 if sum_dtype is not None and str_dtype_to_trt(sum_dtype) != sums.dtype:
809 sums = cast(sums, sum_dtype)
810 return quantized, scales, sums
811
812 plg_creator = trt.get_plugin_registry().get_plugin_creator(
813 'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE)
814 assert plg_creator is not None
815
816 output_type = trt.PluginField("type_id", np.array([int(trt.int8)],
817 np.int32),
818 trt.PluginFieldType.INT32)
819 quant_mode = trt.PluginField(
820 "quant_mode",
821 np.array([int(QuantMode.use_smooth_quant(per_token=True))], np.int32),
822 trt.PluginFieldType.INT32)
823 clamp_enabled = trt.PluginField("clamp_enabled",
824 np.array([clamp_val is not None], np.int8),
825 trt.PluginFieldType.INT8)
826
827 sum_per_token_pf = trt.PluginField("sum_per_token",
828 np.array([int(sum_per_token)], np.int32),
829 trt.PluginFieldType.INT32)
830
831 pfc = trt.PluginFieldCollection(
832 [output_type, quant_mode, clamp_enabled, sum_per_token_pf])
833 quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin", pfc)
834
835 plug_inputs = [x.trt_tensor]
836 if clamp_val:
837 plug_inputs += [clamp_val.trt_tensor]
838 layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
839 if not default_net().strongly_typed:
840 layer.get_output(0).set_dynamic_range(-127, 127)
841 _add_plugin_info(layer, plg_creator, "quantize_per_token_plugin", pfc)
842
843 quantized = _create_tensor(layer.get_output(0), layer)
844 scales = _create_tensor(layer.get_output(1), layer)
845
846 # TODO: The plugin should be able to directly output float16 scales to avoid a cast
847 if scale_dtype is not None and str_dtype_to_trt(

Callers 10

forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
smooth_quant_layer_normFunction · 0.85
smooth_quant_rms_normFunction · 0.85

Calls 13

default_netFunction · 0.85
castFunction · 0.85
sumFunction · 0.85
str_dtype_to_trtFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
absMethod · 0.80
use_smooth_quantMethod · 0.80
create_pluginMethod · 0.80
clipFunction · 0.50
maxMethod · 0.45

Tested by 1