Find all assert statements in *mod* and rewrite them.
(self, mod: ast.Module)
| 681 | ) |
| 682 | |
| 683 | def run(self, mod: ast.Module) -> None: |
| 684 | class="st">""class="st">"Find all assert statements in *mod* and rewrite them."class="st">"" |
| 685 | if not mod.body: |
| 686 | class="cm"># Nothing to do. |
| 687 | return |
| 688 | |
| 689 | class="cm"># We'll insert some special imports at the top of the module, but after any |
| 690 | class="cm"># docstrings and __future__ imports, so first figure out where that is. |
| 691 | doc = getattr(mod, class="st">"docstring", None) |
| 692 | expect_docstring = doc is None |
| 693 | if doc is not None and self.is_rewrite_disabled(doc): |
| 694 | return |
| 695 | pos = 0 |
| 696 | for item in mod.body: |
| 697 | match item: |
| 698 | case ast.Expr(value=ast.Constant(value=str() as doc)) if ( |
| 699 | expect_docstring |
| 700 | ): |
| 701 | if self.is_rewrite_disabled(doc): |
| 702 | return |
| 703 | expect_docstring = False |
| 704 | case ast.ImportFrom(level=0, module=class="st">"__future__"): |
| 705 | pass |
| 706 | case _: |
| 707 | break |
| 708 | pos += 1 |
| 709 | class="cm"># Special case: for a decorated function, set the lineno to that of the |
| 710 | class="cm"># first decorator, not the `def`. Issue #4984. |
| 711 | if isinstance(item, ast.FunctionDef) and item.decorator_list: |
| 712 | lineno = item.decorator_list[0].lineno |
| 713 | else: |
| 714 | lineno = item.lineno |
| 715 | class="cm"># Now actually insert the special imports. |
| 716 | aliases = [ |
| 717 | ast.alias(class="st">"builtins", class="st">"@py_builtins", lineno=lineno, col_offset=0), |
| 718 | ast.alias( |
| 719 | class="st">"_pytest.assertion.rewrite", |
| 720 | class="st">"@pytest_ar", |
| 721 | lineno=lineno, |
| 722 | col_offset=0, |
| 723 | ), |
| 724 | ] |
| 725 | imports = [ |
| 726 | ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases |
| 727 | ] |
| 728 | mod.body[pos:pos] = imports |
| 729 | |
| 730 | class="cm"># Collect asserts. |
| 731 | self.scope = (mod,) |
| 732 | nodes: list[ast.AST | Sentinel] = [mod] |
| 733 | while nodes: |
| 734 | node = nodes.pop() |
| 735 | if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef): |
| 736 | self.scope = tuple((*self.scope, node)) |
| 737 | nodes.append(_SCOPE_END_MARKER) |
| 738 | if node == _SCOPE_END_MARKER: |
| 739 | self.scope = self.scope[:-1] |
| 740 | continue |
no test coverage detected