The class to represent symbolic tensors. Only contains dtype and shape information for users to write their own shape/dtype inference function.
| 229 | |
| 230 | |
| 231 | class SymTensor: |
| 232 | """The class to represent symbolic tensors. |
| 233 | |
| 234 | Only contains dtype and shape information for users to write their own shape/dtype inference function. |
| 235 | """ |
| 236 | |
| 237 | def __init__( |
| 238 | self, |
| 239 | dtype: Union[torch.dtype, np.dtype, str, trt.DataType, Type[None]], |
| 240 | shape: Union[ShapeExpr, Sequence[int]], |
| 241 | ): |
| 242 | self.dtype = dtype |
| 243 | self.shape = shape |
| 244 | |
| 245 | @property |
| 246 | def shape(self) -> Union[ShapeExpr, Sequence[int]]: |
| 247 | return self._shape |
| 248 | |
| 249 | @shape.setter |
| 250 | def shape(self, shape: Union[ShapeExpr, Sequence[int]]): |
| 251 | assert isinstance(shape, (ShapeExpr, list, tuple)) |
| 252 | if isinstance(shape, (list, tuple)): |
| 253 | for i in shape: |
| 254 | assert isinstance(i, int) |
| 255 | self._shape = shape |
| 256 | |
| 257 | @property |
| 258 | def dtype(self) -> Union[trt.DataType, Type[None]]: |
| 259 | return self._dtype |
| 260 | |
| 261 | @dtype.setter |
| 262 | def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType, Type[None]]): |
| 263 | if isinstance(dtype, torch.dtype): |
| 264 | self._dtype = torch_dtype_to_trt(dtype) |
| 265 | elif isinstance(dtype, str): |
| 266 | self._dtype = str_dtype_to_trt(dtype) |
| 267 | elif isinstance(dtype, np.dtype): |
| 268 | self._dtype = np_dtype_to_trt(dtype) |
| 269 | elif isinstance(dtype, trt.DataType): |
| 270 | self._dtype = dtype |
| 271 | elif dtype is None: |
| 272 | self._dtype = None |
| 273 | else: |
| 274 | raise TypeError(f"Unsupported dtype: {dtype}") |
| 275 | |
| 276 | |
| 277 | def _convert_return_value_to_list(ret): |
no outgoing calls
no test coverage detected