Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Frontend] PyTorch importer #95

Merged
merged 5 commits into from
Jan 27, 2023

Conversation

spectrometerHBH
Copy link
Member

@spectrometerHBH spectrometerHBH commented Jan 11, 2023

Implements the Relax importer from PyTorch, using torch FX.

An example use of the importer is:

# Import the importer.
from tvm.relax.frontend import from_pytorch

# Define the module
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True)

    def forward(self, input):
        return self.linear(input)

# Instantiate the model and create the input info dict.
torch_model = MyModule()
input_info = {"input_1": ((128, 10), "float32")}

# Use the importer to import the PyTorch model to Relax.
mod: tvm.IRModule = from_pytorch(torch_model, input_info)

# Print out the imported model.
# print(mod.script())

@tqchen
Copy link
Contributor

tqchen commented Jan 20, 2023

@Hzfengsy @jinhongyii would be great if you can help review

python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for brining the Fx importer!
A few comments.

tests/python/relax/test_torch_importer.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/pytorch_fx.py Outdated Show resolved Hide resolved
@MasterJH5574 MasterJH5574 merged commit d8e982f into mlc-ai:relax Jan 27, 2023
MasterJH5574 added a commit that referenced this pull request Jan 28, 2023
Implements the Relax importer from PyTorch, using torch FX.

An example use of the importer is:

```python
# Import the importer.
from tvm.relax.frontend import from_pytorch

# Define the module
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True)

    def forward(self, input):
        return self.linear(input)

# Instantiate the model and create the input info dict.
torch_model = MyModule()
input_info = {"input_1": ((128, 10), "float32")}

# Use the importer to import the PyTorch model to Relax.
mod: tvm.IRModule = from_pytorch(torch_model, input_info)

# Print out the imported model.
# print(mod.script())
```

---------

Co-authored-by: Ruihang Lai <[email protected]>
MasterJH5574 added a commit that referenced this pull request Jan 31, 2023
Implements the Relax importer from PyTorch, using torch FX.

An example use of the importer is:

```python
# Import the importer.
from tvm.relax.frontend import from_pytorch

# Define the module
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True)

    def forward(self, input):
        return self.linear(input)

# Instantiate the model and create the input info dict.
torch_model = MyModule()
input_info = {"input_1": ((128, 10), "float32")}

# Use the importer to import the PyTorch model to Relax.
mod: tvm.IRModule = from_pytorch(torch_model, input_info)

# Print out the imported model.
# print(mod.script())
```

---------

Co-authored-by: Ruihang Lai <[email protected]>
MasterJH5574 added a commit that referenced this pull request Feb 8, 2023
Enhance VM Executable as a Subclass of runtime::Module
MasterJH5574 added a commit that referenced this pull request Feb 8, 2023
Implements the Relax importer from PyTorch, using torch FX.

An example use of the importer is:

```python
# Import the importer.
from tvm.relax.frontend import from_pytorch

# Define the module
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True)

    def forward(self, input):
        return self.linear(input)

# Instantiate the model and create the input info dict.
torch_model = MyModule()
input_info = {"input_1": ((128, 10), "float32")}

# Use the importer to import the PyTorch model to Relax.
mod: tvm.IRModule = from_pytorch(torch_model, input_info)

# Print out the imported model.
# print(mod.script())
```

---------

Co-authored-by: Ruihang Lai <[email protected]>
spectrometerHBH pushed a commit to spectrometerHBH/relax that referenced this pull request Feb 9, 2023
Enhance VM Executable as a Subclass of runtime::Module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants