Skip to content

Commit

Permalink
Merge pull request #456 from MCYBA/main
Browse files Browse the repository at this point in the history
Add onnxruntime providers selection feature
  • Loading branch information
danielgatis authored May 24, 2023
2 parents c360cf6 + 8383a70 commit 8b6abef
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions rembg/session_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .sessions.u2net import U2netSession


def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
def new_session(model_name: str = "u2net", providers=None, *args, **kwargs) -> BaseSession:
session_class: Type[BaseSession] = U2netSession

for sc in sessions_class:
Expand All @@ -21,4 +21,4 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
if "OMP_NUM_THREADS" in os.environ:
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])

return session_class(model_name, sess_opts, *args, **kwargs)
return session_class(model_name, sess_opts, providers, *args, **kwargs)
15 changes: 13 additions & 2 deletions rembg/sessions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,22 @@


class BaseSession:
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, providers=None, *args, **kwargs):
self.model_name = model_name

self.providers = []

_providers = ort.get_available_providers()
if providers:
for provider in providers:
if provider in _providers:
self.providers.append(provider)
else:
self.providers.extend(_providers)

self.inner_session = ort.InferenceSession(
str(self.__class__.download_models()),
providers=ort.get_available_providers(),
providers=self.providers,
sess_options=sess_opts,
)

Expand Down

0 comments on commit 8b6abef

Please sign in to comment.