MCPcopy Index your code
hub / github.com/numpy/numpy / _greedy_path

Function _greedy_path

numpy/_core/einsumfunc.py:330–442  ·  view source on GitHub ↗

Finds the path by contracting the best pair until the input list is exhausted. The best pair is found by minimizing the tuple ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing matrix multiplication or inner product operations, then Hadamard like operation

(input_sets, output_set, idx_dict, memory_limit)

Source from the content-addressed store, hash-verified

328 return mod_results
329
330def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
331 """
332 Finds the path by contracting the best pair until the input list is
333 exhausted. The best pair is found by minimizing the tuple
334 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
335 matrix multiplication or inner product operations, then Hadamard like
336 operations, and finally outer operations. Outer products are limited by
337 ``memory_limit``. This algorithm scales cubically with respect to the
338 number of elements in the list ``input_sets``.
339
340 Parameters
341 ----------
342 input_sets : list
343 List of sets that represent the lhs side of the einsum subscript
344 output_set : set
345 Set that represents the rhs side of the overall einsum subscript
346 idx_dict : dictionary
347 Dictionary of index sizes
348 memory_limit : int
349 The maximum number of elements in a temporary array
350
351 Returns
352 -------
353 path : list
354 The greedy contraction order within the memory limit constraint.
355
356 Examples
357 --------
358 >>> isets = [set('abd'), set('ac'), set('bdc')]
359 >>> oset = set()
360 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
361 >>> _greedy_path(isets, oset, idx_sizes, 5000)
362 [(0, 2), (0, 1)]
363 """
364
365 # Handle trivial cases that leaked through
366 if len(input_sets) == 1:
367 return [(0,)]
368 elif len(input_sets) == 2:
369 return [(0, 1)]
370
371 # Build up a naive cost
372 contract = _find_contraction(
373 range(len(input_sets)), input_sets, output_set
374 )
375 idx_result, new_input_sets, idx_removed, idx_contract = contract
376 naive_cost = _flop_count(
377 idx_contract, idx_removed, len(input_sets), idx_dict
378 )
379
380 # Initially iterate over all pairs
381 comb_iter = itertools.combinations(range(len(input_sets)), 2)
382 known_contractions = []
383
384 path_cost = 0
385 path = []
386
387 for iteration in range(len(input_sets) - 1):

Callers 1

einsum_pathFunction · 0.85

Calls 5

_find_contractionFunction · 0.85
_flop_countFunction · 0.85
_update_other_resultsFunction · 0.85
minFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…