diff --git a/src/__tests__/filtercanvas.test.ts b/src/__tests__/filtercanvas.test.ts
index 5d68080..a72c40b 100644
--- a/src/__tests__/filtercanvas.test.ts
+++ b/src/__tests__/filtercanvas.test.ts
@@ -1,4 +1,8 @@
-import { getCanvasElement, loadImage } from "./../filtercanvas"
+import {
+ getCanvasElement,
+ getCanvasContext,
+ loadImage,
+} from "./../filtercanvas"
describe("FilterCanvas", () => {
beforeAll(() => (document.body.innerHTML = ``))
@@ -12,6 +16,14 @@ describe("FilterCanvas", () => {
})
})
+describe("getCanvasContext", () => {
+ it("should return a 2d context", () => {
+ const result = getCanvasContext(document.createElement("canvas"), "2d")
+
+ expect(result).toBeInstanceOf(CanvasRenderingContext2D)
+ })
+})
+
describe("loadImage", () => {
const originalImageFn = Object.getOwnPropertyDescriptor(
Image.prototype,
diff --git a/src/filtercanvas.ts b/src/filtercanvas.ts
index 8c9185c..7857c8d 100644
--- a/src/filtercanvas.ts
+++ b/src/filtercanvas.ts
@@ -10,6 +10,22 @@ export const getCanvasElement = (elem: string): HTMLCanvasElement => {
return canvas
}
+export interface CanvasContextMap {
+ "2d": CanvasRenderingContext2D
+ webgl: WebGLRenderingContext
+ webgl2: WebGL2RenderingContext
+ bitmaprenderer: ImageBitmapRenderingContext
+}
+
+export const getCanvasContext = (
+ canvas: HTMLCanvasElement,
+ contextType: T,
+): CanvasContextMap[T] => {
+ const context = canvas.getContext(contextType)
+ if (!context) throw new Error("could not return drawing context")
+ return context as CanvasContextMap[T]
+}
+
export const loadImage = (url: string): Promise =>
new Promise((resolve, reject) => {
const img = new Image()
@@ -31,7 +47,7 @@ class FilterCanvas {
constructor(elem: string, imgUrl: string) {
this.canvas = getCanvasElement(elem)
- this.context = this.canvas.getContext("2d") as CanvasRenderingContext2D
+ this.context = getCanvasContext(this.canvas, "2d")
this.frames = new FrameCounter(30)
@@ -104,4 +120,3 @@ class FilterCanvas {
}
export default FilterCanvas
-