diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index c8db4bc5960c..c2c925e0477c 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -7,6 +7,7 @@ AssignmentExpr, AssignmentStmt, BreakStmt, + ClassDef, Context, ContinueStmt, DictionaryComprehension, @@ -271,13 +272,16 @@ def variable_may_be_undefined(self, name: str, context: Context) -> None: if self.msg.errors.is_error_code_enabled(errorcodes.PARTIALLY_DEFINED): self.msg.variable_may_be_undefined(name, context) + def process_definition(self, name: str) -> None: + # Was this name previously used? If yes, it's a use-before-definition error. + refs = self.tracker.pop_undefined_ref(name) + for ref in refs: + self.var_used_before_def(name, ref) + self.tracker.record_definition(name) + def process_lvalue(self, lvalue: Lvalue | None) -> None: if isinstance(lvalue, NameExpr): - # Was this name previously used? If yes, it's a use-before-definition error. - refs = self.tracker.pop_undefined_ref(lvalue.name) - for ref in refs: - self.var_used_before_def(lvalue.name, ref) - self.tracker.record_definition(lvalue.name) + self.process_definition(lvalue.name) elif isinstance(lvalue, StarExpr): self.process_lvalue(lvalue.expr) elif isinstance(lvalue, (ListExpr, TupleExpr)): @@ -327,7 +331,7 @@ def visit_match_stmt(self, o: MatchStmt) -> None: self.tracker.end_branch_statement() def visit_func_def(self, o: FuncDef) -> None: - self.tracker.record_definition(o.name) + self.process_definition(o.name) self.tracker.enter_scope() super().visit_func_def(o) self.tracker.exit_scope() @@ -476,6 +480,12 @@ def visit_with_stmt(self, o: WithStmt) -> None: self.process_lvalue(idx) o.body.accept(self) + def visit_class_def(self, o: ClassDef) -> None: + self.process_definition(o.name) + self.tracker.enter_scope() + super().visit_class_def(o) + self.tracker.exit_scope() + def visit_import(self, o: Import) -> None: for mod, alias in o.ids: if alias is not None: diff --git a/test-data/unit/check-partially-defined.test b/test-data/unit/check-partially-defined.test index 52822f98ab53..7c10306684ca 100644 --- a/test-data/unit/check-partially-defined.test +++ b/test-data/unit/check-partially-defined.test @@ -140,6 +140,40 @@ def f0(b: bool) -> None: fn = lambda: 2 y = fn # E: Name "fn" may be undefined +[case testUseBeforeDefClass] +# flags: --enable-error-code partially-defined --enable-error-code use-before-def +def f(x: A): # No error here. + pass +y = A() # E: Name "A" is used before definition +class A: pass + +[case testClassScope] +# flags: --enable-error-code partially-defined --enable-error-code use-before-def +class C: + x = 0 + def f0(self) -> None: pass + + def f2(self) -> None: + f0() # No error. + self.f0() # No error. + +f0() # E: Name "f0" is used before definition +def f0() -> None: pass +y = x # E: Name "x" is used before definition +x = 1 + +[case testClassInsideFunction] +# flags: --enable-error-code partially-defined --enable-error-code use-before-def +def f() -> None: + class C: pass + +c = C() # E: Name "C" is used before definition +class C: pass + +[case testUseBeforeDefFunc] +# flags: --enable-error-code partially-defined --enable-error-code use-before-def +foo() # E: Name "foo" is used before definition +def foo(): pass [case testGenerator] # flags: --enable-error-code partially-defined if int():