Skip to content

Commit

Permalink
bug fix: generate directory if path not exist in compile_mmap_model (#…
Browse files Browse the repository at this point in the history
…281)

Generate directory if path not exist in compile_mmap_model when saving a XLinearModel.

Co-authored-by: jianhao peng <[email protected]>
  • Loading branch information
jianhao2016 and jianhao peng authored Feb 23, 2024
1 parent c181496 commit 845ca6a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 3 additions & 2 deletions pecos/xmc/xlinear/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def save(self, model_folder):
Args:
model_folder (str): dir to save the model
"""
if not path.exists(model_folder):
os.makedirs(model_folder)

os.makedirs(model_folder, exist_ok=True)
param = self.append_meta({})
with open(f"{model_folder}/param.json", "w", encoding="utf-8") as fout:
fout.write(json.dumps(param, indent=True))
Expand Down Expand Up @@ -144,6 +144,7 @@ def compile_mmap_model(cls, npz_folder, mmap_folder):
"""
import shutil

os.makedirs(mmap_folder, exist_ok=True)
shutil.copyfile(path.join(npz_folder, "param.json"), path.join(mmap_folder, "param.json"))
HierarchicalMLModel.compile_mmap_model(
path.join(npz_folder, "ranker"), path.join(mmap_folder, "ranker")
Expand Down
2 changes: 0 additions & 2 deletions test/pecos/xmc/xlinear/test_xlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,6 @@ def test_on_model(model, X):


def test_mmap(tmpdir):
from pathlib import Path
from pecos.utils import smat_util
from pecos.xmc.xlinear import XLinearModel
from pecos.xmc import PostProcessor
Expand All @@ -1153,7 +1152,6 @@ def test_mmap(tmpdir):

npz_model_folder = str(tmpdir.join("save_model_npz"))
mmap_model_folder = str(tmpdir.join("save_model_mmap"))
Path(mmap_model_folder).mkdir(parents=True, exist_ok=True)
py_model.save(npz_model_folder)
XLinearModel.compile_mmap_model(npz_model_folder, mmap_model_folder)
mmap_model = XLinearModel.load(mmap_model_folder, is_predict_only=True)
Expand Down

0 comments on commit 845ca6a

Please sign in to comment.