Skip to content

Commit b3ef0b3

Browse files
authored
Move the main code to transformers_js_py/proxies.py (#55)
1 parent 8db755b commit b3ef0b3

File tree

2 files changed

+191
-190
lines changed

2 files changed

+191
-190
lines changed

transformers_js_py/__init__.py

Lines changed: 2 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -1,192 +1,4 @@
1-
import re
2-
from typing import Any, Awaitable, Union
3-
4-
import js
5-
import pyodide.code
6-
import pyodide.ffi
7-
import pyodide.webloop
8-
9-
from .url import as_url, is_url
10-
11-
try:
12-
import numpy as np
13-
except ImportError:
14-
np = None
15-
16-
try:
17-
import PIL.Image as PILImage
18-
except ImportError:
19-
PILImage = None
20-
21-
22-
rx_class_def_code = re.compile(r"^\s*class\s+([a-zA-Z0-9_]+)\s*{", re.MULTILINE)
23-
24-
25-
class TjsModuleProxy:
26-
def __init__(self, js_obj: pyodide.ffi.JsProxy):
27-
if not isinstance(js_obj, pyodide.ffi.JsProxy) or js_obj.typeof != "object":
28-
raise TypeError("js_obj must be a JS module object")
29-
self.js_obj = js_obj
30-
31-
def __getattr__(self, name: str) -> Any:
32-
res = getattr(self.js_obj, name)
33-
if isinstance(res, pyodide.ffi.JsProxy):
34-
return proxy_tjs_object(res)
35-
return res
36-
37-
def __repr__(self) -> str:
38-
return "TjsModuleProxy({})".format(", ".join(self.js_obj.object_keys()))
39-
40-
41-
class TjsProxy:
42-
def __init__(self, js_obj: pyodide.ffi.JsProxy):
43-
self._js_obj = js_obj
44-
self._is_class = self._js_obj.typeof == "function" and rx_class_def_code.match(
45-
self._js_obj.toString()
46-
) # Ref: https://stackoverflow.com/a/30760236/13103190
47-
48-
def __call__(self, *args: Any, **kwds: Any) -> Any:
49-
args = [arg._js_obj if isinstance(arg, TjsProxy) else arg for arg in args]
50-
kwds = {k: v._js_obj if isinstance(v, TjsProxy) else v for k, v in kwds.items()}
51-
args = pyodide.ffi.to_js(args)
52-
kwds = pyodide.ffi.to_js(kwds)
53-
54-
if hasattr(self._js_obj, "_call"):
55-
# Transformers.js uses a custom _call() method
56-
# to make the JS classes callable.
57-
# https://github.com/xenova/transformers.js/blob/2.4.1/src/utils/core.js#L45-L77
58-
res = self._js_obj._call(*args, **kwds)
59-
else:
60-
if self._is_class:
61-
res = self._js_obj.new(*args, **kwds)
62-
else:
63-
res = self._js_obj(*args, **kwds)
64-
65-
return wrap_or_unwrap_proxy_object(res)
66-
67-
def __getattr__(self, name: str) -> Any:
68-
res = getattr(self._js_obj, name)
69-
return wrap_or_unwrap_proxy_object(res)
70-
71-
def __getitem__(self, key: Any) -> Any:
72-
res = self._js_obj[key]
73-
return wrap_or_unwrap_proxy_object(res)
74-
75-
def __setitem__(self, key: Any, value: Any) -> None:
76-
self._js_obj[key] = value
77-
78-
def __setattr__(self, __name: str, __value: Any) -> None:
79-
if __name == "_js_obj" or __name == "_is_class":
80-
super().__setattr__(__name, __value)
81-
else:
82-
setattr(self._js_obj, __name, __value)
83-
84-
85-
class TjsRawImageClassProxy(TjsProxy):
86-
def read(
87-
self, input: Union["TjsRawImageProxy", str]
88-
) -> Awaitable["TjsRawImageProxy"]:
89-
if isinstance(input, TjsRawImageProxy):
90-
res = self._js_obj.read(input._js_obj)
91-
elif is_url(input):
92-
res = self._js_obj.read(input)
93-
else:
94-
res = self._js_obj.read(as_url(input))
95-
return wrap_or_unwrap_proxy_object(res)
96-
97-
98-
class TjsRawImageProxy(TjsProxy):
99-
def to_numpy(self):
100-
if np is None:
101-
raise RuntimeError("numpy is not available")
102-
103-
data = self._js_obj.data # Uint8ClampedArray|Uint8Array
104-
width = self._js_obj.width
105-
height = self._js_obj.height
106-
channels = self._js_obj.channels
107-
return np.asarray(data.to_py()).reshape((height, width, channels))
108-
109-
def to_pil(self):
110-
if PILImage is None:
111-
raise RuntimeError("PIL is not available")
112-
113-
numpy_img = self.to_numpy()
114-
if numpy_img.shape[2] == 1:
115-
# Gray scale image
116-
numpy_img = numpy_img[:, :, 0]
117-
return PILImage.fromarray(numpy_img)
118-
119-
def save(self, path: str):
120-
self.to_pil().save(path)
121-
122-
123-
class TjsTensorProxy(TjsProxy):
124-
def numpy(self):
125-
if np is None:
126-
raise RuntimeError("numpy is not available")
127-
128-
data = self._js_obj.data.to_py()
129-
dims = self._js_obj.dims.to_py()
130-
dtype = self._js_obj.type
131-
132-
return np.asarray(data, dtype=dtype).reshape(dims)
133-
134-
135-
def proxy_tjs_object(js_obj: pyodide.ffi.JsProxy):
136-
"""A factory function that wraps a JsProxy object wrapping a Transformers.js object
137-
into a Python object of type TjsProxy or is subclass in the case of a special object
138-
such as RawImage.
139-
"""
140-
if js_obj == js._transformers.RawImage:
141-
return TjsRawImageClassProxy(js_obj)
142-
if js_obj.constructor == js._transformers.RawImage:
143-
return TjsRawImageProxy(js_obj)
144-
if js_obj.constructor == js._transformers.Tensor:
145-
return TjsTensorProxy(js_obj)
146-
return TjsProxy(js_obj)
147-
148-
149-
def to_py_default_converter(value: pyodide.ffi.JsProxy, _ignored1, _ignored2):
150-
# Pyodide tries to convert the JS object to a Python object
151-
# as best as possible, but it doesn't always work.
152-
# In such a case, this custom converter is called
153-
# and it wraps the JS object into a TjsProxy object.
154-
return proxy_tjs_object(value)
155-
156-
157-
def wrap_or_unwrap_proxy_object(obj):
158-
if isinstance(obj, pyodide.ffi.JsProxy):
159-
if obj.typeof == "object":
160-
return obj.to_py(default_converter=to_py_default_converter)
161-
162-
return proxy_tjs_object(obj)
163-
elif isinstance(obj, pyodide.webloop.PyodideFuture):
164-
return obj.then(wrap_or_unwrap_proxy_object)
165-
return obj
166-
167-
168-
async def import_transformers_js(version: str = "latest"):
169-
pyodide.code.run_js(
170-
"""
171-
async function loadTransformersJs(version) {
172-
const isBrowserMainThread = typeof window !== 'undefined';
173-
const isWorker = typeof WorkerGlobalScope !== 'undefined' && self instanceof WorkerGlobalScope;
174-
const isBrowser = isBrowserMainThread || isWorker;
175-
const transformers = await import(isBrowser ? 'https://cdn.jsdelivr.net/npm/@xenova/transformers@' + version : '@xenova/transformers');
176-
177-
transformers.env.allowLocalModels = false;
178-
179-
globalThis._transformers = { // Convert a module to an object.
180-
...transformers,
181-
};
182-
}
183-
""" # noqa: E501
184-
)
185-
loadTransformersJsFn = js.loadTransformersJs
186-
await loadTransformersJsFn(version)
187-
188-
transformers = js._transformers
189-
return TjsModuleProxy(transformers)
190-
1+
from .proxies import import_transformers_js
2+
from .url import as_url
1913

1924
__all__ = ["as_url", "import_transformers_js"]

transformers_js_py/proxies.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import re
2+
from typing import Any, Awaitable, Union
3+
4+
import js
5+
import pyodide.code
6+
import pyodide.ffi
7+
import pyodide.webloop
8+
9+
from .url import as_url, is_url
10+
11+
try:
12+
import numpy as np
13+
except ImportError:
14+
np = None
15+
16+
try:
17+
import PIL.Image as PILImage
18+
except ImportError:
19+
PILImage = None
20+
21+
22+
rx_class_def_code = re.compile(r"^\s*class\s+([a-zA-Z0-9_]+)\s*{", re.MULTILINE)
23+
24+
25+
class TjsModuleProxy:
26+
def __init__(self, js_obj: pyodide.ffi.JsProxy):
27+
if not isinstance(js_obj, pyodide.ffi.JsProxy) or js_obj.typeof != "object":
28+
raise TypeError("js_obj must be a JS module object")
29+
self.js_obj = js_obj
30+
31+
def __getattr__(self, name: str) -> Any:
32+
res = getattr(self.js_obj, name)
33+
if isinstance(res, pyodide.ffi.JsProxy):
34+
return proxy_tjs_object(res)
35+
return res
36+
37+
def __repr__(self) -> str:
38+
return "TjsModuleProxy({})".format(", ".join(self.js_obj.object_keys()))
39+
40+
41+
class TjsProxy:
42+
def __init__(self, js_obj: pyodide.ffi.JsProxy):
43+
self._js_obj = js_obj
44+
self._is_class = self._js_obj.typeof == "function" and rx_class_def_code.match(
45+
self._js_obj.toString()
46+
) # Ref: https://stackoverflow.com/a/30760236/13103190
47+
48+
def __call__(self, *args: Any, **kwds: Any) -> Any:
49+
args = [arg._js_obj if isinstance(arg, TjsProxy) else arg for arg in args]
50+
kwds = {k: v._js_obj if isinstance(v, TjsProxy) else v for k, v in kwds.items()}
51+
args = pyodide.ffi.to_js(args)
52+
kwds = pyodide.ffi.to_js(kwds)
53+
54+
if hasattr(self._js_obj, "_call"):
55+
# Transformers.js uses a custom _call() method
56+
# to make the JS classes callable.
57+
# https://github.com/xenova/transformers.js/blob/2.4.1/src/utils/core.js#L45-L77
58+
res = self._js_obj._call(*args, **kwds)
59+
else:
60+
if self._is_class:
61+
res = self._js_obj.new(*args, **kwds)
62+
else:
63+
res = self._js_obj(*args, **kwds)
64+
65+
return wrap_or_unwrap_proxy_object(res)
66+
67+
def __getattr__(self, name: str) -> Any:
68+
res = getattr(self._js_obj, name)
69+
return wrap_or_unwrap_proxy_object(res)
70+
71+
def __getitem__(self, key: Any) -> Any:
72+
res = self._js_obj[key]
73+
return wrap_or_unwrap_proxy_object(res)
74+
75+
def __setitem__(self, key: Any, value: Any) -> None:
76+
self._js_obj[key] = value
77+
78+
def __setattr__(self, __name: str, __value: Any) -> None:
79+
if __name == "_js_obj" or __name == "_is_class":
80+
super().__setattr__(__name, __value)
81+
else:
82+
setattr(self._js_obj, __name, __value)
83+
84+
85+
class TjsRawImageClassProxy(TjsProxy):
86+
def read(
87+
self, input: Union["TjsRawImageProxy", str]
88+
) -> Awaitable["TjsRawImageProxy"]:
89+
if isinstance(input, TjsRawImageProxy):
90+
res = self._js_obj.read(input._js_obj)
91+
elif is_url(input):
92+
res = self._js_obj.read(input)
93+
else:
94+
res = self._js_obj.read(as_url(input))
95+
return wrap_or_unwrap_proxy_object(res)
96+
97+
98+
class TjsRawImageProxy(TjsProxy):
99+
def to_numpy(self):
100+
if np is None:
101+
raise RuntimeError("numpy is not available")
102+
103+
data = self._js_obj.data # Uint8ClampedArray|Uint8Array
104+
width = self._js_obj.width
105+
height = self._js_obj.height
106+
channels = self._js_obj.channels
107+
return np.asarray(data.to_py()).reshape((height, width, channels))
108+
109+
def to_pil(self):
110+
if PILImage is None:
111+
raise RuntimeError("PIL is not available")
112+
113+
numpy_img = self.to_numpy()
114+
if numpy_img.shape[2] == 1:
115+
# Gray scale image
116+
numpy_img = numpy_img[:, :, 0]
117+
return PILImage.fromarray(numpy_img)
118+
119+
def save(self, path: str):
120+
self.to_pil().save(path)
121+
122+
123+
class TjsTensorProxy(TjsProxy):
124+
def numpy(self):
125+
if np is None:
126+
raise RuntimeError("numpy is not available")
127+
128+
data = self._js_obj.data.to_py()
129+
dims = self._js_obj.dims.to_py()
130+
dtype = self._js_obj.type
131+
132+
return np.asarray(data, dtype=dtype).reshape(dims)
133+
134+
135+
def proxy_tjs_object(js_obj: pyodide.ffi.JsProxy):
136+
"""A factory function that wraps a JsProxy object wrapping a Transformers.js object
137+
into a Python object of type TjsProxy or is subclass in the case of a special object
138+
such as RawImage.
139+
"""
140+
if js_obj == js._transformers.RawImage:
141+
return TjsRawImageClassProxy(js_obj)
142+
if js_obj.constructor == js._transformers.RawImage:
143+
return TjsRawImageProxy(js_obj)
144+
if js_obj.constructor == js._transformers.Tensor:
145+
return TjsTensorProxy(js_obj)
146+
return TjsProxy(js_obj)
147+
148+
149+
def to_py_default_converter(value: pyodide.ffi.JsProxy, _ignored1, _ignored2):
150+
# Pyodide tries to convert the JS object to a Python object
151+
# as best as possible, but it doesn't always work.
152+
# In such a case, this custom converter is called
153+
# and it wraps the JS object into a TjsProxy object.
154+
return proxy_tjs_object(value)
155+
156+
157+
def wrap_or_unwrap_proxy_object(obj):
158+
if isinstance(obj, pyodide.ffi.JsProxy):
159+
if obj.typeof == "object":
160+
return obj.to_py(default_converter=to_py_default_converter)
161+
162+
return proxy_tjs_object(obj)
163+
elif isinstance(obj, pyodide.webloop.PyodideFuture):
164+
return obj.then(wrap_or_unwrap_proxy_object)
165+
return obj
166+
167+
168+
async def import_transformers_js(version: str = "latest"):
169+
pyodide.code.run_js(
170+
"""
171+
async function loadTransformersJs(version) {
172+
const isBrowserMainThread = typeof window !== 'undefined';
173+
const isWorker = typeof WorkerGlobalScope !== 'undefined' && self instanceof WorkerGlobalScope;
174+
const isBrowser = isBrowserMainThread || isWorker;
175+
const transformers = await import(isBrowser ? 'https://cdn.jsdelivr.net/npm/@xenova/transformers@' + version : '@xenova/transformers');
176+
177+
transformers.env.allowLocalModels = false;
178+
179+
globalThis._transformers = { // Convert a module to an object.
180+
...transformers,
181+
};
182+
}
183+
""" # noqa: E501
184+
)
185+
loadTransformersJsFn = js.loadTransformersJs
186+
await loadTransformersJsFn(version)
187+
188+
transformers = js._transformers
189+
return TjsModuleProxy(transformers)

0 commit comments

Comments
 (0)