diff --git a/rpc/server.go b/rpc/server.go index 432a7b65..acb2095b 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -3,7 +3,9 @@ package rpc import ( "context" "crypto/tls" + "errors" "fmt" + "math" "net" "runtime/debug" "strings" @@ -116,12 +118,25 @@ func (s *Server) Send(ctx context.Context, in *proto.NotificationRequest) (*prot } }() + counts, err := safeIntToInt32(len(notification.Tokens)) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + return &proto.NotificationReply{ Success: true, - Counts: int32(len(notification.Tokens)), + Counts: counts, }, nil } +// safeIntToInt32 converts an int to an int32, returning an error if the int is out of range. +func safeIntToInt32(n int) (int32, error) { + if n < math.MinInt32 || n > math.MaxInt32 { + return 0, errors.New("integer overflow: value out of int32 range") + } + return int32(n), nil +} + // RunGRPCServer run gorush grpc server func RunGRPCServer(ctx context.Context, cfg *config.ConfYaml) error { if !cfg.GRPC.Enabled { diff --git a/rpc/server_test.go b/rpc/server_test.go index 56d6f512..5cf878d2 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -1,5 +1,38 @@ package rpc +import ( + "math" + "testing" +) + +func TestSafeIntToInt32(t *testing.T) { + tests := []struct { + name string + input int + want int32 + wantErr bool + }{ + {"Valid int32", 123, 123, false}, + {"Max int32", math.MaxInt32, math.MaxInt32, false}, + {"Min int32", math.MinInt32, math.MinInt32, false}, + {"Overflow int32", math.MaxInt32 + 1, 0, true}, + {"Underflow int32", math.MinInt32 - 1, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := safeIntToInt32(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("safeIntToInt32() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("safeIntToInt32() = %v, want %v", got, tt.want) + } + }) + } +} + // const gRPCAddr = "localhost:9000" // func initTest() *config.ConfYaml {