-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Implement chained comparison improvements and related checks #7611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
14be74f
f41f4cb
5bf97a2
6f33ed6
409e25e
b1d38d9
1af2ffc
466a92a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
import itertools | ||
import sys | ||
import tokenize | ||
from collections.abc import Iterator | ||
from collections.abc import Iterator, Sequence | ||
from functools import reduce | ||
from re import Pattern | ||
from typing import TYPE_CHECKING, Any, NamedTuple, Union | ||
|
@@ -21,6 +21,7 @@ | |
from pylint import checkers | ||
from pylint.checkers import utils | ||
from pylint.checkers.utils import node_frame_class | ||
from pylint.graph import get_cycles, get_paths | ||
from pylint.interfaces import HIGH | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -348,7 +349,7 @@ class RefactoringChecker(checkers.BaseTokenChecker): | |
"more idiomatic, although sometimes a bit slower", | ||
), | ||
"R1716": ( | ||
"Simplify chained comparison between the operands", | ||
"Simplify chained comparison between the operands: %s", | ||
"chained-comparison", | ||
"This message is emitted when pylint encounters boolean operation like " | ||
'"a < b and b < c", suggesting instead to refactor it to "a < b < c"', | ||
|
@@ -476,6 +477,17 @@ class RefactoringChecker(checkers.BaseTokenChecker): | |
"value by index lookup. " | ||
"The value can be accessed directly instead.", | ||
), | ||
"R1737": ( | ||
"Simplify cycle to ==", | ||
"comparison-all-equal", | ||
"Emitted when items in a boolean condition are all <= or >=" | ||
"This is equivalent to asking if they all equal.", | ||
), | ||
"R1738": ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this should raise the existing message for constant comparison |
||
"This comparison always evalutes to False", | ||
"impossible-comparison", | ||
"Emitted when there a comparison that is always False.", | ||
), | ||
} | ||
options = ( | ||
( | ||
|
@@ -1296,61 +1308,154 @@ def _check_consider_using_in(self, node: nodes.BoolOp) -> None: | |
confidence=HIGH, | ||
) | ||
|
||
def _check_chained_comparison(self, node: nodes.BoolOp) -> None: | ||
"""Check if there is any chained comparison in the expression. | ||
|
||
Add a refactoring message if a boolOp contains comparison like a < b and b < c, | ||
which can be chained as a < b < c. | ||
|
||
Care is taken to avoid simplifying a < b < c and b < d. | ||
""" | ||
if node.op != "and" or len(node.values) < 2: | ||
def _check_comparisons(self, node: nodes.BoolOp) -> None: | ||
graph_info = self._get_graph_from_comparison_nodes(node) | ||
if not graph_info: | ||
return | ||
( | ||
graph_dict, | ||
symbol_dict, | ||
indegree_dict, | ||
frequency_dict, | ||
) = graph_info | ||
|
||
# Convert graph_dict to all strings to access the get_cycles API | ||
str_dict = { | ||
str(key): {str(dest) for dest in graph_dict[key]} for key in graph_dict | ||
} | ||
cycles = get_cycles(str_dict) | ||
if cycles: | ||
self._handle_cycles(node, symbol_dict, cycles) | ||
return | ||
|
||
def _find_lower_upper_bounds( | ||
comparison_node: nodes.Compare, | ||
uses: collections.defaultdict[str, dict[str, set[nodes.Compare]]], | ||
) -> None: | ||
left_operand = comparison_node.left | ||
for operator, right_operand in comparison_node.ops: | ||
for operand in (left_operand, right_operand): | ||
value = None | ||
if isinstance(operand, nodes.Name): | ||
value = operand.name | ||
elif isinstance(operand, nodes.Const): | ||
value = operand.value | ||
|
||
if value is None: | ||
continue | ||
paths = get_paths(graph_dict, indegree_dict, frequency_dict) | ||
|
||
if operator in {"<", "<="}: | ||
if operand is left_operand: | ||
uses[value]["lower_bound"].add(comparison_node) | ||
elif operand is right_operand: | ||
uses[value]["upper_bound"].add(comparison_node) | ||
elif operator in {">", ">="}: | ||
if operand is left_operand: | ||
uses[value]["upper_bound"].add(comparison_node) | ||
elif operand is right_operand: | ||
uses[value]["lower_bound"].add(comparison_node) | ||
left_operand = right_operand | ||
|
||
uses: collections.defaultdict[ | ||
str, dict[str, set[nodes.Compare]] | ||
] = collections.defaultdict( | ||
lambda: {"lower_bound": set(), "upper_bound": set()} | ||
) | ||
for comparison_node in node.values: | ||
if isinstance(comparison_node, nodes.Compare): | ||
_find_lower_upper_bounds(comparison_node, uses) | ||
|
||
for bounds in uses.values(): | ||
num_shared = len(bounds["lower_bound"].intersection(bounds["upper_bound"])) | ||
num_lower_bounds = len(bounds["lower_bound"]) | ||
num_upper_bounds = len(bounds["upper_bound"]) | ||
if num_shared < num_lower_bounds and num_shared < num_upper_bounds: | ||
self.add_message("chained-comparison", node=node) | ||
break | ||
if len(paths) < len(node.values): | ||
suggestions = [] | ||
for path in paths: | ||
cur_statement = str(path[0]) | ||
for i in range(len(path) - 1): | ||
cur_statement += ( | ||
" " + symbol_dict[path[i], path[i + 1]] + " " + str(path[i + 1]) | ||
) | ||
suggestions.append(cur_statement) | ||
args = " and ".join(sorted(suggestions)) | ||
self.add_message("chained-comparison", node=node, args=(args,)) | ||
|
||
def _get_graph_from_comparison_nodes( | ||
self, node: nodes.BoolOp | ||
) -> None | tuple[ | ||
dict[str | int | float, set[str | int | float]], | ||
dict[tuple[str | int | float, str | int | float], str], | ||
dict[str | int | float, int], | ||
dict[tuple[str | int | float, str | int | float], int], | ||
]: | ||
if node.op != "and" or len(node.values) < 2: | ||
return None | ||
|
||
graph_dict: dict[ | ||
str | int | float, set[str | int | float] | ||
] = collections.defaultdict(set) | ||
symbol_dict: dict[ | ||
tuple[str | int | float, str | int | float], str | ||
] = collections.defaultdict(lambda: ">") | ||
frequency_dict: dict[ | ||
tuple[str | int | float, str | int | float], int | ||
] = collections.defaultdict(int) | ||
indegree_dict: dict[str | int | float, int] = collections.defaultdict(int) | ||
const_values: list[int | float] = [] | ||
|
||
for statement in node.values: | ||
if not isinstance(statement, nodes.Compare): | ||
return None | ||
ops = list(statement.ops) | ||
left_statement = statement.left | ||
while ops: | ||
left = self._get_compare_operand_value(left_statement, const_values) | ||
# Pop from ops or else we never advance along the statement | ||
operator, right_statement = ops.pop(0) | ||
# The operand is not a constant or variable or the operator is not a comparison | ||
if operator not in {"<", ">", "==", "<=", ">="} or left is None: | ||
return None | ||
right = self._get_compare_operand_value(right_statement, const_values) | ||
if right is None: | ||
return None | ||
|
||
# Make the graph always point from larger to smaller | ||
if operator == "<": | ||
operator = ">" | ||
left, right = right, left | ||
elif operator == "<=": | ||
operator = ">=" | ||
left, right = right, left | ||
|
||
# Update maps | ||
graph_dict[left].add(right) | ||
if not graph_dict[right]: | ||
graph_dict[right] = set() # Ensure the node exists in graph | ||
symbol_dict[(left, right)] = operator | ||
indegree_dict[left] += 0 # Make sure every node has an entry | ||
indegree_dict[right] += 1 | ||
frequency_dict[(left, right)] += 1 | ||
|
||
# advance onto the next comprison if it exists | ||
Pierre-Sassoulas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
left_statement = right_statement | ||
|
||
# Nothing was added and we have no recommendations | ||
if ( | ||
not graph_dict | ||
or not symbol_dict | ||
or all(val == "==" for val in symbol_dict.values()) | ||
): | ||
return None | ||
|
||
# Link up constant nodes, i.e. create synthetic nodes between 1 and 5 such that 5 > 1 | ||
sorted_consts = sorted(const_values) | ||
while sorted_consts: | ||
largest = sorted_consts.pop() | ||
for smaller in set(sorted_consts): | ||
if smaller < largest: | ||
symbol_dict[(largest, smaller)] = ">" | ||
indegree_dict[smaller] += 1 | ||
frequency_dict[(largest, smaller)] += 1 | ||
graph_dict[largest].add(smaller) | ||
|
||
# Remove paths from the larger number to the smaller number's adjacent nodes | ||
# This prevents duplicated paths in the output | ||
for adj in graph_dict[smaller]: | ||
if isinstance(adj, str): | ||
graph_dict[largest].discard(adj) | ||
|
||
return (graph_dict, symbol_dict, indegree_dict, frequency_dict) | ||
|
||
def _get_compare_operand_value( | ||
self, node: nodes.Compare, const_values: list[int | float] | ||
Pierre-Sassoulas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> str | int | float | None: | ||
value = None | ||
Pierre-Sassoulas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(node, nodes.Name) and isinstance(node.name, str): | ||
value = node.name | ||
elif isinstance(node, nodes.Const) and isinstance(node.value, (int, float)): | ||
value = node.value | ||
const_values.append(value) | ||
return value | ||
|
||
def _handle_cycles( | ||
self, | ||
node: nodes.BoolOp, | ||
symbol_dict: dict[tuple[str | int | float, str | int | float], str], | ||
cycles: Sequence[list[str]], | ||
) -> None: | ||
for cycle in cycles: | ||
all_geq = all( | ||
symbol_dict[(cur_item, cycle[i + 1])] == ">=" | ||
for (i, cur_item) in enumerate(cycle) | ||
if i < len(cycle) - 1 | ||
) | ||
all_geq = all_geq and symbol_dict[cycle[-1], cycle[0]] == ">=" | ||
if all_geq: | ||
self.add_message("comparison-all-equal", node=node) | ||
else: | ||
self.add_message("impossible-comparison", node=node) | ||
|
||
@staticmethod | ||
def _apply_boolean_simplification_rules( | ||
|
@@ -1441,7 +1546,7 @@ def _check_simplifiable_condition(self, node: nodes.BoolOp) -> None: | |
def visit_boolop(self, node: nodes.BoolOp) -> None: | ||
self._check_consider_merging_isinstance(node) | ||
self._check_consider_using_in(node) | ||
self._check_chained_comparison(node) | ||
self._check_comparisons(node) | ||
self._check_simplifiable_condition(node) | ||
|
||
@staticmethod | ||
|
Uh oh!
There was an error while loading. Please reload this page.