|
1 | | -import re |
2 | | -from typing import Any, Awaitable, Union |
3 | | - |
4 | | -import js |
5 | | -import pyodide.code |
6 | | -import pyodide.ffi |
7 | | -import pyodide.webloop |
8 | | - |
9 | | -from .url import as_url, is_url |
10 | | - |
11 | | -try: |
12 | | - import numpy as np |
13 | | -except ImportError: |
14 | | - np = None |
15 | | - |
16 | | -try: |
17 | | - import PIL.Image as PILImage |
18 | | -except ImportError: |
19 | | - PILImage = None |
20 | | - |
21 | | - |
22 | | -rx_class_def_code = re.compile(r"^\s*class\s+([a-zA-Z0-9_]+)\s*{", re.MULTILINE) |
23 | | - |
24 | | - |
25 | | -class TjsModuleProxy: |
26 | | - def __init__(self, js_obj: pyodide.ffi.JsProxy): |
27 | | - if not isinstance(js_obj, pyodide.ffi.JsProxy) or js_obj.typeof != "object": |
28 | | - raise TypeError("js_obj must be a JS module object") |
29 | | - self.js_obj = js_obj |
30 | | - |
31 | | - def __getattr__(self, name: str) -> Any: |
32 | | - res = getattr(self.js_obj, name) |
33 | | - if isinstance(res, pyodide.ffi.JsProxy): |
34 | | - return proxy_tjs_object(res) |
35 | | - return res |
36 | | - |
37 | | - def __repr__(self) -> str: |
38 | | - return "TjsModuleProxy({})".format(", ".join(self.js_obj.object_keys())) |
39 | | - |
40 | | - |
41 | | -class TjsProxy: |
42 | | - def __init__(self, js_obj: pyodide.ffi.JsProxy): |
43 | | - self._js_obj = js_obj |
44 | | - self._is_class = self._js_obj.typeof == "function" and rx_class_def_code.match( |
45 | | - self._js_obj.toString() |
46 | | - ) # Ref: https://stackoverflow.com/a/30760236/13103190 |
47 | | - |
48 | | - def __call__(self, *args: Any, **kwds: Any) -> Any: |
49 | | - args = [arg._js_obj if isinstance(arg, TjsProxy) else arg for arg in args] |
50 | | - kwds = {k: v._js_obj if isinstance(v, TjsProxy) else v for k, v in kwds.items()} |
51 | | - args = pyodide.ffi.to_js(args) |
52 | | - kwds = pyodide.ffi.to_js(kwds) |
53 | | - |
54 | | - if hasattr(self._js_obj, "_call"): |
55 | | - # Transformers.js uses a custom _call() method |
56 | | - # to make the JS classes callable. |
57 | | - # https://github.com/xenova/transformers.js/blob/2.4.1/src/utils/core.js#L45-L77 |
58 | | - res = self._js_obj._call(*args, **kwds) |
59 | | - else: |
60 | | - if self._is_class: |
61 | | - res = self._js_obj.new(*args, **kwds) |
62 | | - else: |
63 | | - res = self._js_obj(*args, **kwds) |
64 | | - |
65 | | - return wrap_or_unwrap_proxy_object(res) |
66 | | - |
67 | | - def __getattr__(self, name: str) -> Any: |
68 | | - res = getattr(self._js_obj, name) |
69 | | - return wrap_or_unwrap_proxy_object(res) |
70 | | - |
71 | | - def __getitem__(self, key: Any) -> Any: |
72 | | - res = self._js_obj[key] |
73 | | - return wrap_or_unwrap_proxy_object(res) |
74 | | - |
75 | | - def __setitem__(self, key: Any, value: Any) -> None: |
76 | | - self._js_obj[key] = value |
77 | | - |
78 | | - def __setattr__(self, __name: str, __value: Any) -> None: |
79 | | - if __name == "_js_obj" or __name == "_is_class": |
80 | | - super().__setattr__(__name, __value) |
81 | | - else: |
82 | | - setattr(self._js_obj, __name, __value) |
83 | | - |
84 | | - |
85 | | -class TjsRawImageClassProxy(TjsProxy): |
86 | | - def read( |
87 | | - self, input: Union["TjsRawImageProxy", str] |
88 | | - ) -> Awaitable["TjsRawImageProxy"]: |
89 | | - if isinstance(input, TjsRawImageProxy): |
90 | | - res = self._js_obj.read(input._js_obj) |
91 | | - elif is_url(input): |
92 | | - res = self._js_obj.read(input) |
93 | | - else: |
94 | | - res = self._js_obj.read(as_url(input)) |
95 | | - return wrap_or_unwrap_proxy_object(res) |
96 | | - |
97 | | - |
98 | | -class TjsRawImageProxy(TjsProxy): |
99 | | - def to_numpy(self): |
100 | | - if np is None: |
101 | | - raise RuntimeError("numpy is not available") |
102 | | - |
103 | | - data = self._js_obj.data # Uint8ClampedArray|Uint8Array |
104 | | - width = self._js_obj.width |
105 | | - height = self._js_obj.height |
106 | | - channels = self._js_obj.channels |
107 | | - return np.asarray(data.to_py()).reshape((height, width, channels)) |
108 | | - |
109 | | - def to_pil(self): |
110 | | - if PILImage is None: |
111 | | - raise RuntimeError("PIL is not available") |
112 | | - |
113 | | - numpy_img = self.to_numpy() |
114 | | - if numpy_img.shape[2] == 1: |
115 | | - # Gray scale image |
116 | | - numpy_img = numpy_img[:, :, 0] |
117 | | - return PILImage.fromarray(numpy_img) |
118 | | - |
119 | | - def save(self, path: str): |
120 | | - self.to_pil().save(path) |
121 | | - |
122 | | - |
123 | | -class TjsTensorProxy(TjsProxy): |
124 | | - def numpy(self): |
125 | | - if np is None: |
126 | | - raise RuntimeError("numpy is not available") |
127 | | - |
128 | | - data = self._js_obj.data.to_py() |
129 | | - dims = self._js_obj.dims.to_py() |
130 | | - dtype = self._js_obj.type |
131 | | - |
132 | | - return np.asarray(data, dtype=dtype).reshape(dims) |
133 | | - |
134 | | - |
135 | | -def proxy_tjs_object(js_obj: pyodide.ffi.JsProxy): |
136 | | - """A factory function that wraps a JsProxy object wrapping a Transformers.js object |
137 | | - into a Python object of type TjsProxy or is subclass in the case of a special object |
138 | | - such as RawImage. |
139 | | - """ |
140 | | - if js_obj == js._transformers.RawImage: |
141 | | - return TjsRawImageClassProxy(js_obj) |
142 | | - if js_obj.constructor == js._transformers.RawImage: |
143 | | - return TjsRawImageProxy(js_obj) |
144 | | - if js_obj.constructor == js._transformers.Tensor: |
145 | | - return TjsTensorProxy(js_obj) |
146 | | - return TjsProxy(js_obj) |
147 | | - |
148 | | - |
149 | | -def to_py_default_converter(value: pyodide.ffi.JsProxy, _ignored1, _ignored2): |
150 | | - # Pyodide tries to convert the JS object to a Python object |
151 | | - # as best as possible, but it doesn't always work. |
152 | | - # In such a case, this custom converter is called |
153 | | - # and it wraps the JS object into a TjsProxy object. |
154 | | - return proxy_tjs_object(value) |
155 | | - |
156 | | - |
157 | | -def wrap_or_unwrap_proxy_object(obj): |
158 | | - if isinstance(obj, pyodide.ffi.JsProxy): |
159 | | - if obj.typeof == "object": |
160 | | - return obj.to_py(default_converter=to_py_default_converter) |
161 | | - |
162 | | - return proxy_tjs_object(obj) |
163 | | - elif isinstance(obj, pyodide.webloop.PyodideFuture): |
164 | | - return obj.then(wrap_or_unwrap_proxy_object) |
165 | | - return obj |
166 | | - |
167 | | - |
168 | | -async def import_transformers_js(version: str = "latest"): |
169 | | - pyodide.code.run_js( |
170 | | - """ |
171 | | - async function loadTransformersJs(version) { |
172 | | - const isBrowserMainThread = typeof window !== 'undefined'; |
173 | | - const isWorker = typeof WorkerGlobalScope !== 'undefined' && self instanceof WorkerGlobalScope; |
174 | | - const isBrowser = isBrowserMainThread || isWorker; |
175 | | - const transformers = await import(isBrowser ? 'https://cdn.jsdelivr.net/npm/@xenova/transformers@' + version : '@xenova/transformers'); |
176 | | -
|
177 | | - transformers.env.allowLocalModels = false; |
178 | | -
|
179 | | - globalThis._transformers = { // Convert a module to an object. |
180 | | - ...transformers, |
181 | | - }; |
182 | | - } |
183 | | - """ # noqa: E501 |
184 | | - ) |
185 | | - loadTransformersJsFn = js.loadTransformersJs |
186 | | - await loadTransformersJsFn(version) |
187 | | - |
188 | | - transformers = js._transformers |
189 | | - return TjsModuleProxy(transformers) |
190 | | - |
| 1 | +from .proxies import import_transformers_js |
| 2 | +from .url import as_url |
191 | 3 |
|
192 | 4 | __all__ = ["as_url", "import_transformers_js"] |
0 commit comments