diff --git a/README.md b/README.md index 031fa4d..246eb20 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ yarn add zustand-computed The middleware layer takes in your store creation function and a compute function, which transforms your state into a computed state. It does not need to handle merging states. ```js -import computed from "zustand-computed" +import { createComputed } from "zustand-computed" const computeState = (state) => ({ countSq: state.count ** 2, @@ -47,7 +47,7 @@ const useStore = create( With types, the previous example would look like this: ```ts -import computed from "zustand-computed" +import { createComputed } from "zustand-computed" type Store = { count: number @@ -59,11 +59,10 @@ type ComputedStore = { countSq: number } -const computeState = (state: Store): ComputedStore => ({ +const computed = createComputed((state: Store): ComputedStore => ({ countSq: state.count ** 2, -}) +})) -// use curried create const useStore = create()( computed( (set) => ({ @@ -73,8 +72,7 @@ const useStore = create()( // get() function has access to ComputedStore square: () => set(() => ({ count: get().countSq })), root: () => set((state) => ({ count: Math.floor(Math.sqrt(state.count)) })), - }), - computeState + }) ) ) ``` @@ -103,10 +101,8 @@ A fully-featured example can be found under the "example" directory. Here's an example with the Immer middleware. -> [!WARNING] -> Immer derives the SetState type from the output of GetState, where `zustand-computed` types SetState to allow only the regular Store and types GetState to return both the store and the computed store. To avoid this issue, you may need to apply Immer outside of `zustand-computed`. If `zustand-computed` must be outside of Immer, you will need to assert the `Store` type as `Store & ComputedStore`. - ```ts +const computed = createComputed((state: Store) => { /* ... */ }) const useStore = create()( devtools( immer( @@ -120,7 +116,6 @@ const useStore = create()( }), dec: () => set((state) => ({ count: state.count - 1 })), }), - computeState ), ) ) @@ -129,9 +124,10 @@ const useStore = create()( ## Selectors -By default, when `zustand-computed` runs your `computeState` function, it tracks accessed variables and does not trigger a computation if one of those variables do not change. This could potentially be problematic if you have nested control flow inside of `computeState`, or perhaps you want it to run on _all_ changes regardless of use inside of `computeState`. To disable automatic selector detection, you can pass a third `opts` variable to the `computed` constructor, e.g. +By default, when `zustand-computed` runs your `computeState` function, it tracks accessed variables and does not trigger a computation if one of those variables do not change. This could potentially be problematic if you have nested control flow inside of `computeState`, or perhaps you want it to run on _all_ changes regardless of use inside of `computeState`. To disable automatic selector detection, you can pass a second `opts` variable to the `createComputed` function, e.g. ```ts +const computed = createComputed((state: Store) => { /* ... */ }, { disableProxy: true }) const useStore = create( computed( (set) => ({ @@ -139,8 +135,6 @@ const useStore = create set((state) => ({ count: state.count + 1 })), dec: () => set((state) => ({ count: state.count - 1 })), }), - computeState, - { disableProxy: true } ) ) ``` diff --git a/src/computed.test.ts b/src/computed.test.ts index 5980562..9e1729b 100644 --- a/src/computed.test.ts +++ b/src/computed.test.ts @@ -1,200 +1,194 @@ -import { describe, expect, test, beforeEach, mock } from "bun:test" -import { type StateCreator, create } from "zustand" -import { type ComputedStateOpts, computed } from "./computed" +import { describe, expect, test, beforeEach, mock } from "bun:test"; +import { type StateCreator, create } from "zustand"; +import { type ComputedStateOpts, createComputed } from "./computed"; type Store = { - count: number - x: number - y: number - inc: () => void - dec: () => void -} + count: number; + x: number; + y: number; + inc: () => void; + dec: () => void; +}; type ComputedStore = { - countSq: number + countSq: number; nestedResult: { - stringified: string - } -} + stringified: string; + }; +}; function computeState(state: Store): ComputedStore { const nestedResult = { stringified: JSON.stringify(state.count), - } + }; return { countSq: state.count ** 2, nestedResult, - } + }; } describe("default config", () => { - const computeStateMock = mock(computeState) + const computeStateMock = mock(computeState); + const computed = createComputed(computeStateMock); const makeStore = () => create( - computed( - (set) => ({ - count: 1, - x: 1, - y: 1, - inc: () => set((state) => ({ count: state.count + 1 })), - dec: () => set((state) => ({ count: state.count - 1 })), - }), - computeStateMock, - ), - ) - - let useStore: ReturnType + computed((set) => ({ + count: 1, + x: 1, + y: 1, + inc: () => set((state) => ({ count: state.count + 1 })), + dec: () => set((state) => ({ count: state.count - 1 })), + })), + ); + + let useStore: ReturnType; beforeEach(() => { - computeStateMock.mockClear() - useStore = makeStore() - }) + computeStateMock.mockClear(); + useStore = makeStore(); + }); test("computed works on simple counter example", () => { // note: this function should have been called once on store creation - expect(computeStateMock).toHaveBeenCalledTimes(1) - expect(useStore.getState().count).toEqual(1) - expect(useStore.getState().countSq).toEqual(1) - useStore.getState().inc() - expect(useStore.getState().count).toEqual(2) - expect(useStore.getState().countSq).toEqual(4) - useStore.getState().dec() - expect(useStore.getState().count).toEqual(1) - expect(useStore.getState().countSq).toEqual(1) - useStore.setState({ count: 4 }) - expect(useStore.getState().countSq).toEqual(16) - expect(computeStateMock).toHaveBeenCalledTimes(4) - }) + expect(computeStateMock).toHaveBeenCalledTimes(1); + expect(useStore.getState().count).toEqual(1); + expect(useStore.getState().countSq).toEqual(1); + useStore.getState().inc(); + expect(useStore.getState().count).toEqual(2); + expect(useStore.getState().countSq).toEqual(4); + useStore.getState().dec(); + expect(useStore.getState().count).toEqual(1); + expect(useStore.getState().countSq).toEqual(1); + useStore.setState({ count: 4 }); + expect(useStore.getState().countSq).toEqual(16); + expect(computeStateMock).toHaveBeenCalledTimes(4); + }); test("computed does not modify object ref even after change", () => { - useStore.setState({ count: 4 }) - expect(useStore.getState().count).toEqual(4) - const obj = useStore.getState().nestedResult - useStore.setState({ count: 4 }) - const toCompare = useStore.getState().nestedResult - expect(obj).toEqual(toCompare) - }) + useStore.setState({ count: 4 }); + expect(useStore.getState().count).toEqual(4); + const obj = useStore.getState().nestedResult; + useStore.setState({ count: 4 }); + const toCompare = useStore.getState().nestedResult; + expect(obj).toEqual(toCompare); + }); test("modifying variables x and y do not trigger compute function more than once, as they are not used in compute function", () => { - expect(computeStateMock).toHaveBeenCalledTimes(1) - useStore.setState({ x: 2 }) - expect(computeStateMock).toHaveBeenCalledTimes(2) - useStore.setState({ x: 3 }) - expect(computeStateMock).toHaveBeenCalledTimes(2) - useStore.setState({ y: 2 }) - expect(computeStateMock).toHaveBeenCalledTimes(2) - }) -}) + expect(computeStateMock).toHaveBeenCalledTimes(1); + useStore.setState({ x: 2 }); + expect(computeStateMock).toHaveBeenCalledTimes(2); + useStore.setState({ x: 3 }); + expect(computeStateMock).toHaveBeenCalledTimes(2); + useStore.setState({ y: 2 }); + expect(computeStateMock).toHaveBeenCalledTimes(2); + }); +}); describe("custom config", () => { - const computeStateMock = mock(computeState) - const makeStore = (opts?: ComputedStateOpts) => - create( - computed( - (set) => ({ - count: 1, - x: 1, - y: 1, - inc: () => set((state) => ({ count: state.count + 1 })), - dec: () => set((state) => ({ count: state.count - 1 })), - }), - computeStateMock, - opts, - ), - ) + const computeStateMock = mock(computeState); + const makeStore = (opts?: ComputedStateOpts) => { + const computed = createComputed(computeStateMock, opts); + return create( + computed((set) => ({ + count: 1, + x: 1, + y: 1, + inc: () => set((state) => ({ count: state.count + 1 })), + dec: () => set((state) => ({ count: state.count - 1 })), + })), + ); + }; beforeEach(() => { - computeStateMock.mockClear() - }) + computeStateMock.mockClear(); + }); test("computed does not update when a custom key selector is given", () => { - const useStore = makeStore({ keys: ["x", "y"] }) + const useStore = makeStore({ keys: ["x", "y"] }); // because we only care about x and y, the compute function should not be called when count changes - expect(computeStateMock).toHaveBeenCalledTimes(1) - expect(useStore.getState().count).toEqual(1) - expect(useStore.getState().countSq).toEqual(1) - useStore.getState().inc() - expect(useStore.getState().count).toEqual(2) - expect(useStore.getState().countSq).toEqual(1) - useStore.getState().dec() - expect(useStore.getState().count).toEqual(1) - expect(useStore.getState().countSq).toEqual(1) - expect(computeStateMock).toHaveBeenCalledTimes(1) - }) + expect(computeStateMock).toHaveBeenCalledTimes(1); + expect(useStore.getState().count).toEqual(1); + expect(useStore.getState().countSq).toEqual(1); + useStore.getState().inc(); + expect(useStore.getState().count).toEqual(2); + expect(useStore.getState().countSq).toEqual(1); + useStore.getState().dec(); + expect(useStore.getState().count).toEqual(1); + expect(useStore.getState().countSq).toEqual(1); + expect(computeStateMock).toHaveBeenCalledTimes(1); + }); test("disabling proxy causes compute to run every time", () => { - const useStore = makeStore({ disableProxy: true }) - expect(computeStateMock).toHaveBeenCalledTimes(1) - useStore.setState({ count: 4 }) - useStore.setState({ x: 2 }) - useStore.setState({ y: 3 }) - expect(useStore.getState().count).toEqual(4) - expect(useStore.getState().countSq).toEqual(16) - expect(computeStateMock).toHaveBeenCalledTimes(4) - }) -}) - -type CountSlice = Pick -type XYSlice = Pick + const useStore = makeStore({ disableProxy: true }); + expect(computeStateMock).toHaveBeenCalledTimes(1); + useStore.setState({ count: 4 }); + useStore.setState({ x: 2 }); + useStore.setState({ y: 3 }); + expect(useStore.getState().count).toEqual(4); + expect(useStore.getState().countSq).toEqual(16); + expect(computeStateMock).toHaveBeenCalledTimes(4); + }); +}); + +type CountSlice = Pick; +type XYSlice = Pick; function computeSlice(state: CountSlice): ComputedStore { const nestedResult = { stringified: JSON.stringify(state.count), - } + }; return { countSq: state.count ** 2, nestedResult, - } + }; } describe("slices pattern", () => { - const computeSliceMock = mock(computeSlice) + const computeSliceMock = mock(computeSlice); + const computed = createComputed(computeSliceMock); const makeStore = () => { const createCountSlice: StateCreator< Store, [], [["chrisvander/zustand-computed", ComputedStore]], CountSlice & ComputedStore - > = computed( - (set) => ({ - count: 1, - dec: () => set((state) => ({ count: state.count - 1 })), - }), - computeSliceMock, - ) + > = computed((set) => ({ + count: 1, + dec: () => set((state) => ({ count: state.count - 1 })), + })); const createXySlice: StateCreator = (set) => ({ x: 1, y: 1, // this should not trigger compute function inc: () => set((state) => ({ count: state.count + 2 })), - }) + }); return create()((...a) => ({ ...createCountSlice(...a), ...createXySlice(...a), - })) - } + })); + }; beforeEach(() => { - computeSliceMock.mockClear() - }) + computeSliceMock.mockClear(); + }); test("computed works on slices pattern example", () => { - const useStore = makeStore() - expect(computeSliceMock).toHaveBeenCalledTimes(1) - expect(useStore.getState().count).toEqual(1) - expect(useStore.getState().countSq).toEqual(1) - useStore.getState().inc() - expect(useStore.getState().count).toEqual(3) - expect(useStore.getState().countSq).toEqual(1) - expect(computeSliceMock).toHaveBeenCalledTimes(1) - useStore.getState().dec() - expect(useStore.getState().count).toEqual(2) - expect(useStore.getState().countSq).toEqual(4) - expect(computeSliceMock).toHaveBeenCalledTimes(2) - useStore.setState({ count: 4 }) - expect(useStore.getState().countSq).toEqual(16) - expect(computeSliceMock).toHaveBeenCalledTimes(3) - }) -}) + const useStore = makeStore(); + expect(computeSliceMock).toHaveBeenCalledTimes(1); + expect(useStore.getState().count).toEqual(1); + expect(useStore.getState().countSq).toEqual(1); + useStore.getState().inc(); + expect(useStore.getState().count).toEqual(3); + expect(useStore.getState().countSq).toEqual(1); + expect(computeSliceMock).toHaveBeenCalledTimes(1); + useStore.getState().dec(); + expect(useStore.getState().count).toEqual(2); + expect(useStore.getState().countSq).toEqual(4); + expect(computeSliceMock).toHaveBeenCalledTimes(2); + useStore.setState({ count: 4 }); + expect(useStore.getState().countSq).toEqual(16); + expect(computeSliceMock).toHaveBeenCalledTimes(3); + }); +}); diff --git a/src/computed.ts b/src/computed.ts index 8cd5d66..9675070 100644 --- a/src/computed.ts +++ b/src/computed.ts @@ -1,100 +1,109 @@ -import type { Mutate, StateCreator, StoreApi, StoreMutatorIdentifier } from "zustand" -import { shallow } from "zustand/shallow" +import type { + Mutate, + StateCreator, + StoreApi, + StoreMutatorIdentifier, +} from "zustand"; +import { shallow } from "zustand/shallow"; export type ComputedStateOpts = { - keys?: (keyof T)[] - disableProxy?: boolean - equalityFn?: (a: Y, b: Y) => boolean -} + keys?: (keyof T)[]; + disableProxy?: boolean; + equalityFn?: (a: Y, b: Y) => boolean; +}; -export type ComputedStateCreator = < - T extends object, - A extends object, +export type ComputedStateCreator = ( + compute: (state: T) => A, + opts?: ComputedStateOpts, +) => < Mps extends [StoreMutatorIdentifier, unknown][] = [], Mcs extends [StoreMutatorIdentifier, unknown][] = [], U = T, >( f: StateCreator, - compute: (state: T) => A, - opts?: ComputedStateOpts, -) => StateCreator +) => StateCreator; -type Cast = T extends U ? T : U -type Write = Omit & U +type Cast = T extends U ? T : U; +type Write = Omit & U; type StoreCompute = S extends { - getState: () => infer T + getState: () => infer T; } ? Omit, "setState"> - : never -type WithCompute = Write> + : never; +type WithCompute = Write>; declare module "zustand/vanilla" { interface StoreMutators { - "chrisvander/zustand-computed": WithCompute, A> + "chrisvander/zustand-computed": WithCompute, A>; } } type ComputedStateImpl = ( - f: StateCreator, compute: (state: T) => A, opts?: ComputedStateOpts, -) => StateCreator +) => (f: StateCreator) => StateCreator; -type SetStateWithArgs = Parameters>[0] extends (...args: infer U) => void +type SetStateWithArgs = Parameters< + ReturnType> +>[0] extends (...args: infer U) => void ? (...args: [...U, ...unknown[]]) => void - : never + : never; -const computedImpl: ComputedStateImpl = (f, compute, opts) => { +const computedImpl: ComputedStateImpl = (compute, opts) => (f) => { // set of keys that have been accessed in any compute call - const trackedSelectors = new Set() + const trackedSelectors = new Set(); return (set, get, api) => { - type T = ReturnType - type A = ReturnType + type T = ReturnType; + type A = ReturnType; - const equalityFn = opts?.equalityFn ?? shallow + const equalityFn = opts?.equalityFn ?? shallow; if (opts?.keys) { - const selectorKeys = opts.keys + const selectorKeys = opts.keys; for (const key of selectorKeys) { - trackedSelectors.add(key) + trackedSelectors.add(key); } } // we track which selectors are accessed - const useSelectors = opts?.disableProxy !== true || !!opts?.keys - const useProxy = opts?.disableProxy !== true && !opts?.keys + const useSelectors = opts?.disableProxy !== true || !!opts?.keys; + const useProxy = opts?.disableProxy !== true && !opts?.keys; const computeAndMerge = (state: T | (T & A)): T & A => { // create a Proxy to track which selectors are accessed const stateProxy = new Proxy( { ...state }, { get: (_, prop) => { - trackedSelectors.add(prop) - return state[prop as keyof T] + trackedSelectors.add(prop); + return state[prop as keyof T]; }, }, - ) + ); // calculate the new computed state - const computedState: A = compute(useProxy ? stateProxy : { ...state }) + const computedState: A = compute(useProxy ? stateProxy : { ...state }); // if part of the computed state did not change according to the equalityFn // then we use the object ref from the previous state. This is to prevent // unnecessary re-renders. for (const k of Object.keys(computedState) as (keyof A)[]) { if (equalityFn(computedState[k], (state as T & A)[k])) { - computedState[k] = (state as T & A)[k] + computedState[k] = (state as T & A)[k]; } } - return { ...state, ...computedState } - } + return { ...state, ...computedState }; + }; // higher level function to handle compute & compare overhead - const setWithComputed = (update: T | ((state: T) => T), replace?: boolean, ...args: unknown[]) => { - ;(set as SetStateWithArgs)( + const setWithComputed = ( + update: T | ((state: T) => T), + replace?: boolean, + ...args: unknown[] + ) => { + (set as SetStateWithArgs)( (state: T): T & A => { - const updated = typeof update === "object" ? update : update(state) + const updated = typeof update === "object" ? update : update(state); if ( useSelectors && @@ -102,22 +111,24 @@ const computedImpl: ComputedStateImpl = (f, compute, opts) => { !Object.keys(updated).some((k) => trackedSelectors.has(k)) ) { // if we have a selector set, but none of the updated keys are in the selector set, then we can skip the compute - return { ...state, ...updated } as T & A + return { ...state, ...updated } as T & A; } - return computeAndMerge({ ...state, ...updated }) + return computeAndMerge({ ...state, ...updated }); }, replace, ...args, - ) - } - - const _api = api as Mutate, [["chrisvander/zustand-computed", A]]> - _api.setState = setWithComputed - const st = f(setWithComputed, get, _api) as T & A - return Object.assign({}, st, compute(st)) - } -} - -export const computed = computedImpl as unknown as ComputedStateCreator -export default computed + ); + }; + + const _api = api as Mutate< + StoreApi, + [["chrisvander/zustand-computed", A]] + >; + _api.setState = setWithComputed; + const st = f(setWithComputed, get, _api) as T & A; + return Object.assign({}, st, compute(st)); + }; +}; + +export const createComputed = computedImpl as unknown as ComputedStateCreator; diff --git a/src/index.ts b/src/index.ts index d615254..6c6e96f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,2 +1 @@ -export * from "./computed" -export { computed as default } from "./computed" +export * from "./computed";