Skip to content

Commit 8f5f388

Browse files
authored
Add TjsRawImageProxy and fix the proxying logic to support the depth estimation feature (#30)
* Add TjsRawImageProxy and fix the proxying logic to support the depth estimation feature * Support new RawImage() constructor * Support RawImage.read() without as_url() * Fix * Pin Git-hosted resources' revisions
1 parent ff16478 commit 8f5f388

File tree

4 files changed

+243
-13
lines changed

4 files changed

+243
-13
lines changed

pyodide-e2e/src/tests/pipeline.test.ts

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ suite("transformers.pipeline", () => {
66
let pyodide: PyodideInterface;
77

88
beforeEach(async () => {
9-
pyodide = await setupPyodideForTest();
9+
pyodide = await setupPyodideForTest(["numpy", "Pillow"]);
1010
});
1111

1212
test("zero-shot-image-classification with a local file wrapped by as_url()", async () => {
13-
await fetch("https://huggingface.co/spaces/gradio/image_mod/resolve/main/images/lion.jpg")
13+
await fetch("https://huggingface.co/spaces/gradio/image_mod/resolve/e07924a/images/lion.jpg")
1414
.then((response) => response.blob())
1515
.then((blob) => blob.arrayBuffer())
1616
.then((arrayBuffer) => {
@@ -38,4 +38,45 @@ result = {item['label']: round(item['score'], 2) for item in data}
3838
const topLabel = Object.keys(resultObj).reduce((a, b) => resultObj[a] > resultObj[b] ? a : b);
3939
expect(topLabel).toEqual("lion");
4040
});
41+
42+
test("depth-estimation", async () => {
43+
await fetch("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/db8bd36/bread_small.png")
44+
.then((response) => response.blob())
45+
.then((blob) => blob.arrayBuffer())
46+
.then((arrayBuffer) => {
47+
const fileData = new Uint8Array(arrayBuffer);
48+
const filePath = "/tmp/bread_small.png";
49+
pyodide.FS.writeFile(filePath, fileData);
50+
});
51+
52+
await pyodide.runPythonAsync(`
53+
from transformers_js_py import import_transformers_js, as_url
54+
55+
transformers = await import_transformers_js()
56+
57+
pipeline = transformers.pipeline
58+
RawImage = transformers.RawImage
59+
60+
depth_estimator = await pipeline('depth-estimation', 'Xenova/depth-anything-small-hf');
61+
62+
output = await depth_estimator(as_url("/tmp/bread_small.png"))
63+
`);
64+
const outputMap = await pyodide.globals.get("output").toJs() // Python's dict to JS's Map
65+
const output = Object.fromEntries(outputMap);
66+
67+
const depth = output.depth.toJs();
68+
const predictedDepth = output.predicted_depth.toJs();
69+
70+
// API reference: https://huggingface.co/Xenova/depth-anything-small-hf
71+
expect(depth.width).toBe(640)
72+
expect(depth.height).toBe(424)
73+
expect(depth.channels).toBe(1)
74+
expect(predictedDepth).toBeDefined()
75+
76+
await pyodide.runPythonAsync(`
77+
output["depth"].save('/tmp/depth.png')
78+
`);
79+
const depthImage: Uint8Array = pyodide.FS.readFile("/tmp/depth.png", { encoding: "binary" });
80+
// TODO: How to assert the depth image? Image snapshot is not available in the browser env.
81+
})
4182
});
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import type { PyodideInterface } from "pyodide";
2+
import { beforeEach, describe, it, expect } from "vitest";
3+
import { setupPyodideForTest } from "./utils";
4+
5+
describe("RawImage", () => {
6+
let pyodide: PyodideInterface;
7+
8+
beforeEach(async () => {
9+
pyodide = await setupPyodideForTest(["numpy", "Pillow"]);
10+
});
11+
12+
it("can be initialized via .fromURL()", async () => {
13+
await pyodide.runPythonAsync(`
14+
from transformers_js_py import import_transformers_js
15+
transformers = await import_transformers_js()
16+
RawImage = transformers.RawImage
17+
raw_image = await RawImage.fromURL('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/db8bd36/bread_small.png')
18+
`)
19+
const rawImage = await pyodide.globals.get("raw_image").toJs();
20+
expect(rawImage).toBeDefined();
21+
expect(rawImage).toHaveProperty("width", 640);
22+
expect(rawImage).toHaveProperty("height", 424);
23+
expect(rawImage).toHaveProperty("channels", 4);
24+
});
25+
26+
it("can be initialized from a local file via .read()", async () => {
27+
await fetch("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/db8bd36/bread_small.png")
28+
.then((response) => response.blob())
29+
.then((blob) => blob.arrayBuffer())
30+
.then((arrayBuffer) => {
31+
const fileData = new Uint8Array(arrayBuffer);
32+
const filePath = "/tmp/bread_small.png";
33+
pyodide.FS.writeFile(filePath, fileData);
34+
});
35+
36+
await pyodide.runPythonAsync(`
37+
from transformers_js_py import import_transformers_js
38+
transformers = await import_transformers_js()
39+
40+
RawImage = transformers.RawImage
41+
raw_image = await RawImage.read("/tmp/bread_small.png")
42+
`)
43+
const rawImage = await pyodide.globals.get("raw_image").toJs();
44+
expect(rawImage).toBeDefined();
45+
expect(rawImage).toHaveProperty("width", 640);
46+
expect(rawImage).toHaveProperty("height", 424);
47+
expect(rawImage).toHaveProperty("channels", 4);
48+
});
49+
50+
it("can be initialized via the constructor", async () => {
51+
await pyodide.runPythonAsync(`
52+
from transformers_js_py import import_transformers_js
53+
transformers = await import_transformers_js()
54+
55+
RawImage = transformers.RawImage
56+
raw_image = RawImage(bytes([0] * 16*10*3), 16, 10, 3)
57+
`)
58+
59+
const rawImage = await pyodide.globals.get("raw_image").toJs();
60+
expect(rawImage).toBeDefined();
61+
expect(rawImage).toHaveProperty("width", 16);
62+
expect(rawImage).toHaveProperty("height", 10);
63+
expect(rawImage).toHaveProperty("channels", 3);
64+
});
65+
66+
it("can be transformed into a numpy array and a PIL image and saved to a local file", async () => {
67+
await pyodide.runPythonAsync(`
68+
from transformers_js_py import import_transformers_js
69+
transformers = await import_transformers_js()
70+
RawImage = transformers.RawImage
71+
raw_image = await RawImage.fromURL('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/db8bd36/bread_small.png')
72+
73+
numpy_array = raw_image.to_numpy()
74+
pil_image = raw_image.to_pil()
75+
raw_image.save('/tmp/bread_small.png')
76+
`)
77+
const numpyArrayProxy = await pyodide.globals.get("numpy_array");
78+
const numpyArray = numpyArrayProxy.getBuffer("u8");
79+
expect(numpyArray.shape).toEqual([424, 640, 4]);
80+
81+
const pilImage = await pyodide.globals.get("pil_image")
82+
expect(pilImage.width).toBe(640);
83+
expect(pilImage.height).toBe(424);
84+
expect(pilImage.mode).toBe("RGBA");
85+
});
86+
87+
describe("Color transform methods such as .grayscale(), rgb(), and .rgba()", async () => {
88+
([["grayscale", 1], ["rgb", 3], ["rgba", 4]] as const).forEach(([source, sourceChannels]) => {
89+
describe(`${source} image`, () => {
90+
beforeEach(async () => {
91+
await pyodide.runPythonAsync(`
92+
from transformers_js_py import import_transformers_js
93+
transformers = await import_transformers_js()
94+
95+
RawImage = transformers.RawImage
96+
raw_image = RawImage(bytes([0] * 16*10*3), 16, 10, ${sourceChannels})
97+
`)
98+
});
99+
100+
([["grayscale", 1], ["rgb", 3], ["rgba", 4]] as const).forEach(([target, targetChannels]) => {
101+
it(`can be transformed into a ${target} image`, async () => {
102+
await pyodide.runPythonAsync(`
103+
converted_image = raw_image.${target}()
104+
`)
105+
const convertedImage = await pyodide.globals.get("converted_image").toJs();
106+
expect(convertedImage).toBeDefined();
107+
expect(convertedImage).toHaveProperty("width", 16);
108+
expect(convertedImage).toHaveProperty("height", 10);
109+
expect(convertedImage).toHaveProperty("channels", targetChannels);
110+
});
111+
});
112+
113+
});
114+
});
115+
});
116+
})

pyodide-e2e/src/tests/utils.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import wheelUrl from "transformers-js-py.whl"; // This is the alias from vite.c
55

66
export const IS_NODE = typeof window === 'undefined';
77

8-
export async function setupPyodideForTest(): Promise<PyodideInterface> {
8+
export async function setupPyodideForTest(requirements: string[] = []): Promise<PyodideInterface> {
99
const pyodide = await loadPyodide({
1010
indexURL: IS_NODE
1111
? "node_modules/pyodide" // pnpm puts pyodide at this path
@@ -25,6 +25,8 @@ export async function setupPyodideForTest(): Promise<PyodideInterface> {
2525
await micropip.install(wheelUrl);
2626
}
2727

28+
await micropip.install(requirements);
29+
2830
await pyodide.runPythonAsync(`
2931
from transformers_js_py import import_transformers_js
3032
transformers = await import_transformers_js()

transformers_js_py/__init__.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any
1+
import re
2+
from typing import Any, Union
23

34
import js
45
import pyodide.code
@@ -7,6 +8,19 @@
78

89
from .url import as_url
910

11+
try:
12+
import numpy as np
13+
except ImportError:
14+
np = None
15+
16+
try:
17+
import PIL.Image as PILImage
18+
except ImportError:
19+
PILImage = None
20+
21+
22+
rx_class_def_code = re.compile(r"^\s*class\s+([a-zA-Z0-9_]+)\s*{", re.MULTILINE)
23+
1024

1125
class TjsModuleProxy:
1226
def __init__(self, js_obj: pyodide.ffi.JsProxy):
@@ -17,7 +31,7 @@ def __init__(self, js_obj: pyodide.ffi.JsProxy):
1731
def __getattr__(self, name: str) -> Any:
1832
res = getattr(self.js_obj, name)
1933
if isinstance(res, pyodide.ffi.JsProxy):
20-
return TjsProxy(res)
34+
return proxy_tjs_object(res)
2135
return res
2236

2337
def __repr__(self) -> str:
@@ -27,18 +41,24 @@ def __repr__(self) -> str:
2741
class TjsProxy:
2842
def __init__(self, js_obj: pyodide.ffi.JsProxy):
2943
self._js_obj = js_obj
44+
self._is_class = self._js_obj.typeof == "function" and rx_class_def_code.match(
45+
self._js_obj.toString()
46+
) # Ref: https://stackoverflow.com/a/30760236/13103190
3047

3148
def __call__(self, *args: Any, **kwds: Any) -> Any:
32-
if hasattr(self._js_obj, "_call"):
33-
args = pyodide.ffi.to_js(args)
34-
kwds = pyodide.ffi.to_js(kwds)
49+
args = pyodide.ffi.to_js(args)
50+
kwds = pyodide.ffi.to_js(kwds)
3551

52+
if hasattr(self._js_obj, "_call"):
3653
# Transformers.js uses a custom _call() method
3754
# to make the JS classes callable.
3855
# https://github.com/xenova/transformers.js/blob/2.4.1/src/utils/core.js#L45-L77
3956
res = self._js_obj._call(*args, **kwds)
4057
else:
41-
res = self._js_obj(*args, **kwds)
58+
if self._is_class:
59+
res = self._js_obj.new(*args, **kwds)
60+
else:
61+
res = self._js_obj(*args, **kwds)
4262

4363
return wrap_or_unwrap_proxy_object(res)
4464

@@ -54,17 +74,68 @@ def __setitem__(self, key: Any, value: Any) -> None:
5474
self._js_obj[key] = value
5575

5676
def __setattr__(self, __name: str, __value: Any) -> None:
57-
if __name == "_js_obj":
58-
super().__setattr__("_js_obj", __value)
77+
if __name == "_js_obj" or __name == "_is_class":
78+
super().__setattr__(__name, __value)
5979
else:
6080
setattr(self._js_obj, __name, __value)
6181

6282

83+
class TjsRawImageClassProxy(TjsProxy):
84+
def read(self, input: Union["TjsRawImageProxy", str]):
85+
return wrap_or_unwrap_proxy_object(self._js_obj.read(as_url(input)))
86+
87+
88+
class TjsRawImageProxy(TjsProxy):
89+
def to_numpy(self):
90+
if np is None:
91+
raise RuntimeError("numpy is not available")
92+
93+
data = self._js_obj.data # Uint8ClampedArray|Uint8Array
94+
width = self._js_obj.width
95+
height = self._js_obj.height
96+
channels = self._js_obj.channels
97+
return np.asarray(data.to_py()).reshape((height, width, channels))
98+
99+
def to_pil(self):
100+
if PILImage is None:
101+
raise RuntimeError("PIL is not available")
102+
103+
numpy_img = self.to_numpy()
104+
if numpy_img.shape[2] == 1:
105+
# Gray scale image
106+
numpy_img = numpy_img[:, :, 0]
107+
return PILImage.fromarray(numpy_img)
108+
109+
def save(self, path: str):
110+
self.to_pil().save(path)
111+
112+
113+
def proxy_tjs_object(js_obj: pyodide.ffi.JsProxy):
114+
"""A factory function that wraps a JsProxy object wrapping a Transformers.js object
115+
into a Python object of type TjsProxy or is subclass in the case of a special object
116+
such as RawImage.
117+
"""
118+
if js_obj == js._transformers.RawImage:
119+
return TjsRawImageClassProxy(js_obj)
120+
if js_obj.constructor == js._transformers.RawImage:
121+
return TjsRawImageProxy(js_obj)
122+
return TjsProxy(js_obj)
123+
124+
125+
def to_py_default_converter(value: pyodide.ffi.JsProxy, _ignored1, _ignored2):
126+
# Pyodide tries to convert the JS object to a Python object
127+
# as best as possible, but it doesn't always work.
128+
# In such a case, this custom converter is called
129+
# and it wraps the JS object into a TjsProxy object.
130+
return proxy_tjs_object(value)
131+
132+
63133
def wrap_or_unwrap_proxy_object(obj):
64134
if isinstance(obj, pyodide.ffi.JsProxy):
65135
if obj.typeof == "object":
66-
return obj.to_py()
67-
return TjsProxy(obj)
136+
return obj.to_py(default_converter=to_py_default_converter)
137+
138+
return proxy_tjs_object(obj)
68139
elif isinstance(obj, pyodide.webloop.PyodideFuture):
69140
return obj.then(wrap_or_unwrap_proxy_object)
70141
return obj

0 commit comments

Comments
 (0)