Skip to content

Commit

Permalink
fix/class:list directive was not properly merging with the class attr…
Browse files Browse the repository at this point in the history
…ibute (#1039)

Co-authored-by: Emanuele Stoppa <[email protected]>
  • Loading branch information
Prokodo and ematipico authored Aug 7, 2024
1 parent 1c01c72 commit f55a2af
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 28 deletions.
5 changes: 5 additions & 0 deletions .changeset/ten-melons-return.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@astrojs/compiler": patch
---

Resolves an issue where the `class:list` directive was not correctly merging with the class attribute.
28 changes: 21 additions & 7 deletions internal/transform/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,21 +625,35 @@ func walk(doc *astro.Node, cb func(*astro.Node)) {
func mergeClassList(doc *astro.Node, n *astro.Node, opts *TransformOptions) {
var classListAttrValue string
var classListAttrIndex int = -1

var classAttrType astro.AttributeType
var classAttrValue string
var classAttrIndex int = -1
for i, a := range n.Attr {
if a.Key == "class:list" {
classListAttrValue = a.Val

for i, attr := range n.Attr {
if attr.Key == "class:list" {
classListAttrValue = attr.Val
classListAttrIndex = i
}
if a.Key == "class" {
classAttrValue = a.Val
if attr.Key == "class" {
classAttrType = attr.Type
classAttrValue = attr.Val
classAttrIndex = i
}
}

// Check if both `class:list` and `class` attributes are present
if classListAttrIndex >= 0 && classAttrIndex >= 0 {
// we append the prepend the value of class to class:list
n.Attr[classListAttrIndex].Val = fmt.Sprintf("['%s', %s]", classAttrValue, classListAttrValue)
// Merge the `class` attribute value into `class:list`
if classAttrType == astro.ExpressionAttribute {
// If the `class` attribute is an expression, include it directly without surrounding quotes.
// This respects the fact that expressions are evaluated dynamically and should not be treated as strings.
n.Attr[classListAttrIndex].Val = fmt.Sprintf("[%s, %s]", classAttrValue, classListAttrValue)
} else {
// If the `class` attribute is a static string, wrap it in quotes.
// This ensures that static class names are treated as string values within the list.
n.Attr[classListAttrIndex].Val = fmt.Sprintf("['%s', %s]", classAttrValue, classListAttrValue)
}
// Now that the value of `class` is carried by `class:list`, we can remove the `class` node from the AST.
// Doing so will allow us to generate valid HTML at runtime
n.Attr = remove(n.Attr, classAttrIndex)
Expand Down
99 changes: 78 additions & 21 deletions internal/transform/transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,19 @@ func TestFullTransform(t *testing.T) {
want string
}{
{
name: "top-level component with leading style",
source: `<style>:root{}</style><Component><h1>Hello world</h1></Component>
`,
want: `<Component><h1>Hello world</h1></Component>`,
name: "top-level component with leading style",
source: `<style>:root{}</style><Component><h1>Hello world</h1></Component>`,
want: `<Component><h1>Hello world</h1></Component>`,
},
{
name: "top-level component with leading style body",
source: `<style>:root{}</style><Component><div><h1>Hello world</h1></div></Component>
`,
want: `<Component><div><h1>Hello world</h1></div></Component>`,
name: "top-level component with leading style body",
source: `<style>:root{}</style><Component><div><h1>Hello world</h1></div></Component>`,
want: `<Component><div><h1>Hello world</h1></div></Component>`,
},
{
name: "top-level component with trailing style",
source: `<Component><h1>Hello world</h1></Component><style>:root{}</style>
`,
want: `<Component><h1>Hello world</h1></Component>`,
name: "top-level component with trailing style",
source: `<Component><h1>Hello world</h1></Component><style>:root{}</style>`,
want: `<Component><h1>Hello world</h1></Component>`,
},
{
name: "Component before html I",
Expand Down Expand Up @@ -287,15 +284,9 @@ func TestFullTransform(t *testing.T) {
want: `<A><div><B></B></div></A>`,
},
{
name: "does not remove trailing siblings",
source: `<title>Title</title>
<span />
<Component />
<span />`,
want: `<title>Title</title>
<span></span>
<Component></Component>
<span></span>`,
name: "does not remove trailing siblings",
source: `<title>Title</title><span /><Component /><span />`,
want: `<title>Title</title><span></span><Component></Component><span></span>`,
},
}
var b strings.Builder
Expand Down Expand Up @@ -532,3 +523,69 @@ func TestAnnotation(t *testing.T) {
})
}
}

func TestClassAndClassListMerging(t *testing.T) {
tests := []struct {
name string
source string
want string
}{
{
name: "Single class attribute",
source: `<div class="astro-xxxxxx" />`,
want: `<div class="astro-xxxxxx"></div>`,
},
{
name: "Class attribute with parameter",
source: "<div class={`astro-xxxxxx ${astro}`} />",
want: "<div class={`astro-xxxxxx ${astro}`}></div>",
},
{
name: "Single class:list attribute",
source: `<div class:list={"astro-xxxxxx"} />`,
want: `<div class:list={"astro-xxxxxx"}></div>`,
},
{
name: "Merge class with empty class:list (double quotes)",
source: `<div class="astro-xxxxxx" class:list={} />`,
want: `<div class:list={['astro-xxxxxx', ]}></div>`,
},
{
name: "Merge class with empty class:list (single quotes)",
source: `<div class='astro-xxxxxx' class:list={} />`,
want: `<div class:list={['astro-xxxxxx', ]}></div>`,
},
{
name: "Merge class and class:list attributes (string)",
source: `<div class="astro-xxxxxx" class:list={"astro-yyyyyy"} />`,
want: `<div class:list={['astro-xxxxxx', "astro-yyyyyy"]}></div>`,
},
{
name: "Merge class and class:list attributes (expression)",
source: `<div class={"astro-xxxxxx"} class:list={"astro-yyyyyy"} />`,
want: `<div class:list={["astro-xxxxxx", "astro-yyyyyy"]}></div>`,
},
{
name: "Merge Class and Class List Attributes (concatenation)",
source: `<div class={"astro-xxxxxx" + name} class:list={"astro-yyyyyy"} />`,
want: `<div class:list={["astro-xxxxxx" + name, "astro-yyyyyy"]}></div>`,
},
}

var b strings.Builder
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b.Reset()
doc, err := astro.Parse(strings.NewReader(tt.source))
if err != nil {
t.Error(err)
}
Transform(doc, TransformOptions{}, handler.NewHandler(tt.source, "/test.astro"))
astro.PrintToSource(&b, doc.LastChild.FirstChild.NextSibling.FirstChild)
got := b.String()
if tt.want != got {
t.Errorf("\nFAIL: %s\n want: %s\n got: %s", tt.name, tt.want, got)
}
})
}
}

0 comments on commit f55a2af

Please sign in to comment.