MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / quantize_fp8_per_token

Function quantize_fp8_per_token

tensorrt_llm/quantization/functional.py:861–906  ·  view source on GitHub ↗
(x: Tensor,
                           clamp_val: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

859
860
861def quantize_fp8_per_token(x: Tensor,
862 clamp_val: Optional[Tensor] = None) -> Tuple[Tensor]:
863 if not default_net().plugin_config.quantize_per_token_plugin:
864 x = cast(x, 'float32')
865 xmax = x.abs().max(-1, keepdim=True)
866 scale = xmax / 448.0
867 out = x * 448.0 / xmax
868 out = round(out)
869 out = clip(out, -448, 448)
870 quantized_out = cast(out, 'fp8')
871 return quantized_out, scale
872 else:
873 plg_creator = trt.get_plugin_registry().get_plugin_creator(
874 'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE)
875 assert plg_creator is not None
876
877 output_type = trt.PluginField("type_id",
878 np.array([int(trt.fp8)], np.int32),
879 trt.PluginFieldType.INT32)
880 quant_mode = trt.PluginField(
881 "quant_mode",
882 np.array([int(QuantMode.from_description(use_fp8_rowwise=True))],
883 np.int32), trt.PluginFieldType.INT32)
884 clamp_enabled = trt.PluginField(
885 "clamp_enabled", np.array([clamp_val is not None], np.int8),
886 trt.PluginFieldType.INT8)
887 sum_per_token_pf = trt.PluginField("sum_per_token",
888 np.array([int(False)], np.int32),
889 trt.PluginFieldType.INT32)
890 pfc = trt.PluginFieldCollection(
891 [output_type, quant_mode, clamp_enabled, sum_per_token_pf])
892 quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin",
893 pfc)
894
895 plug_inputs = [x.trt_tensor]
896 if clamp_val:
897 plug_inputs += [clamp_val.trt_tensor]
898 layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
899 if not default_net().strongly_typed:
900 layer.get_output(0).set_dynamic_range(-448, 448)
901 _add_plugin_info(layer, plg_creator, "quantize_per_token_plugin", pfc)
902
903 quantized = _create_tensor(layer.get_output(0), layer)
904 scales = _create_tensor(layer.get_output(1), layer)
905
906 return quantized, scales
907
908
909def quantize_tensor(x, scale):

Callers 4

forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 11

default_netFunction · 0.85
castFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
absMethod · 0.80
from_descriptionMethod · 0.80
create_pluginMethod · 0.80
clipFunction · 0.50
maxMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected