diff --git a/src/__tests__/filtercanvas.test.ts b/src/__tests__/filtercanvas.test.ts index 5d68080..a72c40b 100644 --- a/src/__tests__/filtercanvas.test.ts +++ b/src/__tests__/filtercanvas.test.ts @@ -1,4 +1,8 @@ -import { getCanvasElement, loadImage } from "./../filtercanvas" +import { + getCanvasElement, + getCanvasContext, + loadImage, +} from "./../filtercanvas" describe("FilterCanvas", () => { beforeAll(() => (document.body.innerHTML = ``)) @@ -12,6 +16,14 @@ describe("FilterCanvas", () => { }) }) +describe("getCanvasContext", () => { + it("should return a 2d context", () => { + const result = getCanvasContext(document.createElement("canvas"), "2d") + + expect(result).toBeInstanceOf(CanvasRenderingContext2D) + }) +}) + describe("loadImage", () => { const originalImageFn = Object.getOwnPropertyDescriptor( Image.prototype, diff --git a/src/filtercanvas.ts b/src/filtercanvas.ts index 8c9185c..7857c8d 100644 --- a/src/filtercanvas.ts +++ b/src/filtercanvas.ts @@ -10,6 +10,22 @@ export const getCanvasElement = (elem: string): HTMLCanvasElement => { return canvas } +export interface CanvasContextMap { + "2d": CanvasRenderingContext2D + webgl: WebGLRenderingContext + webgl2: WebGL2RenderingContext + bitmaprenderer: ImageBitmapRenderingContext +} + +export const getCanvasContext = ( + canvas: HTMLCanvasElement, + contextType: T, +): CanvasContextMap[T] => { + const context = canvas.getContext(contextType) + if (!context) throw new Error("could not return drawing context") + return context as CanvasContextMap[T] +} + export const loadImage = (url: string): Promise => new Promise((resolve, reject) => { const img = new Image() @@ -31,7 +47,7 @@ class FilterCanvas { constructor(elem: string, imgUrl: string) { this.canvas = getCanvasElement(elem) - this.context = this.canvas.getContext("2d") as CanvasRenderingContext2D + this.context = getCanvasContext(this.canvas, "2d") this.frames = new FrameCounter(30) @@ -104,4 +120,3 @@ class FilterCanvas { } export default FilterCanvas -