DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor.
| 198 | |
| 199 | |
| 200 | class DefinedVariableTracker: |
| 201 | """DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor.""" |
| 202 | |
| 203 | def __init__(self) -> None: |
| 204 | # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement. |
| 205 | self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)] |
| 206 | # disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful |
| 207 | # in things like try/except/finally statements. |
| 208 | self.disable_branch_skip = False |
| 209 | self.in_finally = False |
| 210 | |
| 211 | def copy(self) -> DefinedVariableTracker: |
| 212 | result = DefinedVariableTracker() |
| 213 | result.scopes = [s.copy() for s in self.scopes] |
| 214 | result.disable_branch_skip = self.disable_branch_skip |
| 215 | result.in_finally = self.in_finally |
| 216 | return result |
| 217 | |
| 218 | def _scope(self) -> Scope: |
| 219 | assert len(self.scopes) > 0 |
| 220 | return self.scopes[-1] |
| 221 | |
| 222 | def enter_scope(self, scope_type: ScopeType) -> None: |
| 223 | assert len(self._scope().branch_stmts) > 0 |
| 224 | initial_state = None |
| 225 | if scope_type == ScopeType.Generator: |
| 226 | # Generators are special because they inherit the outer scope. |
| 227 | initial_state = self._scope().branch_stmts[-1].branches[-1] |
| 228 | self.scopes.append(Scope([BranchStatement(initial_state)], scope_type)) |
| 229 | |
| 230 | def exit_scope(self) -> None: |
| 231 | self.scopes.pop() |
| 232 | |
| 233 | def in_scope(self, scope_type: ScopeType) -> bool: |
| 234 | return self._scope().scope_type == scope_type |
| 235 | |
| 236 | def start_branch_statement(self) -> None: |
| 237 | assert len(self._scope().branch_stmts) > 0 |
| 238 | self._scope().branch_stmts.append( |
| 239 | BranchStatement(self._scope().branch_stmts[-1].branches[-1]) |
| 240 | ) |
| 241 | |
| 242 | def next_branch(self) -> None: |
| 243 | assert len(self._scope().branch_stmts) > 1 |
| 244 | self._scope().branch_stmts[-1].next_branch() |
| 245 | |
| 246 | def end_branch_statement(self) -> None: |
| 247 | assert len(self._scope().branch_stmts) > 1 |
| 248 | result = self._scope().branch_stmts.pop().done() |
| 249 | self._scope().branch_stmts[-1].record_nested_branch(result) |
| 250 | |
| 251 | def skip_branch(self) -> None: |
| 252 | # Only skip branch if we're outside of "root" branch statement. |
| 253 | if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip: |
| 254 | self._scope().branch_stmts[-1].skip_branch() |
| 255 | |
| 256 | def record_definition(self, name: str) -> None: |
| 257 | assert len(self.scopes) > 0 |