Group a series of comparison operands together chained by any operand in the 'operators_to_group' set. All other pairwise operands are kept in groups of size 2. For example, suppose we have the input comparison expression: x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x
(
pairwise_comparisons: Iterable[tuple[str, Expression, Expression]],
operand_to_literal_hash: Mapping[int, Key],
operators_to_group: set[str],
)
| 9408 | |
| 9409 | |
| 9410 | def group_comparison_operands( |
| 9411 | pairwise_comparisons: Iterable[tuple[str, Expression, Expression]], |
| 9412 | operand_to_literal_hash: Mapping[int, Key], |
| 9413 | operators_to_group: set[str], |
| 9414 | ) -> list[tuple[str, list[int]]]: |
| 9415 | """Group a series of comparison operands together chained by any operand |
| 9416 | in the 'operators_to_group' set. All other pairwise operands are kept in |
| 9417 | groups of size 2. |
| 9418 | |
| 9419 | For example, suppose we have the input comparison expression: |
| 9420 | |
| 9421 | x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8 |
| 9422 | |
| 9423 | If we get these expressions in a pairwise way (e.g. by calling ComparisonExpr's |
| 9424 | 'pairwise()' method), we get the following as input: |
| 9425 | |
| 9426 | [('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('<', x3, x4), |
| 9427 | ('is', x4, x5), ('is', x5, x6), ('is not', x6, x7), ('is not', x7, x8)] |
| 9428 | |
| 9429 | If `operators_to_group` is the set {'==', 'is'}, this function will produce |
| 9430 | the following "simplified operator list": |
| 9431 | |
| 9432 | [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]), |
| 9433 | ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])] |
| 9434 | |
| 9435 | Note that (a) we yield *indices* to the operands rather then the operand |
| 9436 | expressions themselves and that (b) operands used in a consecutive chain |
| 9437 | of '==' or 'is' are grouped together. |
| 9438 | |
| 9439 | If two of these chains happen to contain operands with the same underlying |
| 9440 | literal hash (e.g. are assignable and correspond to the same expression), |
| 9441 | we combine those chains together. For example, if we had: |
| 9442 | |
| 9443 | same == x < y == same |
| 9444 | |
| 9445 | ...and if 'operand_to_literal_hash' contained the same values for the indices |
| 9446 | 0 and 3, we'd produce the following output: |
| 9447 | |
| 9448 | [("==", [0, 1, 2, 3]), ("<", [1, 2])] |
| 9449 | |
| 9450 | But if the 'operand_to_literal_hash' did *not* contain an entry, we'd instead |
| 9451 | default to returning: |
| 9452 | |
| 9453 | [("==", [0, 1]), ("<", [1, 2]), ("==", [2, 3])] |
| 9454 | |
| 9455 | This function is currently only used to assist with type-narrowing refinements |
| 9456 | and is extracted out to a helper function so we can unit test it. |
| 9457 | """ |
| 9458 | groups: dict[str, DisjointDict[Key, int]] = {op: DisjointDict() for op in operators_to_group} |
| 9459 | |
| 9460 | simplified_operator_list: list[tuple[str, list[int]]] = [] |
| 9461 | last_operator: str | None = None |
| 9462 | current_indices: set[int] = set() |
| 9463 | current_hashes: set[Key] = set() |
| 9464 | for i, (operator, left_expr, right_expr) in enumerate(pairwise_comparisons): |
| 9465 | if last_operator is None: |
| 9466 | last_operator = operator |
| 9467 |
searching dependent graphs…