MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / run

Method run

tensorrt_llm/runtime/session.py:268–298  ·  view source on GitHub ↗

@brief: Run the TensorRT engine with the given inputs and outputs @param inputs: dict of input tensors, key is tensor name, value is tensor pointer or torch tensor @param outputs: dict of output tensors, key is tensor name, value is tensor pointer or torch tensor @pa

(self,
            inputs: Dict[str, Any],
            outputs: Dict[str, Any],
            stream,
            context=None)

Source from the content-addressed store, hash-verified

266 raise
267
268 def run(self,
269 inputs: Dict[str, Any],
270 outputs: Dict[str, Any],
271 stream,
272 context=None) -> bool:
273 '''
274 @brief: Run the TensorRT engine with the given inputs and outputs
275 @param inputs: dict of input tensors, key is tensor name, value is tensor pointer or torch tensor
276 @param outputs: dict of output tensors, key is tensor name, value is tensor pointer or torch tensor
277 @param stream: cuda stream to enqueue the TensorRT engine on
278 @param context: TensorRT execution context, if None, use the default context
279 @return: True if enqueue succeeded, note the enqueue is an async call,
280 returning True does not mean the execution is finished
281 '''
282 # enqueue to the default context if context is not specified
283 if context is None:
284 context = self.context
285
286 import torch
287 for tensor_name in inputs:
288 tensor = inputs[tensor_name]
289 ptr = tensor.data_ptr() if isinstance(tensor,
290 torch.Tensor) else tensor
291 context.set_tensor_address(tensor_name, ptr)
292 for tensor_name in outputs:
293 tensor = outputs[tensor_name]
294 ptr = tensor.data_ptr() if isinstance(tensor,
295 torch.Tensor) else tensor
296 context.set_tensor_address(tensor_name, ptr)
297 ok = context.execute_async_v3(stream)
298 return ok
299
300 def _debug_run(self,
301 inputs: Dict[str, "torch.Tensor"],

Callers 15

_debug_runMethod · 0.95
wrapperFunction · 0.45
build_cpp_examplesFunction · 0.45
run_cmdFunction · 0.45
verify_l0_test_listsFunction · 0.45
verify_qa_test_listsFunction · 0.45
verify_waive_listFunction · 0.45
_check_banned_symbolsFunction · 0.45
compressFunction · 0.45
get_wheel_from_packageFunction · 0.45
run_shell_commandFunction · 0.45

Calls 1

data_ptrMethod · 0.45

Tested by 5

wrapperFunction · 0.36
verify_l0_test_listsFunction · 0.36
verify_qa_test_listsFunction · 0.36
verify_waive_listFunction · 0.36