Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e83679f
Support narrowing literals and enums using the in operator in combina…
tyralla Mar 17, 2024
a0d1db3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2024
2bff3c2
fix
tyralla Mar 17, 2024
5e8a525
Merge branch 'feature/narrow_using_in' of https://github.com/tyralla/…
tyralla Mar 17, 2024
2fa954a
Optional[Expression] -> Expression | None
tyralla Mar 17, 2024
db7b969
Add __eq__ to object in tuple.pyi
tyralla Mar 17, 2024
23bfd4e
simplification: avoid returning None
tyralla Mar 18, 2024
207c56e
replace the critical `in` comparisons for testing
tyralla Mar 19, 2024
dea2614
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
214f51a
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
tyralla Mar 19, 2024
a495fee
Revert "replace the critical `in` comparisons for testing"
tyralla Mar 19, 2024
8785fe9
fix mypyc crash
tyralla Mar 21, 2024
6af40ae
make NameExpr mypyc copyable
tyralla Mar 21, 2024
e52eb25
ignore star expressions
tyralla Mar 21, 2024
763c265
update docs
tyralla Mar 21, 2024
44d71eb
Also support list expressions.
tyralla Oct 26, 2024
bd145d1
Merge branch 'master' into feature/narrow_using_in
tyralla Oct 26, 2024
a04126d
Also support set expressions.
tyralla Oct 27, 2024
3e5bdf9
Merge branch 'master' into feature/narrow_using_in
tyralla May 23, 2025
78dbff2
Adjust the `testNarrowingOptionalEqualsNone` test case.
tyralla May 23, 2025
e445651
update comment
tyralla May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docs/source/literal_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ perform an exhaustiveness check, you need to update your code to use an

.. code-block:: python

from typing import Literal, NoReturn
from typing import Literal
from typing_extensions import assert_never

PossibleValues = Literal['one', 'two']
Expand Down Expand Up @@ -368,6 +368,19 @@ without a value:
elif x == 'two':
return False

For the sake of brevity, you can use the ``in`` operator in combination with
list, set, or tuple expressions (lists, sets, or tuples created "on the fly"):

.. code-block:: python

PossibleValues = Literal['one', 'two', 'three']

def validate(x: PossibleValues) -> bool:
if x in ['one']:
return True
elif x in ('two', 'three'):
return False

Exhaustiveness checking is also supported for match statements (Python 3.10 and later):

.. code-block:: python
Expand Down
61 changes: 61 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping, Sequence, Set as AbstractSet
from contextlib import ExitStack, contextmanager
from copy import copy
from typing import Callable, Final, Generic, NamedTuple, Optional, TypeVar, Union, cast, overload
from typing_extensions import TypeAlias as _TypeAlias, TypeGuard

Expand Down Expand Up @@ -104,6 +105,7 @@
RaiseStmt,
RefExpr,
ReturnStmt,
SetExpr,
StarExpr,
Statement,
StrExpr,
Expand Down Expand Up @@ -4832,12 +4834,71 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
if self.in_checked_function():
self.fail(message_registry.RETURN_VALUE_EXPECTED, s)

def _transform_sequence_expressions_for_narrowing_with_in(self, e: Expression) -> Expression:
"""
Transform an expression like

(x is None) and (x in (1, 2)) and (x not in [3, 4])

into

(x is None) and (x == 1 or x == 2) and (x != 3 and x != 4)

This transformation is supposed to enable narrowing literals and enums using the
in (and the not in) operator in combination with tuple, list, and set expressions
without the need to implement additional narrowing logic.
"""
if isinstance(e, OpExpr):
e.left = self._transform_sequence_expressions_for_narrowing_with_in(e.left)
e.right = self._transform_sequence_expressions_for_narrowing_with_in(e.right)
return e

if not (
isinstance(e, ComparisonExpr)
and isinstance(left := e.operands[0], NameExpr)
and ((op_in := e.operators[0]) in ("in", "not in"))
and isinstance(litu := e.operands[1], (ListExpr, SetExpr, TupleExpr))
):
return e

op_eq, op_con = (["=="], "or") if (op_in == "in") else (["!="], "and")
line = e.line
left_new = left
comparisons = []
for right in reversed(litu.items):
if isinstance(right, StarExpr):
return e
comparison = ComparisonExpr(op_eq, [left_new, right])
comparison.line = line
comparisons.append(comparison)
left_new = copy(left)
if (nmb := len(comparisons)) == 0:
if op_in == "in":
e = NameExpr("False")
e.fullname = "builtins.False"
e.line = line
return e
e = NameExpr("True")
e.fullname = "builtins.True"
e.line = line
return e
if nmb == 1:
return comparisons[0]
e = OpExpr(op_con, comparisons[1], comparisons[0])
for comparison in comparisons[2:]:
e = OpExpr(op_con, comparison, e)
e.line = line
return e

def visit_if_stmt(self, s: IfStmt) -> None:
"""Type check an if statement."""
# This frame records the knowledge from previous if/elif clauses not being taken.
# Fall-through to the original frame is handled explicitly in each block.
with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0):
for e, b in zip(s.expr, s.body):

e = self._transform_sequence_expressions_for_narrowing_with_in(e)

t = get_proper_type(self.expr_checker.accept(e))

if isinstance(t, DeletedType):
Expand Down
4 changes: 3 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,7 +1910,9 @@ class NameExpr(RefExpr):

__match_args__ = ("name", "node")

def __init__(self, name: str) -> None:
def __init__(self, name: str = "?") -> None:
# The default name "?" aims to make NameExpr mypyc copyable.
# Always pass a proper name when manually calling NameExpr.__init__.
super().__init__()
self.name = name # Name referred to
# Is this a l.h.s. of a special form assignment like typed dict or type variable?
Expand Down
9 changes: 9 additions & 0 deletions mypyc/test-data/run-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,12 @@ def test_multiply() -> None:
assert (1,) * 3 == res
assert 3 * (1,) == res
assert multiply((1,), 3) == res

[case testTupleDoNotCrashOnTransformedInComparisons]
def f() -> None:
for n in ["x"]:
if n in ("x", "z") or n.startswith("y"):
print(n)
f()
[out]
x
2 changes: 1 addition & 1 deletion test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -2004,7 +2004,7 @@ class C(A): pass

y: Optional[B]
if y in (B(), C()):
reveal_type(y) # N: Revealed type is "__main__.B"
reveal_type(y) # N: Revealed type is "Union[__main__.B, None]"
else:
reveal_type(y) # N: Revealed type is "Union[__main__.B, None]"
[builtins fixtures/tuple.pyi]
Expand Down
105 changes: 103 additions & 2 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1375,9 +1375,9 @@ else:
if val in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "__main__.A"
if val not in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "__main__.A"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
[builtins fixtures/primitives.pyi]
Expand Down Expand Up @@ -2313,6 +2313,107 @@ def f(x: C) -> None:
f(C(5))
[builtins fixtures/primitives.pyi]

[case testNarrowLiteralsInListOrSetOrTupleExpression]
# flags: --warn-unreachable

from typing import Optional
from typing_extensions import Literal

x: int

def f(v: Optional[Literal[1, 2, 3, 4]]) -> None:
if v in (0, 1, 2):
reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]"
elif v in [1]:
reveal_type(v) # E: Statement is unreachable
elif v is None or v in {3, x}:
reveal_type(v) # N: Revealed type is "Union[Literal[3], Literal[4], None]"
elif v in ():
reveal_type(v) # E: Statement is unreachable
else:
reveal_type(v) # N: Revealed type is "Literal[4]"
reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], None]"
[builtins fixtures/primitives.pyi]

[case testNarrowLiteralsNotInListOrSetOrTupleExpression]
# flags: --warn-unreachable

from typing import Optional
from typing_extensions import Literal

x: int

def f(v: Optional[Literal[1, 2, 3, 4, 5]]) -> None:
if v not in {0, 1, 2, 3}:
reveal_type(v) # N: Revealed type is "Union[Literal[4], Literal[5], None]"
elif v not in [1, 2, 3, 4]: # E: Right operand of "and" is never evaluated
reveal_type(v) # E: Statement is unreachable
elif v is not None and v not in (3,):
reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]"
elif v not in (x, 3):
reveal_type(v) # E: Statement is unreachable
else:
reveal_type(v) # N: Revealed type is "Literal[3]"
reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], Literal[5], None]"
[builtins fixtures/primitives.pyi]

[case testNarrowEnumsInListOrSetOrTupleExpression]
from enum import Enum
from typing import Final

class E(Enum):
A = 1
B = 2
C = 3
D = 4

A: Final = E.A
C: Final = E.C

def f(v: E) -> None:
reveal_type(v) # N: Revealed type is "__main__.E"
if v in (A, E.B):
reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]"
elif v in [E.A]:
reveal_type(v)
elif v in {C}:
reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]"
elif v in ():
reveal_type(v)
else:
reveal_type(v) # N: Revealed type is "Literal[__main__.E.D]"
reveal_type(v) # N: Revealed type is "__main__.E"
[builtins fixtures/primitives.pyi]

[case testNarrowEnumsNotInListOrSetOrTupleExpression]
from enum import Enum
from typing import Final

class E(Enum):
A = 1
B = 2
C = 3
D = 4
E = 5

A: Final = E.A
C: Final = E.C

def f(v: E) -> None:
reveal_type(v) # N: Revealed type is "__main__.E"
if v not in (A, E.B, E.C):
reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.D], Literal[__main__.E.E]]"
elif v not in [E.A, E.B, E.C, E.C]:
reveal_type(v)
elif v not in {C}:
reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]"
elif v not in []:
reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]"
else:
reveal_type(v)
reveal_type(v) # N: Revealed type is "__main__.E"
[builtins fixtures/primitives.pyi]

[case testNarrowingTypeVarNone]
# flags: --warn-unreachable

Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/fixtures/tuple.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ _Tco = TypeVar('_Tco', covariant=True)

class object:
def __init__(self) -> None: pass
def __new__(cls) -> Self: ...
def __new__(cls) -> Self: pass
def __eq__(self, other: object) -> bool: pass

class type:
def __init__(self, *a: object) -> None: pass
Expand Down