MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / _create_tensor

Function _create_tensor

tensorrt_llm/functional.py:605–663  ·  view source on GitHub ↗

A helper function to create a TensorRT LLM Tensor object that encapsulates the connection between the TensorRT tensor (trt.ITensor) and the layer (trt.ILayer) that produces it. That function is expected to be used as: # Insert a new layer in the network using the TensorRT

(trt_tensor: trt.ITensor, producer: trt.ILayer)

Source from the content-addressed store, hash-verified

603
604
605def _create_tensor(trt_tensor: trt.ITensor, producer: trt.ILayer) -> Tensor:
606 '''
607 A helper function to create a TensorRT LLM Tensor object that encapsulates
608 the connection between the TensorRT tensor (trt.ITensor) and the layer
609 (trt.ILayer) that produces it.
610
611 That function is expected to be used as:
612
613 # Insert a new layer in the network using the TensorRT API:
614 layer = default_trtnet().add_<some_layer>(...)
615 # Extract the first output of that layer and connect it to the layer.
616 return _create_tensor(layer.get_output(0), layer)
617
618 That function also sets the precision of the layer/producer to the default
619 precision of the network.
620
621 Parameters:
622 trt_tensor : trt.ITensor
623 The TensorRT tensor to connect to its producer (the layer).
624
625 producer : trt.ILayer
626 The producer.
627
628 Returns:
629 The TensorRT LLM tensor (functional.Tensor) that encapsulates the
630 TensorRT tensor and the layer that produces it. The former is
631 accessible through the attribute 'trt_tensor' and the latter using the
632 attribute 'producer'.
633 ''&#x27;
634 assert trt_tensor is not None
635 assert producer is not None
636
637 # Set the layer name since this is the only
638 # centralized location to pass the name from
639 # module space to the TRT IR
640 default_net()._set_layer_name(producer)
641
642 assert trt_tensor.shape.__len__(
643 ) >= 0, f"tensor {trt_tensor.name} has an invalid shape"
644 tensor = Tensor(name=trt_tensor.name,
645 dtype=trt_tensor.dtype,
646 shape=trt_tensor.shape,
647 is_network_input=False)
648 tensor.trt_tensor = trt_tensor
649 tensor.producer = producer
650
651 # tb.print_stack(limit=10) # FOR DEBUGGING: filter producer.name if needed
652 if default_net().dtype is not None and not default_net().strongly_typed:
653 if producer.type not in [
654 trt.LayerType.SHAPE, trt.LayerType.CONSTANT,
655 trt.LayerType.GATHER, trt.LayerType.CONCATENATION
656 ]:
657 producer.precision = default_net().dtype
658 assert tensor is not None
659
660 if gw.FLayerInfoMemo.instance().cur_flayer is not None:
661 gw.FLayerInfoMemo.instance().cur_flayer.layer_name = producer.name
662

Callers 15

flash_attention_opFunction · 0.90
__call__Method · 0.85
activationFunction · 0.85
clipFunction · 0.85
castFunction · 0.85
flipFunction · 0.85
interpolateFunction · 0.85
matmulFunction · 0.85
gemm_swigluFunction · 0.85
constantFunction · 0.85
sliceFunction · 0.85
padFunction · 0.85

Calls 5

default_netFunction · 0.85
_set_layer_nameMethod · 0.80
TensorClass · 0.70
__len__Method · 0.45
instanceMethod · 0.45

Tested by

no test coverage detected