Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]add other devices supported model list #2259

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Global:
model: LaTeX_OCR_rec
mode: check_dataset # check_dataset/train/evaluate/predict
dataset_dir: "./dataset/ocr_rec_latexocr_dataset_example"
device: gpu:0
device: gpu:0,1,2,3
output: "output"

CheckDataset:
Expand Down
2 changes: 1 addition & 1 deletion paddlex/inference/utils/new_ir_blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

NEWIR_BLOCKLIST = [
NEWIR_BLACKLIST = [
"FasterRCNN-ResNet34-FPN",
"FasterRCNN-ResNet50",
"FasterRCNN-ResNet50-FPN",
Expand Down
18 changes: 9 additions & 9 deletions paddlex/inference/utils/pp_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...utils.device import parse_device, set_env_for_device, get_default_device
from ...utils.device import (
parse_device,
set_env_for_device,
get_default_device,
check_device,
)
from ...utils import logging
from .new_ir_blacklist import NEWIR_BLOCKLIST
from .new_ir_blacklist import NEWIR_BLACKLIST


class PaddlePredictorOption(object):
Expand All @@ -28,7 +33,6 @@ class PaddlePredictorOption(object):
"mkldnn",
"mkldnn_bf16",
)
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu")

def __init__(self, model_name=None, **kwargs):
super().__init__()
Expand Down Expand Up @@ -61,7 +65,7 @@ def _get_default_config(self):
"cpu_threads": 1,
"trt_use_static": False,
"delete_pass": [],
"enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
"enable_new_ir": True if self.model_name not in NEWIR_BLACKLIST else False,
"batch_size": 1, # only for trt
}

Expand Down Expand Up @@ -101,11 +105,7 @@ def device(self, device: str):
if not device:
return
device_type, device_ids = parse_device(device)
if device_type not in self.SUPPORT_DEVICE:
support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
raise ValueError(
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
)
check_device(self.model_name, device_type)
self._update("device", device_type)
device_id = device_ids[0] if device_ids is not None else 0
self._update("device_id", device_id)
Expand Down
11 changes: 9 additions & 2 deletions paddlex/modules/base/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from abc import ABC, abstractmethod

from .build_model import build_model
from ...utils.device import update_device_num, set_env_for_device
from ...utils.device import (
update_device_num,
set_env_for_device,
parse_device,
check_device,
)
from ...utils.misc import AutoRegisterABCMetaClass
from ...utils.config import AttrDict
from ...utils.logging import *
Expand Down Expand Up @@ -138,8 +143,10 @@ def get_device(self, using_device_number: int = None) -> str:
Returns:
str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
"""
device_type, device_ids = parse_device(self.global_config.device)
check_device(self.global_config.model, device_type)
if using_device_number:
return update_device_num(self.global_config.device, using_device_number)
return update_device_num(device_type, device_ids, using_device_number)
set_env_for_device(self.global_config.device)
return self.global_config.device

Expand Down
11 changes: 9 additions & 2 deletions paddlex/modules/base/exportor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from abc import ABC, abstractmethod

from .build_model import build_model
from ...utils.device import update_device_num, set_env_for_device
from ...utils.device import (
update_device_num,
set_env_for_device,
parse_device,
check_device,
)
from ...utils.misc import AutoRegisterABCMetaClass
from ...utils.config import AttrDict
from ...utils import logging
Expand Down Expand Up @@ -103,8 +108,10 @@ def get_device(self, using_device_number: int = None) -> str:
Returns:
str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
"""
device_type, device_ids = parse_device(self.global_config.device)
check_device(self.global_config.model, device_type)
if using_device_number:
return update_device_num(self.global_config.device, using_device_number)
return update_device_num(device_type, device_ids, using_device_number)
set_env_for_device(self.global_config.device)
return self.global_config.device

Expand Down
11 changes: 9 additions & 2 deletions paddlex/modules/base/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from abc import ABC, abstractmethod
from pathlib import Path
from .build_model import build_model
from ...utils.device import update_device_num, set_env_for_device
from ...utils.device import (
update_device_num,
set_env_for_device,
parse_device,
check_device,
)
from ...utils.misc import AutoRegisterABCMetaClass
from ...utils.config import AttrDict

Expand Down Expand Up @@ -95,8 +100,10 @@ def get_device(self, using_device_number: int = None) -> str:
Returns:
str: device setting, such as: `gpu:0,1`, `npu:0,1` `cpu`.
"""
device_type, device_ids = parse_device(self.global_config.device)
check_device(self.global_config.model, device_type)
if using_device_number:
return update_device_num(self.global_config.device, using_device_number)
return update_device_num(device_type, device_ids, using_device_number)
set_env_for_device(self.global_config.device)
return self.global_config.device

Expand Down
15 changes: 15 additions & 0 deletions paddlex/paddlex_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
import argparse
import subprocess
import sys
import shutil
import tempfile
from pathlib import Path

from . import create_pipeline
from .inference.pipelines import create_pipeline_from_config, load_pipeline_config
from .repo_manager import setup, get_all_supported_repo_names
from .utils.cache import CACHE_DIR
from .utils import logging
from .utils.interactive_get_pipeline import interactive_get_pipeline

Expand Down Expand Up @@ -65,6 +68,7 @@ def parse_str(s):

################# install pdx #################
parser.add_argument("--install", action="store_true", default=False, help="")
parser.add_argument("--clear_cache", action="store_true", default=False, help="")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为可接受参数,参数值可以对应CACHE_DIR下的目录名(或者是相应的),默认是删除全部

parser.add_argument("plugins", nargs="*", default=[])
parser.add_argument("--no_deps", action="store_true")
parser.add_argument("--platform", type=str, default="github.com")
Expand Down Expand Up @@ -159,6 +163,15 @@ def serve(pipeline, *, device, use_hpip, serial_number, update_license, host, po
run_server(app, host=host, port=port, debug=False)


def clear_cache():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clear_cache 函数移到.utils.cache中,同时clear_cache()可以接受参数控制删除CACHE下的什么内容

cache_dir = Path(CACHE_DIR) / "official_models"
if cache_dir.exists() and cache_dir.is_dir():
shutil.rmtree(cache_dir)
logging.info(f"Successfully cleared the cache models at {cache_dir}")
else:
logging.info(f"No cache models found at {cache_dir}")


# for CLI
def main():
"""API for commad line"""
Expand All @@ -180,6 +193,8 @@ def main():
host=args.host,
port=args.port,
)
elif args.clear_cache:
clear_cache()
else:
if args.get_pipeline_config is not None:
interactive_get_pipeline(args.get_pipeline_config, args.save_path)
Expand Down
24 changes: 17 additions & 7 deletions paddlex/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

from . import logging
from .errors import raise_unsupported_device_error

SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
from .other_devices_model_list import OTHER_DEVICES_MODEL_LIST


def _constr_device(device_type, device_ids):
Expand All @@ -38,6 +37,21 @@ def get_default_device():
return _constr_device("gpu", [avail_gpus[0]])


def check_device(model_name, device_type):
supported_device_type = ["cpu", "gpu", "xpu", "npu", "mlu", "dcu"]
device_type = device_type.lower()
if device_type not in supported_device_type:
support_run_mode_str = ", ".join(supported_device_type)
raise ValueError(
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
)
if device_type in OTHER_DEVICES_MODEL_LIST:
if model_name not in OTHER_DEVICES_MODEL_LIST[device_type]:
raise ValueError(
f"The model '{model_name}' is not supported on {device_type}."
)


def parse_device(device):
"""parse_device"""
# According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
Expand All @@ -55,14 +69,10 @@ def parse_device(device):
f"Device ID must be an integer. Invalid device ID: {device_id}"
)
device_ids = list(map(int, device_ids))
device_type = device_type.lower()
# raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
assert device_type.lower() in SUPPORTED_DEVICE_TYPE
return device_type, device_ids


def update_device_num(device, num):
device_type, device_ids = parse_device(device)
def update_device_num(device_type, device_ids, num):
if device_ids:
assert len(device_ids) >= num
return _constr_device(device_type, device_ids[:num])
Expand Down
1 change: 0 additions & 1 deletion paddlex/utils/file_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
try:
import ujson as json
except:
logging.error("failed to import ujson, using json instead")
import json

from contextlib import contextmanager
Expand Down
Loading