A reproduction of einsum c side einsum parsing in python. Returns ------- input_strings : str Parsed input strings output_string : str Parsed output string operands : list of array_like The operands to use in the numpy contraction Examples -
(operands)
| 443 | |
| 444 | |
| 445 | def _parse_einsum_input(operands): |
| 446 | """ |
| 447 | A reproduction of einsum c side einsum parsing in python. |
| 448 | |
| 449 | Returns |
| 450 | ------- |
| 451 | input_strings : str |
| 452 | Parsed input strings |
| 453 | output_string : str |
| 454 | Parsed output string |
| 455 | operands : list of array_like |
| 456 | The operands to use in the numpy contraction |
| 457 | |
| 458 | Examples |
| 459 | -------- |
| 460 | The operand list is simplified to reduce printing: |
| 461 | |
| 462 | >>> np.random.seed(123) |
| 463 | >>> a = np.random.rand(4, 4) |
| 464 | >>> b = np.random.rand(4, 4, 4) |
| 465 | >>> _parse_einsum_input(('...a,...a->...', a, b)) |
| 466 | ('za,xza', 'xz', [a, b]) # may vary |
| 467 | |
| 468 | >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) |
| 469 | ('za,xza', 'xz', [a, b]) # may vary |
| 470 | """ |
| 471 | |
| 472 | if len(operands) == 0: |
| 473 | raise ValueError("No input operands") |
| 474 | |
| 475 | if isinstance(operands[0], str): |
| 476 | subscripts = operands[0].replace(" ", "") |
| 477 | operands = [asanyarray(v) for v in operands[1:]] |
| 478 | |
| 479 | # Ensure all characters are valid |
| 480 | for s in subscripts: |
| 481 | if s in '.,->': |
| 482 | continue |
| 483 | if s not in einsum_symbols: |
| 484 | raise ValueError(f"Character {s} is not a valid symbol.") |
| 485 | |
| 486 | else: |
| 487 | tmp_operands = list(operands) |
| 488 | operand_list = [] |
| 489 | subscript_list = [] |
| 490 | for p in range(len(operands) // 2): |
| 491 | operand_list.append(tmp_operands.pop(0)) |
| 492 | subscript_list.append(tmp_operands.pop(0)) |
| 493 | |
| 494 | output_list = tmp_operands[-1] if len(tmp_operands) else None |
| 495 | operands = [asanyarray(v) for v in operand_list] |
| 496 | subscripts = "" |
| 497 | last = len(subscript_list) - 1 |
| 498 | for num, sub in enumerate(subscript_list): |
| 499 | for s in sub: |
| 500 | if s is Ellipsis: |
| 501 | subscripts += "..." |
| 502 | else: |
no test coverage detected
searching dependent graphs…