A helper function to create a TensorRT LLM Tensor object that encapsulates the connection between the TensorRT tensor (trt.ITensor) and the layer (trt.ILayer) that produces it. That function is expected to be used as: # Insert a new layer in the network using the TensorRT
(trt_tensor: trt.ITensor, producer: trt.ILayer)
| 603 | |
| 604 | |
| 605 | def _create_tensor(trt_tensor: trt.ITensor, producer: trt.ILayer) -> Tensor: |
| 606 | ''' |
| 607 | A helper function to create a TensorRT LLM Tensor object that encapsulates |
| 608 | the connection between the TensorRT tensor (trt.ITensor) and the layer |
| 609 | (trt.ILayer) that produces it. |
| 610 | |
| 611 | That function is expected to be used as: |
| 612 | |
| 613 | # Insert a new layer in the network using the TensorRT API: |
| 614 | layer = default_trtnet().add_<some_layer>(...) |
| 615 | # Extract the first output of that layer and connect it to the layer. |
| 616 | return _create_tensor(layer.get_output(0), layer) |
| 617 | |
| 618 | That function also sets the precision of the layer/producer to the default |
| 619 | precision of the network. |
| 620 | |
| 621 | Parameters: |
| 622 | trt_tensor : trt.ITensor |
| 623 | The TensorRT tensor to connect to its producer (the layer). |
| 624 | |
| 625 | producer : trt.ILayer |
| 626 | The producer. |
| 627 | |
| 628 | Returns: |
| 629 | The TensorRT LLM tensor (functional.Tensor) that encapsulates the |
| 630 | TensorRT tensor and the layer that produces it. The former is |
| 631 | accessible through the attribute 'trt_tensor' and the latter using the |
| 632 | attribute 'producer'. |
| 633 | ''' |
| 634 | assert trt_tensor is not None |
| 635 | assert producer is not None |
| 636 | |
| 637 | # Set the layer name since this is the only |
| 638 | # centralized location to pass the name from |
| 639 | # module space to the TRT IR |
| 640 | default_net()._set_layer_name(producer) |
| 641 | |
| 642 | assert trt_tensor.shape.__len__( |
| 643 | ) >= 0, f"tensor {trt_tensor.name} has an invalid shape" |
| 644 | tensor = Tensor(name=trt_tensor.name, |
| 645 | dtype=trt_tensor.dtype, |
| 646 | shape=trt_tensor.shape, |
| 647 | is_network_input=False) |
| 648 | tensor.trt_tensor = trt_tensor |
| 649 | tensor.producer = producer |
| 650 | |
| 651 | # tb.print_stack(limit=10) # FOR DEBUGGING: filter producer.name if needed |
| 652 | if default_net().dtype is not None and not default_net().strongly_typed: |
| 653 | if producer.type not in [ |
| 654 | trt.LayerType.SHAPE, trt.LayerType.CONSTANT, |
| 655 | trt.LayerType.GATHER, trt.LayerType.CONCATENATION |
| 656 | ]: |
| 657 | producer.precision = default_net().dtype |
| 658 | assert tensor is not None |
| 659 | |
| 660 | if gw.FLayerInfoMemo.instance().cur_flayer is not None: |
| 661 | gw.FLayerInfoMemo.instance().cur_flayer.layer_name = producer.name |
| 662 |
no test coverage detected