Skip to content

Commit

Permalink
Merge pull request #89 from Hynn01/main
Browse files Browse the repository at this point in the history
Add library constraint to forward-pytorch checker
  • Loading branch information
Hynn01 authored Jun 17, 2022
2 parents b64433a + 17ce3cf commit 78877ff
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 9 deletions.
19 changes: 18 additions & 1 deletion dslinter/checkers/forward_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pylint.interfaces import IAstroidChecker

from dslinter.utils.exception_handler import ExceptionHandler
from dslinter.utils.randomness_control_helper import has_import


class ForwardPytorchChecker(BaseChecker):
Expand All @@ -22,6 +23,12 @@ class ForwardPytorchChecker(BaseChecker):
}
options = ()

_import_torch = False

def visit_import(self, import_node: astroid.Import):
if self._import_torch is False:
self._import_torch = has_import(import_node, "torch")

def visit_call(self, call_node: astroid.Call):
"""
When a Call node is visited, check whether it violated the rule in this checker.
Expand All @@ -46,7 +53,17 @@ def visit_call(self, call_node: astroid.Call):
and call_node.func.expr.func.name == "super"
):
_call_from_super = True
if _has_forward is True and (_call_from_self is False and _call_from_super is False):
if(
self._import_torch is True
and _has_forward is True
and (
_call_from_self is False
and _call_from_super is False
)
):
self.add_message("forward-pytorch", node=call_node)
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, call_node)

def leave_module(self, module: astroid.Module):
self._import_torch = False
22 changes: 16 additions & 6 deletions dslinter/tests/checkers/test_forward_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestForwardPytorchChecker(pylint.testutils.CheckerTestCase):
def test_use_forward(self):
"""Message will be added if the self.net.forward() is used in the code rather than self.net()."""
script = """
import torch.nn as nn
import torch.nn as nn #@
class Net(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -33,14 +33,16 @@ def forward(self, x):
x = self.fc3(x)
return x
"""
call_node = astroid.extract_node(script).value
import_node, assign_node = astroid.extract_node(script)
call_node = assign_node.value
with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="forward-pytorch", node=call_node)):
self.checker.visit_import(import_node)
self.checker.visit_call(call_node)

def test_not_use_forward(self):
"""No message will be added if self.net() is used in the code."""
script = """
import torch.nn as nn
import torch.nn as nn #@
class Net(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -60,25 +62,31 @@ def forward(self, x):
x = self.fc3(x)
return x
"""
call_node = astroid.extract_node(script).value
import_node, assign_node = astroid.extract_node(script)
call_node = assign_node.value
with self.assertNoMessages():
self.checker.visit_import(import_node)
self.checker.visit_call(call_node)

def test_use_self_forward(self):
"""No Message will be added if the self.forward() is used in the code."""
script = """
import torch #@
def training_step(self, batch, batch_nb):
idx = batch['idx']
loss = self.forward(batch)[0] #@
return {'loss': loss, 'idx': idx}
"""
call_node = astroid.extract_node(script).value.value
import_node, assign_node = astroid.extract_node(script)
call_node = assign_node.value.value
with self.assertNoMessages():
self.checker.visit_import(import_node)
self.checker.visit_call(call_node)

def test_use_super_forward(self):
"""No Message will be added if the super().forward() is used in the code."""
script = """
import torch #@
class SpatialDropout(nn.Dropout2d):
def forward(self, x):
x = x.unsqueeze(2) # (N, T, 1, K)
Expand All @@ -88,6 +96,8 @@ def forward(self, x):
x = x.squeeze(2) # (N, T, K)
return x
"""
call_node = astroid.extract_node(script).value
import_node, assign_node = astroid.extract_node(script)
call_node = assign_node.value
with self.assertNoMessages():
self.checker.visit_import(import_node)
self.checker.visit_call(call_node)
2 changes: 1 addition & 1 deletion dslinter/utils/randomness_control_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def check_main_module(module: astroid.Module) -> bool:

def has_import(node: astroid.Import, library_name: str):
for name, _ in node.names:
if name == library_name:
if name == library_name or library_name in name:
return True
return False

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ skip = 'scripts'

[tool.poetry]
name = "dslinter"
version = "2.0.8"
version = "2.0.9"
description = "`dslinter` is a pylint plugin for linting data science and machine learning code. We plan to support the following Python libraries: TensorFlow, PyTorch, Scikit-Learn, Pandas, NumPy and SciPy."

license = "GPL-3.0 License"
Expand Down

0 comments on commit 78877ff

Please sign in to comment.