Skip to content

Commit

Permalink
[TVMSCRIPT] Using diagnostics for TVM Script (#6797)
Browse files Browse the repository at this point in the history
* [TVMSCRIPT] Using diagnostics for TVM Script

* fix lint

* More documentation, improve some error messages

* Apply suggestions from code review

Co-authored-by: Leandro Nunes <[email protected]>

* Add synr to ci setup and setup.py

* remove typed_ast dependency

Co-authored-by: Leandro Nunes <[email protected]>
  • Loading branch information
tkonolige and leandron committed Nov 6, 2020
1 parent b31f4ae commit d164aac
Show file tree
Hide file tree
Showing 10 changed files with 726 additions and 605 deletions.
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ set -u
set -o pipefail

# install libraries for python package on ubuntu
pip3 install six numpy pytest cython decorator scipy tornado typed_ast pytest pytest-xdist pytest-profiling mypy orderedset attrs requests Pillow packaging cloudpickle synr
pip3 install six numpy pytest cython decorator scipy tornado pytest pytest-xdist pytest-profiling mypy orderedset attrs requests Pillow packaging cloudpickle synr
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def get_package_data_files():
"decorator",
"attrs",
"psutil",
"typed_ast",
"synr>=0.2.1",
],
extras_require={
"test": ["pillow<7", "matplotlib"],
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def lookup_symbol(self, name):
return symbols[name]
return None

def report_error(self, message):
self.parser.report_error(message)
def report_error(self, message, span):
self.parser.report_error(message, span)
54 changes: 54 additions & 0 deletions python/tvm/script/diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Bridge from synr's (the library used for parsing the python AST)
DiagnosticContext to TVM's diagnostics
"""
import tvm
from synr import DiagnosticContext, ast
from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic


class TVMDiagnosticCtx(DiagnosticContext):
"""TVM diagnostics for synr"""

diag_ctx: TVMCtx

def __init__(self) -> None:
self.diag_ctx = TVMCtx(tvm.IRModule(), get_renderer())
self.source_name = None

def to_tvm_span(self, src_name, ast_span: ast.Span) -> tvm.ir.Span:
return tvm.ir.Span(
src_name,
ast_span.start_line,
ast_span.end_line,
ast_span.start_column,
ast_span.end_column,
)

def add_source(self, name: str, source: str) -> None:
src_name = self.diag_ctx.module.source_map.add(name, source)
self.source_name = src_name

def emit(self, _level, message, span):
span = self.to_tvm_span(self.source_name, span)
self.diag_ctx.emit(Diagnostic(DiagnosticLevel.ERROR, span, message))
self.diag_ctx.render() # Raise exception on the first error we hit. TODO remove

def render(self):
self.diag_ctx.render()
31 changes: 13 additions & 18 deletions python/tvm/script/meta_unparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,29 @@
"""Unparse meta AST node into a dict"""
# pylint: disable=invalid-name

from typed_ast import ast3 as ast
from synr import Transformer


class MetaUnparser(ast.NodeVisitor):
class MetaUnparser(Transformer):
"""Python AST Visitor to unparse meta AST node into a dict"""

def visit_Dict(self, node):
def transform(self, node):
method = "transform_" + node.__class__.__name__
visitor = getattr(self, method, None)
if visitor is None:
self.error(f"Unexpected node type {type(node)} when parsing __tvm_meta__", node.span)
return visitor(node)

def transform_DictLiteral(self, node):
keys = [self.visit(key) for key in node.keys]
values = [self.visit(value) for value in node.values]
return dict(zip(keys, values))

def visit_Tuple(self, node):
def transform_Tuple(self, node):
return tuple(self.visit(element) for element in node.elts)

def visit_List(self, node):
def transform_ArrayLiteral(self, node):
return [self.visit(element) for element in node.elts]

def visit_keyword(self, node):
return node.arg, self.visit(node.value)

def visit_NameConstant(self, node):
return node.value

def visit_Constant(self, node):
def transform_Constant(self, node):
return node.value

def visit_Num(self, node):
return node.n

def visit_Str(self, node):
return node.s
Loading

0 comments on commit d164aac

Please sign in to comment.