Skip to content

Commit

Permalink
Merge pull request #23 from BeastByteAI/phonnx
Browse files Browse the repository at this point in the history
Phonnx
  • Loading branch information
iryna-kondr committed Aug 30, 2023
2 parents 6a06c00 + 4384237 commit 0f27b55
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 67 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/pypi-deploy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: PyPi Deploy

on:
release:
types: [published]
workflow_dispatch:

jobs:
deploy:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build twine
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
python -m build
twine upload dist/*
74 changes: 10 additions & 64 deletions falcon/runtime.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
from typing import Any, Union, Dict, List, Type, Optional

try:
import onnxruntime as ort
from phonnx.runtime import Runtime as _PhonnxRuntime
except (ImportError, ModuleNotFoundError):
print("ONNXRuntime is not installed. Inference modules will not work.")
ort = None
print("ONNXRuntime/PHONNX is not installed. Inference modules will not work.")
_PhonnxRuntime = None
import numpy as np
from abc import ABC, abstractmethod


class BaseRuntime(ABC):
@abstractmethod
def run(self, X: np.ndarray, **kwargs: Any) -> Any:
pass


class ONNXRuntime(BaseRuntime):
class ONNXRuntime:
"""
Runtime for ONNX models. This runtime can only run onnx models produced by falcon.
Runtime for ONNX models based on PHONNX.
"""

def __init__(self, model: Union[bytes, str]):
if ort is None:
raise RuntimeError(
"ONNXRuntime is not installed. Please install it with `pip install onnxruntime`."
)
self.ort_session = ort.InferenceSession(model)
if _PhonnxRuntime is None:
raise ImportError("PHONNX is not installed.")
self.runtime = _PhonnxRuntime(model=model)

def run(
self, X: np.ndarray, outputs: str = "final", **kwargs: Any
Expand All @@ -49,50 +41,4 @@ def run(
f"Expected `outputs` to be one of [all, final], got `{outputs}`."
)

inputs = self._get_inputs(X)
output_names = self._get_output_names(outputs)

return self.ort_session.run(output_names, inputs)

def _get_inputs(self, X: np.ndarray) -> Dict:
ort_inputs = self.ort_session.get_inputs()
inputs = {}

for i, inp in enumerate(ort_inputs):
dtype: Any
if str(inp.type) == "tensor(float)":
dtype = np.float32
elif str(inp.type) == "tensor(string)":
dtype = np.str_
elif "int" in str(inp.type).lower():
dtype = np.int64
else:
RuntimeError(
f"The model input type should be one of [str, int, float], got f{inp.type}"
)
if len(ort_inputs) > 1:
inputs[str(inp.name)] = np.expand_dims(X[:, i], 1).astype(dtype)
else:
inputs[str(inp.name)] = X.astype(dtype)
return inputs

def _get_output_names(self, outputs: str = "final") -> List[str]:
ort_outputs = self.ort_session.get_outputs()
output_names = [o.name for o in ort_outputs]
if outputs == "final" and len(ort_outputs) > 1:
idx_ = []
for name in output_names:
if not name[0:10] in ("falcon_pl_", "falcon-pl-"):
raise RuntimeError("One of the output nodes has an invalid name.")
idx = int(name.split("/")[0][10:])
idx_.append(idx)
max_idx = max(idx_)
output_names = [
n
for n in output_names
if (
n.startswith(f"falcon_pl_{max_idx}")
or n.startswith(f"falcon-pl-{max_idx}")
)
]
return output_names
return self.runtime.run(X, outputs_to_return=outputs)
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ dependencies = [
"imbalanced-learn>=0.8.1",
"pyarrow>=8.0.0",
"optuna>=3.0.0",
"packaging>=20.0.0"
"packaging>=20.0.0",
"phonnx>=0.0.1",
]
name = "falcon-ml"
version = "0.6.0"
version = "0.7.0"
authors = [
{ name="Oleg Kostromin", email="[email protected]" },
{ name="Iryna Kondrashchenko", email="[email protected]" },
Expand All @@ -26,7 +27,7 @@ authors = [
description = "AutoML library for fast experementations."
readme = "README.md"
license = {text = "MIT"}
requires-python = ">=3.8"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand Down

0 comments on commit 0f27b55

Please sign in to comment.