Finds the path by contracting the best pair until the input list is exhausted. The best pair is found by minimizing the tuple ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing matrix multiplication or inner product operations, then Hadamard like operation
(input_sets, output_set, idx_dict, memory_limit)
| 328 | return mod_results |
| 329 | |
| 330 | def _greedy_path(input_sets, output_set, idx_dict, memory_limit): |
| 331 | """ |
| 332 | Finds the path by contracting the best pair until the input list is |
| 333 | exhausted. The best pair is found by minimizing the tuple |
| 334 | ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing |
| 335 | matrix multiplication or inner product operations, then Hadamard like |
| 336 | operations, and finally outer operations. Outer products are limited by |
| 337 | ``memory_limit``. This algorithm scales cubically with respect to the |
| 338 | number of elements in the list ``input_sets``. |
| 339 | |
| 340 | Parameters |
| 341 | ---------- |
| 342 | input_sets : list |
| 343 | List of sets that represent the lhs side of the einsum subscript |
| 344 | output_set : set |
| 345 | Set that represents the rhs side of the overall einsum subscript |
| 346 | idx_dict : dictionary |
| 347 | Dictionary of index sizes |
| 348 | memory_limit : int |
| 349 | The maximum number of elements in a temporary array |
| 350 | |
| 351 | Returns |
| 352 | ------- |
| 353 | path : list |
| 354 | The greedy contraction order within the memory limit constraint. |
| 355 | |
| 356 | Examples |
| 357 | -------- |
| 358 | >>> isets = [set('abd'), set('ac'), set('bdc')] |
| 359 | >>> oset = set() |
| 360 | >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} |
| 361 | >>> _greedy_path(isets, oset, idx_sizes, 5000) |
| 362 | [(0, 2), (0, 1)] |
| 363 | """ |
| 364 | |
| 365 | # Handle trivial cases that leaked through |
| 366 | if len(input_sets) == 1: |
| 367 | return [(0,)] |
| 368 | elif len(input_sets) == 2: |
| 369 | return [(0, 1)] |
| 370 | |
| 371 | # Build up a naive cost |
| 372 | contract = _find_contraction( |
| 373 | range(len(input_sets)), input_sets, output_set |
| 374 | ) |
| 375 | idx_result, new_input_sets, idx_removed, idx_contract = contract |
| 376 | naive_cost = _flop_count( |
| 377 | idx_contract, idx_removed, len(input_sets), idx_dict |
| 378 | ) |
| 379 | |
| 380 | # Initially iterate over all pairs |
| 381 | comb_iter = itertools.combinations(range(len(input_sets)), 2) |
| 382 | known_contractions = [] |
| 383 | |
| 384 | path_cost = 0 |
| 385 | path = [] |
| 386 | |
| 387 | for iteration in range(len(input_sets) - 1): |
no test coverage detected
searching dependent graphs…