MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / create_allreduce_plugin

Function create_allreduce_plugin

tensorrt_llm/functional.py:3980–4042  ·  view source on GitHub ↗
(
    network: trt.INetworkDefinition,
    tensor: trt.ITensor,
    workspace: Optional[trt.ITensor],
    group: np.array,
    dtype: trt.DataType,
    all_reduce_params: AllReduceParams,
)

Source from the content-addressed store, hash-verified

3978
3979
3980def create_allreduce_plugin(
3981 network: trt.INetworkDefinition,
3982 tensor: trt.ITensor,
3983 workspace: Optional[trt.ITensor],
3984 group: np.array,
3985 dtype: trt.DataType,
3986 all_reduce_params: AllReduceParams,
3987):
3988 allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
3989 'AllReduce', '1', TRT_LLM_PLUGIN_NAMESPACE)
3990 assert allreduce_plg_creator is not None
3991
3992 pf_group = trt.PluginField("group", group, trt.PluginFieldType.INT32)
3993 pf_dtype = trt.PluginField("type_id", np.array([int(dtype)], np.int32),
3994 trt.PluginFieldType.INT32)
3995 pfc = [pf_group, pf_dtype]
3996 p_strategy = trt.PluginField(
3997 "strategy", np.array([int(all_reduce_params.strategy)], np.int8),
3998 trt.PluginFieldType.INT8)
3999 pfc.append(p_strategy)
4000 p_fusion_op = trt.PluginField(
4001 "fusion_op", np.array([int(all_reduce_params.fusion_op)], np.int8),
4002 trt.PluginFieldType.INT8)
4003 pfc.append(p_fusion_op)
4004 p_eps = trt.PluginField(
4005 "eps", np.array([float(all_reduce_params.eps)], np.float32),
4006 trt.PluginFieldType.FLOAT32)
4007 pfc.append(p_eps)
4008 p_affine = trt.PluginField(
4009 "affine", np.array([int(all_reduce_params.has_affine())], np.int8),
4010 trt.PluginFieldType.INT8)
4011 pfc.append(p_affine)
4012 p_bias = trt.PluginField(
4013 "bias", np.array([int(all_reduce_params.has_bias())], np.int8),
4014 trt.PluginFieldType.INT8)
4015 pfc.append(p_bias)
4016 p_scale = trt.PluginField(
4017 "scale", np.array([int(all_reduce_params.has_scale())], np.int8),
4018 trt.PluginFieldType.INT8)
4019 pfc.append(p_scale)
4020
4021 pfc = trt.PluginFieldCollection(pfc)
4022 ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc)
4023 plug_inputs = [tensor]
4024 if all_reduce_params.strategy not in {
4025 AllReduceStrategy.NCCL, AllReduceStrategy.UB,
4026 AllReduceStrategy.NCCL_SYMMETRIC
4027 }:
4028 plug_inputs.append(workspace)
4029 if all_reduce_params.fusion_op != AllReduceFusionOp.NONE:
4030 if all_reduce_params.has_bias() == 1:
4031 plug_inputs.append(all_reduce_params.bias.trt_tensor)
4032 plug_inputs.append(all_reduce_params.residual.trt_tensor)
4033 if all_reduce_params.has_affine() == 1:
4034 plug_inputs.append(all_reduce_params.norm_weight.trt_tensor)
4035 if all_reduce_params.fusion_op == AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM:
4036 plug_inputs.append(
4037 all_reduce_params.norm_pre_residual_weight.trt_tensor)

Callers 1

allreduceFunction · 0.85

Calls 5

has_affineMethod · 0.80
has_scaleMethod · 0.80
create_pluginMethod · 0.80
appendMethod · 0.45
has_biasMethod · 0.45

Tested by

no test coverage detected