Skip to content

Commit 13cb10d

Browse files
authored
Use temp dir and abs path in api_gen.py (keras-team#19533)
* Use temp dir and abs path * Use temp dir and abs path * Update Readme
1 parent e57b138 commit 13cb10d

File tree

2 files changed

+68
-36
lines changed

2 files changed

+68
-36
lines changed

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ pip install -r requirements.txt
5050
python pip_build.py --install
5151
```
5252

53+
3. Run API generation script when creating PRs that update `keras_export` public APIs:
54+
55+
```
56+
./shell/api_gen.sh
57+
```
58+
5359
#### Adding GPU support
5460

5561
The `requirements.txt` file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also

api_gen.py

+62-36
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,41 @@ def ignore_files(_, filenames):
1818
return [f for f in filenames if f.endswith("_test.py")]
1919

2020

21-
def create_legacy_directory():
22-
API_DIR = os.path.join(package, "api")
21+
def copy_source_to_build_directory(root_path):
22+
# Copy sources (`keras/` directory and setup files) to build dir
23+
build_dir = os.path.join(root_path, "tmp_build_dir")
24+
if os.path.exists(build_dir):
25+
shutil.rmtree(build_dir)
26+
os.mkdir(build_dir)
27+
shutil.copytree(
28+
package, os.path.join(build_dir, package), ignore=ignore_files
29+
)
30+
return build_dir
31+
32+
33+
def create_legacy_directory(package_dir):
34+
src_dir = os.path.join(package_dir, "src")
35+
api_dir = os.path.join(package_dir, "api")
2336
# Make keras/_tf_keras/ by copying keras/
24-
tf_keras_dirpath_parent = os.path.join(API_DIR, "_tf_keras")
37+
tf_keras_dirpath_parent = os.path.join(api_dir, "_tf_keras")
2538
tf_keras_dirpath = os.path.join(tf_keras_dirpath_parent, "keras")
2639
os.makedirs(tf_keras_dirpath, exist_ok=True)
2740
with open(os.path.join(tf_keras_dirpath_parent, "__init__.py"), "w") as f:
2841
f.write("from keras.api._tf_keras import keras\n")
29-
with open(os.path.join(API_DIR, "__init__.py")) as f:
42+
with open(os.path.join(api_dir, "__init__.py")) as f:
3043
init_file = f.read()
3144
init_file = init_file.replace(
3245
"from keras.api import _legacy",
3346
"from keras.api import _tf_keras",
3447
)
35-
with open(os.path.join(API_DIR, "__init__.py"), "w") as f:
48+
with open(os.path.join(api_dir, "__init__.py"), "w") as f:
3649
f.write(init_file)
3750
# Remove the import of `_tf_keras` in `keras/_tf_keras/keras/__init__.py`
3851
init_file = init_file.replace("from keras.api import _tf_keras\n", "\n")
3952
with open(os.path.join(tf_keras_dirpath, "__init__.py"), "w") as f:
4053
f.write(init_file)
41-
for dirname in os.listdir(API_DIR):
42-
dirpath = os.path.join(API_DIR, dirname)
54+
for dirname in os.listdir(api_dir):
55+
dirpath = os.path.join(api_dir, dirname)
4356
if os.path.isdir(dirpath) and dirname not in (
4457
"_legacy",
4558
"_tf_keras",
@@ -57,16 +70,16 @@ def create_legacy_directory():
5770
# Copy keras/_legacy/ file contents to keras/_tf_keras/keras
5871
legacy_submodules = [
5972
path[:-3]
60-
for path in os.listdir(os.path.join(package, "src", "legacy"))
73+
for path in os.listdir(os.path.join(src_dir, "legacy"))
6174
if path.endswith(".py")
6275
]
6376
legacy_submodules += [
6477
path
65-
for path in os.listdir(os.path.join(package, "src", "legacy"))
66-
if os.path.isdir(os.path.join(package, "src", "legacy", path))
78+
for path in os.listdir(os.path.join(src_dir, "legacy"))
79+
if os.path.isdir(os.path.join(src_dir, "legacy", path))
6780
]
6881

69-
for root, _, fnames in os.walk(os.path.join(package, "_legacy")):
82+
for root, _, fnames in os.walk(os.path.join(package_dir, "_legacy")):
7083
for fname in fnames:
7184
if fname.endswith(".py"):
7285
legacy_fpath = os.path.join(root, fname)
@@ -102,19 +115,18 @@ def create_legacy_directory():
102115
f.write(legacy_contents)
103116

104117
# Delete keras/api/_legacy/
105-
shutil.rmtree(os.path.join(API_DIR, "_legacy"))
118+
shutil.rmtree(os.path.join(api_dir, "_legacy"))
106119

107120

108-
def export_version_string():
109-
API_INIT = os.path.join(package, "api", "__init__.py")
110-
with open(API_INIT) as f:
121+
def export_version_string(api_init_fname):
122+
with open(api_init_fname) as f:
111123
contents = f.read()
112-
with open(API_INIT, "w") as f:
124+
with open(api_init_fname, "w") as f:
113125
contents += "from keras.src.version import __version__\n"
114126
f.write(contents)
115127

116128

117-
def update_package_init():
129+
def update_package_init(init_fname):
118130
contents = """
119131
# Import everything from /api/ into keras.
120132
from keras.api import * # noqa: F403
@@ -142,34 +154,48 @@ def __dir__():
142154
for name in globals().keys()
143155
if not (name.startswith("_") or name in ("src", "api"))
144156
]"""
145-
with open(os.path.join(package, "__init__.py")) as f:
157+
with open(init_fname) as f:
146158
init_contents = f.read()
147-
with open(os.path.join(package, "__init__.py"), "w") as f:
159+
with open(init_fname, "w") as f:
148160
f.write(init_contents.replace("\nfrom keras import api", contents))
149161

150162

151-
if __name__ == "__main__":
163+
def build():
152164
# Backup the `keras/__init__.py` and restore it on error in api gen.
153-
os.makedirs(os.path.join(package, "api"), exist_ok=True)
154-
init_fname = os.path.join(package, "__init__.py")
155-
backup_init_fname = os.path.join(package, "__init__.py.bak")
165+
root_path = os.path.dirname(os.path.abspath(__file__))
166+
code_api_dir = os.path.join(root_path, package, "api")
167+
code_init_fname = os.path.join(root_path, package, "__init__.py")
168+
# Create temp build dir
169+
build_dir = copy_source_to_build_directory(root_path)
170+
build_api_dir = os.path.join(build_dir, package, "api")
171+
build_init_fname = os.path.join(build_dir, package, "__init__.py")
172+
build_api_init_fname = os.path.join(build_api_dir, "__init__.py")
156173
try:
157-
if os.path.exists(init_fname):
158-
shutil.move(init_fname, backup_init_fname)
174+
os.chdir(build_dir)
159175
# Generates `keras/api` directory.
176+
if os.path.exists(build_api_dir):
177+
shutil.rmtree(build_api_dir)
178+
if os.path.exists(build_init_fname):
179+
os.remove(build_init_fname)
180+
os.makedirs(build_api_dir)
160181
namex.generate_api_files(
161182
"keras", code_directory="src", target_directory="api"
162183
)
163184
# Creates `keras/__init__.py` importing from `keras/api`
164-
update_package_init()
165-
except Exception as e:
166-
if os.path.exists(backup_init_fname):
167-
shutil.move(backup_init_fname, init_fname)
168-
raise e
185+
update_package_init(build_init_fname)
186+
# Add __version__ to keras package
187+
export_version_string(build_api_init_fname)
188+
# Creates `_tf_keras` with full keras API
189+
create_legacy_directory(package_dir=os.path.join(build_dir, package))
190+
# Copy back the keras/api and keras/__init__.py from build directory
191+
if os.path.exists(code_api_dir):
192+
shutil.rmtree(code_api_dir)
193+
shutil.copytree(build_api_dir, code_api_dir)
194+
shutil.copy(build_init_fname, code_init_fname)
169195
finally:
170-
if os.path.exists(backup_init_fname):
171-
os.remove(backup_init_fname)
172-
# Add __version__ to keras package
173-
export_version_string()
174-
# Creates `_tf_keras` with full keras API
175-
create_legacy_directory()
196+
# Clean up: remove the build directory (no longer needed)
197+
shutil.rmtree(build_dir)
198+
199+
200+
if __name__ == "__main__":
201+
build()

0 commit comments

Comments
 (0)