Skip to content

Commit

Permalink
fix(jsx/dom): fix memo for DOM renderer (honojs#3568)
Browse files Browse the repository at this point in the history
Fixes honojs#3473
Fixes honojs#3567

* fix(jsx/dom): fix memoization mechanism in dom renderer

* fix(jsx/dom): fix `memo` for DOM renderer

* feat(jsx/dom): implement light weight `memo` function for DOM renderer

* test(jsx/dom): add tests for memoization
  • Loading branch information
usualoma authored and TinsFox committed Nov 11, 2024
1 parent 928d591 commit c6677eb
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 50 deletions.
19 changes: 15 additions & 4 deletions src/jsx/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { raw } from '../helper/html'
import { escapeToBuffer, resolveCallbackSync, stringBufferToString } from '../utils/html'
import type { HtmlEscaped, HtmlEscapedString, StringBufferWithCallbacks } from '../utils/html'
import { DOM_RENDERER } from './constants'
import { DOM_RENDERER, DOM_MEMO } from './constants'
import type { Context } from './context'
import { createContext, globalContexts, useContext } from './context'
import { domRenderers } from './intrinsic-element/common'
Expand Down Expand Up @@ -346,7 +346,7 @@ export const jsxFn = (
}
}

const shallowEqual = (a: Props, b: Props): boolean => {
export const shallowEqual = (a: Props, b: Props): boolean => {
if (a === b) {
return true
}
Expand All @@ -373,19 +373,30 @@ const shallowEqual = (a: Props, b: Props): boolean => {
return true
}

export type MemorableFC<T> = FC<T> & {
[DOM_MEMO]: (prevProps: Readonly<T>, nextProps: Readonly<T>) => boolean
}
export const memo = <T>(
component: FC<T>,
propsAreEqual: (prevProps: Readonly<T>, nextProps: Readonly<T>) => boolean = shallowEqual
): FC<T> => {
let computed: ReturnType<FC<T>> = null
let prevProps: T | undefined = undefined
return ((props) => {
const wrapper: MemorableFC<T> = ((props: T) => {
if (prevProps && !propsAreEqual(prevProps, props)) {
computed = null
}
prevProps = props
return (computed ||= component(props))
}) as FC<T>
}) as MemorableFC<T>

// This function is for toString(), but it can also be used for DOM renderer.
// So, set DOM_MEMO and DOM_RENDERER for DOM renderer.
wrapper[DOM_MEMO] = propsAreEqual
// eslint-disable-next-line @typescript-eslint/no-explicit-any
;(wrapper as any)[DOM_RENDERER] = component

return wrapper as FC<T>
}

export const Fragment = ({
Expand Down
1 change: 1 addition & 0 deletions src/jsx/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ export const DOM_RENDERER = Symbol('RENDERER')
export const DOM_ERROR_HANDLER = Symbol('ERROR_HANDLER')
export const DOM_STASH = Symbol('STASH')
export const DOM_INTERNAL_TAG = Symbol('INTERNAL')
export const DOM_MEMO = Symbol('MEMO')
export const PERMALINK = Symbol('PERMALINK')
165 changes: 132 additions & 33 deletions src/jsx/dom/index.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ describe('DOM', () => {
})
})

describe('skip build child', () => {
describe('child component', () => {
it('simple', async () => {
const Child = vi.fn(({ count }: { count: number }) => <div>{count}</div>)
const App = () => {
Expand All @@ -301,11 +301,11 @@ describe('DOM', () => {
root.querySelector('button')?.click()
await Promise.resolve()
expect(root.innerHTML).toBe('<div>1</div><div>0</div><button>+</button>')
expect(Child).toBeCalledTimes(1)
expect(Child).toBeCalledTimes(2)
root.querySelector('button')?.click()
await Promise.resolve()
expect(root.innerHTML).toBe('<div>2</div><div>1</div><button>+</button>')
expect(Child).toBeCalledTimes(2)
expect(Child).toBeCalledTimes(3)
})
})

Expand Down Expand Up @@ -1321,38 +1321,137 @@ describe('DOM', () => {
})
})

it('memo', async () => {
let renderCount = 0
const Counter = ({ count }: { count: number }) => {
renderCount++
return (
<div>
<p>Count: {count}</p>
</div>
describe('memo', () => {
it('simple', async () => {
let renderCount = 0
const Counter = ({ count }: { count: number }) => {
renderCount++
return (
<div>
<p>Count: {count}</p>
</div>
)
}
const MemoCounter = memo(Counter)
const App = () => {
const [count, setCount] = useState(0)
return (
<div>
<MemoCounter count={Math.min(count, 1)} />
<button onClick={() => setCount(count + 1)}>+</button>
</div>
)
}
const app = <App />
render(app, root)
expect(root.innerHTML).toBe('<div><div><p>Count: 0</p></div><button>+</button></div>')
expect(renderCount).toBe(1)
root.querySelector('button')?.click()
await Promise.resolve()
expect(root.innerHTML).toBe('<div><div><p>Count: 1</p></div><button>+</button></div>')
expect(renderCount).toBe(2)
root.querySelector('button')?.click()
await Promise.resolve()
expect(root.innerHTML).toBe('<div><div><p>Count: 1</p></div><button>+</button></div>')
expect(renderCount).toBe(2)
})

it('useState', async () => {
const Child = vi.fn(({ count }: { count: number }) => {
const [count2, setCount2] = useState(0)
return (
<>
<div>
{count} : {count2}
</div>
<button id='child-button' onClick={() => setCount2(count2 + 1)}>
Child +
</button>
</>
)
})
const MemoChild = memo(Child)
const App = () => {
const [count, setCount] = useState(0)
return (
<>
<button id='app-button' onClick={() => setCount(count + 1)}>
App +
</button>
<MemoChild count={Math.floor(count / 2)} />
</>
)
}
render(<App />, root)
expect(root.innerHTML).toBe(
'<button id="app-button">App +</button><div>0 : 0</div><button id="child-button">Child +</button>'
)
}
const MemoCounter = memo(Counter)
const App = () => {
const [count, setCount] = useState(0)
return (
<div>
<MemoCounter count={Math.min(count, 1)} />
<button onClick={() => setCount(count + 1)}>+</button>
</div>
root.querySelector<HTMLButtonElement>('button#app-button')?.click()
await Promise.resolve()
expect(Child).toBeCalledTimes(1)
expect(root.innerHTML).toBe(
'<button id="app-button">App +</button><div>0 : 0</div><button id="child-button">Child +</button>'
)
}
const app = <App />
render(app, root)
expect(root.innerHTML).toBe('<div><div><p>Count: 0</p></div><button>+</button></div>')
expect(renderCount).toBe(1)
root.querySelector('button')?.click()
await Promise.resolve()
expect(root.innerHTML).toBe('<div><div><p>Count: 1</p></div><button>+</button></div>')
expect(renderCount).toBe(2)
root.querySelector('button')?.click()
await Promise.resolve()
expect(root.innerHTML).toBe('<div><div><p>Count: 1</p></div><button>+</button></div>')
expect(renderCount).toBe(2)
root.querySelector<HTMLButtonElement>('button#app-button')?.click()
await Promise.resolve()
expect(Child).toBeCalledTimes(2)
expect(root.innerHTML).toBe(
'<button id="app-button">App +</button><div>1 : 0</div><button id="child-button">Child +</button>'
)
root.querySelector<HTMLButtonElement>('button#child-button')?.click()
await Promise.resolve()
expect(Child).toBeCalledTimes(3)
expect(root.innerHTML).toBe(
'<button id="app-button">App +</button><div>1 : 1</div><button id="child-button">Child +</button>'
)
})

// The react compiler generates code like the following for memoization.
it('react compiler', async () => {
let renderCount = 0
const Counter = ({ count }: { count: number }) => {
renderCount++
return (
<div>
<p>Count: {count}</p>
</div>
)
}

const App = () => {
const [cache] = useState<unknown[]>(() => [])
const [count, setCount] = useState(0)
const countForDisplay = Math.floor(count / 2)

let localCounter
if (cache[0] !== countForDisplay) {
localCounter = <Counter count={countForDisplay} />
cache[0] = countForDisplay
cache[1] = localCounter
} else {
localCounter = cache[1]
}

return (
<div>
{localCounter}
<button onClick={() => setCount(count + 1)}>+</button>
</div>
)
}
const app = <App />
render(app, root)
expect(root.innerHTML).toBe('<div><div><p>Count: 0</p></div><button>+</button></div>')
expect(renderCount).toBe(1)
root.querySelector('button')?.click()
await Promise.resolve()
expect(root.innerHTML).toBe('<div><div><p>Count: 0</p></div><button>+</button></div>')
expect(renderCount).toBe(1)
root.querySelector('button')?.click()
await Promise.resolve()
expect(root.innerHTML).toBe('<div><div><p>Count: 1</p></div><button>+</button></div>')
expect(renderCount).toBe(2)
})
})

describe('useRef', async () => {
Expand Down
14 changes: 12 additions & 2 deletions src/jsx/dom/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
* This module provides APIs for `hono/jsx/dom`.
*/

import { isValidElement, memo, reactAPICompatVersion } from '../base'
import type { Child, DOMAttributes, JSX, JSXNode, Props } from '../base'
import { isValidElement, reactAPICompatVersion, shallowEqual } from '../base'
import type { Child, DOMAttributes, JSX, JSXNode, Props, FC, MemorableFC } from '../base'
import { Children } from '../children'
import { DOM_MEMO } from '../constants'
import { useContext } from '../context'
import {
createRef,
Expand Down Expand Up @@ -72,6 +73,15 @@ const cloneElement = <T extends JSXNode | JSX.Element>(
) as T
}

const memo = <T>(
component: FC<T>,
propsAreEqual: (prevProps: Readonly<T>, nextProps: Readonly<T>) => boolean = shallowEqual
): FC<T> => {
const wrapper = ((props: T) => component(props)) as MemorableFC<T>
wrapper[DOM_MEMO] = propsAreEqual
return wrapper as FC<T>
}

export {
reactAPICompatVersion as version,
createElement as jsx,
Expand Down
27 changes: 16 additions & 11 deletions src/jsx/dom/render.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import type { Child, FC, JSXNode, Props } from '../base'
import type { Child, FC, JSXNode, Props, MemorableFC } from '../base'
import { toArray } from '../children'
import { DOM_ERROR_HANDLER, DOM_INTERNAL_TAG, DOM_RENDERER, DOM_STASH } from '../constants'
import {
DOM_ERROR_HANDLER,
DOM_INTERNAL_TAG,
DOM_MEMO,
DOM_RENDERER,
DOM_STASH,
} from '../constants'
import type { Context as JSXContext } from '../context'
import { globalContexts as globalJSXContexts, useContext } from '../context'
import type { EffectData } from '../hooks'
Expand Down Expand Up @@ -45,6 +51,7 @@ export type NodeObject = {
e: SupportedElement | Text | undefined // rendered element
p?: PreserveNodeType // preserve HTMLElement if it will be unmounted
a?: boolean // cancel apply() if true
o?: NodeObject // original node
[DOM_STASH]:
| [
number, // current hook index
Expand Down Expand Up @@ -516,15 +523,12 @@ export const build = (context: Context, node: NodeObject, children?: Child[]): v
oldChild[DOM_STASH][2] = child[DOM_STASH][2] || []
oldChild[DOM_STASH][3] = child[DOM_STASH][3]

if (!oldChild.f) {
const prevPropsKeys = Object.keys(pP)
const currentProps = oldChild.props
if (
prevPropsKeys.length === Object.keys(currentProps).length &&
prevPropsKeys.every((k) => k in currentProps && currentProps[k] === pP[k])
) {
oldChild.s = true
}
if (
!oldChild.f &&
((oldChild.o || oldChild) === child.o || // The code generated by the react compiler is memoized under this condition.
(oldChild.tag as MemorableFC<unknown>)[DOM_MEMO]?.(pP, oldChild.props)) // The `memo` function is memoized under this condition.
) {
oldChild.s = true
}
}
child = oldChild
Expand Down Expand Up @@ -626,6 +630,7 @@ export const buildNode = (node: Child): Node | undefined => {
f: (node as NodeObject).f,
type: (node as NodeObject).tag,
ref: (node as NodeObject).props.ref,
o: (node as NodeObject).o || node,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any
}
Expand Down

0 comments on commit c6677eb

Please sign in to comment.