Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client/v2)!: dynamic prompt #22775

Merged
merged 13 commits into from
Dec 11, 2024
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i

### API Breaking Changes

* (client) [#22775](https://github.com/cosmos/cosmos-sdk/pull/22775) Removed client prompt validations.

### Deprecated

## [v0.52.0](https://github.com/cosmos/cosmos-sdk/releases/tag/v0.52.0) - 2024-XX-XX
Expand Down
39 changes: 0 additions & 39 deletions client/prompt_validation_test.go

This file was deleted.

1 change: 1 addition & 0 deletions client/v2/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
* [#20623](https://github.com/cosmos/cosmos-sdk/pull/20623) Introduce client/v2 tx factory.
* [#20623](https://github.com/cosmos/cosmos-sdk/pull/20623) Extend client/v2 keyring interface with `KeyType` and `KeyInfo`.
* [#22282](https://github.com/cosmos/cosmos-sdk/pull/22282) Added custom broadcast logic.
* [#22775](https://github.com/cosmos/cosmos-sdk/pull/22775) Added interactive autocli prompt functionality, including message field prompting, validation helpers, and default value support.

### Improvements

Expand Down
259 changes: 259 additions & 0 deletions client/v2/autocli/prompt/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
package prompt

import (
"fmt"
"io"
"strconv"
"strings"

"github.com/manifoldco/promptui"
"google.golang.org/protobuf/reflect/protoreflect"

"cosmossdk.io/client/v2/autocli/flag"
addresscodec "cosmossdk.io/core/address"
)

// PromptMessage prompts the user for values to populate a protobuf message interactively.
// It returns the populated message and any error encountered during prompting.
func PromptMessage(
addressCodec, validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, msg protoreflect.Message,
) (protoreflect.Message, error) {
return promptMessage(addressCodec, validatorAddressCodec, consensusAddressCodec, promptPrefix, nil, msg)
}

// promptMessage prompts the user for values to populate a protobuf message interactively.
// stdIn is provided to make the function easier to unit test by allowing injection of predefined inputs.
func promptMessage(
addressCodec, validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, stdIn io.ReadCloser, msg protoreflect.Message,
) (protoreflect.Message, error) {
fields := msg.Descriptor().Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
fieldName := string(field.Name())

promptUi := promptui.Prompt{
Validate: ValidatePromptNotEmpty,
Stdin: stdIn,
}

// If this signer field has already a valid default value set,
// use that value as the default prompt value. This is useful for
// commands that have an authority such as gov.
if strings.EqualFold(fieldName, flag.GetSignerFieldName(msg.Descriptor())) {
if defaultValue := msg.Get(field); defaultValue.IsValid() {
promptUi.Default = defaultValue.String()
}
}

// validate address fields
scalarField, ok := flag.GetScalarType(field)
if ok {
switch scalarField {
case flag.AddressStringScalarType:
promptUi.Validate = ValidateAddress(addressCodec)
case flag.ValidatorAddressStringScalarType:
promptUi.Validate = ValidateAddress(validatorAddressCodec)
case flag.ConsensusAddressStringScalarType:
promptUi.Validate = ValidateAddress(consensusAddressCodec)
default:
// prompt.Validate = ValidatePromptNotEmpty (we possibly don't want to force all fields to be non-empty)
promptUi.Validate = nil
}
}

// handle nested message fields recursively
if field.Kind() == protoreflect.MessageKind {
err := promptInnerMessageKind(field, addressCodec, validatorAddressCodec, consensusAddressCodec, promptPrefix, stdIn, msg)
if err != nil {
return nil, err
}
continue
}

// handle repeated fields by prompting for a comma-separated list of values
if field.IsList() {
list, err := promptList(field, msg, promptUi, promptPrefix)
if err != nil {
return nil, err
}

msg.Set(field, protoreflect.ValueOfList(list))
continue
}

promptUi.Label = fmt.Sprintf("Enter %s %s", promptPrefix, fieldName)
result, err := promptUi.Run()
if err != nil {
return msg, fmt.Errorf("failed to prompt for %s: %w", fieldName, err)
}

v, err := valueOf(field, result)
if err != nil {
return msg, err
}
msg.Set(field, v)
}

return msg, nil
}

// valueOf converts a string input value to a protoreflect.Value based on the field's type.
// It handles string, numeric, bool, bytes and enum field types.
// Returns the converted value and any error that occurred during conversion.
func valueOf(field protoreflect.FieldDescriptor, result string) (protoreflect.Value, error) {
switch field.Kind() {
case protoreflect.StringKind:
return protoreflect.ValueOfString(result), nil
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
resultUint, err := strconv.ParseUint(result, 10, 0)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("invalid value for int: %w", err)
}

return protoreflect.ValueOfUint64(resultUint), nil
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
resultInt, err := strconv.ParseInt(result, 10, 0)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("invalid value for int: %w", err)
}
// If a value was successfully parsed the ranges of:
// [minInt, maxInt]
// are within the ranges of:
// [minInt64, maxInt64]
// of which on 64-bit machines, which are most common,
// int==int64
return protoreflect.ValueOfInt64(resultInt), nil
case protoreflect.BoolKind:
resultBool, err := strconv.ParseBool(result)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("invalid value for bool: %w", err)
}

return protoreflect.ValueOfBool(resultBool), nil
case protoreflect.BytesKind:
resultBytes := []byte(result)
return protoreflect.ValueOfBytes(resultBytes), nil
case protoreflect.EnumKind:
enumValue := field.Enum().Values().ByName(protoreflect.Name(result))
if enumValue == nil {
return protoreflect.Value{}, fmt.Errorf("invalid enum value %q", result)
}
return protoreflect.ValueOfEnum(enumValue.Number()), nil
default:
// TODO: add more kinds
// skip any other types
return protoreflect.Value{}, nil
Comment on lines +145 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Handle additional protobuf field kinds in valueOf function.

Currently, the valueOf function does not handle certain field kinds such as floats. Consider implementing support for additional protobuf field kinds to ensure all message fields can be appropriately parsed.

You can add cases for FloatKind and DoubleKind in the switch statement:

func valueOf(field protoreflect.FieldDescriptor, result string) (protoreflect.Value, error) {
	switch field.Kind() {
		// existing cases...
+	case protoreflect.FloatKind, protoreflect.DoubleKind:
+		resultFloat, err := strconv.ParseFloat(result, 64)
+		if err != nil {
+			return protoreflect.Value{}, fmt.Errorf("invalid value for float: %w", err)
+		}
+		return protoreflect.ValueOfFloat64(resultFloat), nil
		// existing cases...
	}
}

Committable suggestion skipped: line range outside the PR's diff.

}
}

// promptList prompts the user for a comma-separated list of values for a repeated field.
// The user will be prompted to enter values separated by commas which will be parsed
// according to the field's type using valueOf.
func promptList(field protoreflect.FieldDescriptor, msg protoreflect.Message, promptUi promptui.Prompt, promptPrefix string) (protoreflect.List, error) {
promptUi.Label = fmt.Sprintf("Enter %s %s list (separate values with ',')", promptPrefix, string(field.Name()))
result, err := promptUi.Run()
if err != nil {
return nil, fmt.Errorf("failed to prompt for %s: %w", string(field.Name()), err)
}

list := msg.Mutable(field).List()
for _, item := range strings.Split(result, ",") {
v, err := valueOf(field, item)
if err != nil {
return nil, err
}
list.Append(v)
}

return list, nil
}

// promptInnerMessageKind handles prompting for fields that are of message kind.
// It handles both single messages and repeated message fields by delegating to
// promptInnerMessage and promptMessageList respectively.
func promptInnerMessageKind(
f protoreflect.FieldDescriptor, addressCodec addresscodec.Codec,
validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, stdIn io.ReadCloser, msg protoreflect.Message,
) error {
if f.IsList() {
return promptMessageList(f, addressCodec, validatorAddressCodec, consensusAddressCodec, promptPrefix, stdIn, msg)
}
return promptInnerMessage(f, addressCodec, validatorAddressCodec, consensusAddressCodec, promptPrefix, stdIn, msg)
}

// promptInnerMessage prompts for a single nested message field. It creates a new message instance,
// recursively prompts for its fields, and sets the populated message on the parent message.
func promptInnerMessage(
f protoreflect.FieldDescriptor, addressCodec addresscodec.Codec,
validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, stdIn io.ReadCloser, msg protoreflect.Message,
) error {
fieldName := promptPrefix + "." + string(f.Name())
nestedMsg := msg.Get(f).Message()
nestedMsg = nestedMsg.New()
// Recursively prompt for nested message fields
updatedMsg, err := promptMessage(
addressCodec,
validatorAddressCodec,
consensusAddressCodec,
fieldName,
stdIn,
nestedMsg,
)
if err != nil {
return fmt.Errorf("failed to prompt for nested message %s: %w", fieldName, err)
}

msg.Set(f, protoreflect.ValueOfMessage(updatedMsg))
return nil
}

// promptMessageList prompts for a repeated message field by repeatedly creating new message instances,
// prompting for their fields, and appending them to the list until the user chooses to stop.
func promptMessageList(
f protoreflect.FieldDescriptor, addressCodec addresscodec.Codec,
validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, stdIn io.ReadCloser, msg protoreflect.Message,
) error {
list := msg.Mutable(f).List()
for {
fieldName := promptPrefix + "." + string(f.Name())
// Create and populate a new message for the list
nestedMsg := list.NewElement().Message()
updatedMsg, err := promptMessage(
addressCodec,
validatorAddressCodec,
consensusAddressCodec,
fieldName,
stdIn,
nestedMsg,
)
if err != nil {
return fmt.Errorf("failed to prompt for list item in %s: %w", fieldName, err)
}

list.Append(protoreflect.ValueOfMessage(updatedMsg))

// Prompt whether to continue
// TODO: may be better yes/no rather than interactive?
continuePrompt := promptui.Select{
Label: "Add another item?",
Items: []string{"No", "Yes"},
Stdin: stdIn,
}

_, result, err := continuePrompt.Run()
if err != nil {
return fmt.Errorf("failed to prompt for continuation: %w", err)
}

if result == "No" {
break
}
}

return nil
}
59 changes: 59 additions & 0 deletions client/v2/autocli/prompt/message_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package prompt

import (
"io"
"strings"
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/reflect/protoreflect"

"cosmossdk.io/client/v2/internal/testpb"

address2 "github.com/cosmos/cosmos-sdk/codec/address"
)

func getReader(inputs []string) io.ReadCloser {
// https://github.com/manifoldco/promptui/issues/63#issuecomment-621118463
var paddedInputs []string
for _, input := range inputs {
padding := strings.Repeat("a", 4096-1-len(input)%4096)
paddedInputs = append(paddedInputs, input+"\n"+padding)
}
return io.NopCloser(strings.NewReader(strings.Join(paddedInputs, "")))
}

func TestPromptMessage(t *testing.T) {
tests := []struct {
name string
msg protoreflect.Message
inputs []string
}{
{
name: "testPb",
inputs: []string{
"1", "2", "string", "bytes", "10101010", "0", "234234", "3", "4", "5", "true", "ENUM_ONE",
"bar", "6", "10000", "stake", "cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn",
"bytes", "6", "7", "false", "false", "true,false,true", "1,2,3", "hello,hola,ciao", "ENUM_ONE,ENUM_TWO",
"10239", "0", "No", "bar", "343", "No", "134", "positional2", "23455", "stake", "No", "deprecate",
"shorthand", "false", "cosmosvaloper1tnh2q55v8wyygtt9srz5safamzdengsn9dsd7z",
},
msg: (&testpb.MsgRequest{}).ProtoReflect(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// https://github.com/manifoldco/promptui/issues/63#issuecomment-621118463
var paddedInputs []string
for _, input := range tt.inputs {
padding := strings.Repeat("a", 4096-1-len(input)%4096)
paddedInputs = append(paddedInputs, input+"\n"+padding)
}
reader := io.NopCloser(strings.NewReader(strings.Join(paddedInputs, "")))

got, err := promptMessage(address2.NewBech32Codec("cosmos"), address2.NewBech32Codec("cosmosvaloper"), address2.NewBech32Codec("cosmosvalcons"), "prefix", reader, tt.msg)
require.NoError(t, err)
require.NotNil(t, got)
})
}
}
Comment on lines +26 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance test coverage

The current test only covers the happy path. Consider adding:

  1. Error cases (invalid inputs, wrong formats)
  2. Validation of the resulting message content
  3. Edge cases (empty inputs, maximum values)

Would you like me to help generate additional test cases?


🛠️ Refactor suggestion

Refactor test implementation

The test has two issues that should be addressed:

  1. The padding logic is duplicated instead of using the getReader function
  2. The test inputs lack documentation explaining their purpose

Apply these improvements:

 func TestPromptMessage(t *testing.T) {
     tests := []struct {
         name   string
         msg    protoreflect.Message
         inputs []string
     }{
         {
             name: "testPb",
+            // Each input corresponds to a field in MsgRequest:
+            // 1-12: Basic types (numbers, strings, bools, enums)
+            // 13-17: Complex types (addresses, coins)
+            // 18-40: Array and repeated fields
             inputs: []string{
                 "1", "2", "string", "bytes", "10101010", "0", "234234", "3", "4", "5", "true", "ENUM_ONE",
                 "bar", "6", "10000", "stake", "cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn",
                 "bytes", "6", "7", "false", "false", "true,false,true", "1,2,3", "hello,hola,ciao", "ENUM_ONE,ENUM_TWO",
                 "10239", "0", "No", "bar", "343", "No", "134", "positional2", "23455", "stake", "No", "deprecate",
                 "shorthand", "false", "cosmosvaloper1tnh2q55v8wyygtt9srz5safamzdengsn9dsd7z",
             },
             msg: (&testpb.MsgRequest{}).ProtoReflect(),
         },
     }
     for _, tt := range tests {
         t.Run(tt.name, func(t *testing.T) {
-            var paddedInputs []string
-            for _, input := range tt.inputs {
-                padding := strings.Repeat("a", 4096-1-len(input)%4096)
-                paddedInputs = append(paddedInputs, input+"\n"+padding)
-            }
-            reader := io.NopCloser(strings.NewReader(strings.Join(paddedInputs, "")))
+            reader := getReader(tt.inputs)
             got, err := promptMessage(address2.NewBech32Codec("cosmos"), address2.NewBech32Codec("cosmosvaloper"), address2.NewBech32Codec("cosmosvalcons"), "prefix", reader, tt.msg)
             require.NoError(t, err)
             require.NotNil(t, got)
         })
     }
 }

Committable suggestion skipped: line range outside the PR's diff.

Loading
Loading