Skip to content
Merged
19 changes: 17 additions & 2 deletions libs/dyn/convert/to_typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ func toTypedBool(dst reflect.Value, src dyn.Value) error {
case dyn.KindString:
// See https://github.com/go-yaml/yaml/blob/f6f7691b1fdeb513f56608cd2c32c51f8194bf51/decode.go#L684-L693.
switch src.MustString() {
case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON":
case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON", "true":
dst.SetBool(true)
return nil
case "n", "N", "no", "No", "NO", "off", "Off", "OFF":
case "n", "N", "no", "No", "NO", "off", "Off", "OFF", "false":
Comment thread
andrewnester marked this conversation as resolved.
dst.SetBool(false)
return nil
}
Expand All @@ -246,6 +246,17 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error {
case dyn.KindInt:
dst.SetInt(src.MustInt())
return nil
case dyn.KindFloat:
v := src.MustFloat()
if canConvertToInt(v) {
dst.SetInt(int64(src.MustFloat()))
Comment thread
andrewnester marked this conversation as resolved.
return nil
}

return TypeError{
value: src,
msg: fmt.Sprintf("expected an int, found a %s", src.Kind()),
}
Comment thread
andrewnester marked this conversation as resolved.
case dyn.KindString:
if i64, err := strconv.ParseInt(src.MustString(), 10, 64); err == nil {
dst.SetInt(i64)
Expand All @@ -264,6 +275,10 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error {
}
}

func canConvertToInt(v float64) bool {
return v == float64(int(v))
Comment thread
andrewnester marked this conversation as resolved.
Outdated
}

Comment thread
andrewnester marked this conversation as resolved.
Outdated
func toTypedFloat(dst reflect.Value, src dyn.Value) error {
switch src.Kind() {
case dyn.KindFloat:
Expand Down
21 changes: 21 additions & 0 deletions libs/dyn/jsonloader/json.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package jsonloader

import (
"encoding/json"

"github.com/databricks/cli/libs/dyn"
)

func LoadJSON(data []byte) (dyn.Value, error) {
var root map[string]interface{}
err := json.Unmarshal(data, &root)
if err != nil {
return dyn.InvalidValue, err
}

loc := dyn.Location{
Line: 1,
Column: 1,
}
return newLoader().load(&root, loc)
Comment thread
andrewnester marked this conversation as resolved.
Outdated
}
53 changes: 53 additions & 0 deletions libs/dyn/jsonloader/json_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package jsonloader

import (
"testing"

"github.com/databricks/cli/libs/dyn/convert"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require"
)

const jsonData = `
{
"job_id": 123,
"new_settings": {
"name": "xxx",
"email_notifications": {
"on_start": [],
"on_success": [],
"on_failure": []
},
"webhook_notifications": {
"on_start": [],
"on_failure": []
},
"notification_settings": {
"no_alert_for_skipped_runs": true,
"no_alert_for_canceled_runs": true
},
"timeout_seconds": 0,
"max_concurrent_runs": 1,
"tasks": [
{
"task_key": "xxx",
"email_notifications": {},
"notification_settings": {},
"timeout_seconds": 0,
"max_retries": 0,
"min_retry_interval_millis": 0,
"retry_on_timeout": "true"
}
]
}
}
`

func TestJsonLoader(t *testing.T) {
v, err := LoadJSON([]byte(jsonData))
require.NoError(t, err)

var r jobs.ResetJob
err = convert.ToTyped(&r, v)
require.NoError(t, err)
}
Comment thread
andrewnester marked this conversation as resolved.
99 changes: 99 additions & 0 deletions libs/dyn/jsonloader/loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package jsonloader

import (
"fmt"
"reflect"

"github.com/databricks/cli/libs/dyn"
)

type loader struct {
}

func newLoader() *loader {
return &loader{}
}

func errorf(loc dyn.Location, format string, args ...interface{}) error {
return fmt.Errorf("json (%s): %s", loc, fmt.Sprintf(format, args...))
}

func (d *loader) load(node any, loc dyn.Location) (dyn.Value, error) {
var value dyn.Value
var err error

if node == nil {
return dyn.NilValue, nil
}

if reflect.TypeOf(node).Kind() == reflect.Ptr {
return d.load(reflect.ValueOf(node).Elem().Interface(), loc)
}

switch reflect.TypeOf(node).Kind() {
case reflect.Map:
value, err = d.loadMapping(node.(map[string]interface{}), loc)
case reflect.Slice:
value, err = d.loadSequence(node.([]interface{}), loc)
case reflect.String, reflect.Bool,
reflect.Float64, reflect.Float32,
reflect.Int, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint32, reflect.Uint64:
value, err = d.loadScalar(node, loc)

default:
return dyn.InvalidValue, errorf(loc, "unknown node kind: %v", reflect.TypeOf(node).Kind())
}

if err != nil {
return dyn.InvalidValue, err
}

return value, nil
}

func (d *loader) loadScalar(node any, loc dyn.Location) (dyn.Value, error) {
switch reflect.TypeOf(node).Kind() {
case reflect.String:
return dyn.NewValue(node.(string), []dyn.Location{loc}), nil
case reflect.Bool:
return dyn.NewValue(node.(bool), []dyn.Location{loc}), nil
case reflect.Float64, reflect.Float32:
return dyn.NewValue(node.(float64), []dyn.Location{loc}), nil
case reflect.Int, reflect.Int32, reflect.Int64:
return dyn.NewValue(node.(int64), []dyn.Location{loc}), nil
case reflect.Uint, reflect.Uint32, reflect.Uint64:
return dyn.NewValue(node.(uint64), []dyn.Location{loc}), nil
default:
return dyn.InvalidValue, errorf(loc, "unknown scalar type: %v", reflect.TypeOf(node).Kind())
}
}

func (d *loader) loadSequence(node []interface{}, loc dyn.Location) (dyn.Value, error) {
dst := make([]dyn.Value, len(node))
for i, value := range node {
v, err := d.load(value, loc)
if err != nil {
return dyn.InvalidValue, err
}
dst[i] = v
}
return dyn.NewValue(dst, []dyn.Location{loc}), nil
}

func (d *loader) loadMapping(node map[string]interface{}, loc dyn.Location) (dyn.Value, error) {
dst := make(map[string]dyn.Value)
index := 0
for key, value := range node {
index += 1
v, err := d.load(value, dyn.Location{
Line: loc.Line + index,
Column: loc.Column,
})
if err != nil {
return dyn.InvalidValue, err
}
dst[key] = v
}
return dyn.NewValue(dst, []dyn.Location{loc}), nil
}
26 changes: 24 additions & 2 deletions libs/flags/json_flag.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package flags

import (
"encoding/json"
"fmt"
"os"

"github.com/databricks/cli/libs/dyn/convert"
"github.com/databricks/cli/libs/dyn/jsonloader"
)

type JsonFlag struct {
Expand Down Expand Up @@ -33,7 +35,27 @@ func (j *JsonFlag) Unmarshal(v any) error {
if j.raw == nil {
return nil
}
return json.Unmarshal(j.raw, v)

dv, err := jsonloader.LoadJSON(j.raw)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do the same for the YAML flag?

I believe that's the only reason we still depend on github.com/ghodss/yaml. This package uses the json tags in the types as opposed to yaml tags, which the upstream YAML package uses. Since we use the json tags as well in the dyn package it should be possible to replace it.

Not blocking for this PR, of course.

if err != nil {
return err
}

err = convert.ToTyped(v, dv)
if err != nil {
return err
}

_, diags := convert.Normalize(v, dv)
if len(diags) > 0 {
summary := ""
for _, diag := range diags {
summary += fmt.Sprintf("- %s\n", diag.Summary)
}
return fmt.Errorf("json input error:\n%v", summary)
Comment thread
andrewnester marked this conversation as resolved.
Outdated
}
Comment thread
andrewnester marked this conversation as resolved.

return nil
}

func (j *JsonFlag) Type() string {
Expand Down
101 changes: 99 additions & 2 deletions libs/flags/json_flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path"
"testing"

"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -52,7 +53,7 @@ func TestJsonFlagFile(t *testing.T) {
var request any

var fpath string
var payload = []byte(`"hello world"`)
var payload = []byte(`{"hello": "world"}`)

{
f, err := os.Create(path.Join(t.TempDir(), "file"))
Expand All @@ -68,5 +69,101 @@ func TestJsonFlagFile(t *testing.T) {
err = body.Unmarshal(&request)
require.NoError(t, err)

assert.Equal(t, "hello world", request)
assert.Equal(t, map[string]interface{}{"hello": "world"}, request)
}

const jsonData = `
{
"job_id": 123,
"new_settings": {
"name": "new job",
"email_notifications": {
"on_start": [],
"on_success": [],
"on_failure": []
},
"notification_settings": {
"no_alert_for_skipped_runs": true,
"no_alert_for_canceled_runs": true
},
"timeout_seconds": 0,
"max_concurrent_runs": 1,
"tasks": [
{
"task_key": "new task",
"email_notifications": {},
"notification_settings": {},
"timeout_seconds": 0,
"max_retries": 0,
"min_retry_interval_millis": 0,
"retry_on_timeout": "true"
}
]
}
}
`

func TestJsonUnmarshalForRequest(t *testing.T) {
var body JsonFlag

var r jobs.ResetJob
err := body.Set(jsonData)
require.NoError(t, err)

err = body.Unmarshal(&r)
require.NoError(t, err)

assert.Equal(t, int64(123), r.JobId)
assert.Equal(t, "new job", r.NewSettings.Name)
assert.Equal(t, 0, r.NewSettings.TimeoutSeconds)
assert.Equal(t, 1, r.NewSettings.MaxConcurrentRuns)
assert.Equal(t, 1, len(r.NewSettings.Tasks))
assert.Equal(t, "new task", r.NewSettings.Tasks[0].TaskKey)
assert.Equal(t, 0, r.NewSettings.Tasks[0].TimeoutSeconds)
assert.Equal(t, 0, r.NewSettings.Tasks[0].MaxRetries)
assert.Equal(t, 0, r.NewSettings.Tasks[0].MinRetryIntervalMillis)
assert.Equal(t, true, r.NewSettings.Tasks[0].RetryOnTimeout)
}

const incorrectJsonData = `
{
"job_id": 123,
"settings": {
"name": "new job",
"email_notifications": {
"on_start": [],
"on_success": [],
"on_failure": []
},
"notification_settings": {
"no_alert_for_skipped_runs": true,
"no_alert_for_canceled_runs": true
},
"timeout_seconds": {},
"max_concurrent_runs": {},
"tasks": [
{
"task_key": "new task",
"email_notifications": {},
"notification_settings": {},
"timeout_seconds": 0,
"max_retries": 0,
"min_retry_interval_millis": 0,
"retry_on_timeout": "true"
}
]
}
}
`

func TestJsonUnmarshalRequestMismatch(t *testing.T) {
var body JsonFlag

var r jobs.ResetJob
err := body.Set(incorrectJsonData)
require.NoError(t, err)

err = body.Unmarshal(&r)
require.ErrorContains(t, err, `json input error:
- unknown field: settings`)
}