Skip to content

Commit

Permalink
Merge pull request #4 from openinfradev/support_tls
Browse files Browse the repository at this point in the history
feature. support tls
  • Loading branch information
ktkfree authored Mar 21, 2022
2 parents 57d9571 + e2afd0c commit 105302d
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 47 deletions.
61 changes: 52 additions & 9 deletions pkg/grpc_client/client.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package grpc_client

import (
"fmt"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

"github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/openinfradev/tks-common/pkg/log"
"github.com/openinfradev/tks-common/pkg/helper"
pb "github.com/openinfradev/tks-proto/tks_pb"
)

func CreateCspInfoClient(address string, port int, caller string) (*grpc.ClientConn, pb.CspInfoServiceClient, error) {
cc, err := helper.CreateConnection(address, port, caller)
func CreateCspInfoClient(address string, port int, tlsEnabled bool, certPath string ) (*grpc.ClientConn, pb.CspInfoServiceClient, error) {
cc, err := createConnection(address, port, tlsEnabled, certPath)
if err != nil {
log.Fatal("Could not connect to gRPC server", err)
return nil, nil, err
Expand All @@ -18,8 +21,8 @@ func CreateCspInfoClient(address string, port int, caller string) (*grpc.ClientC
return cc, sc, nil
}

func CreateContractClient(address string, port int, caller string) (*grpc.ClientConn, pb.ContractServiceClient, error) {
cc, err := helper.CreateConnection(address, port, caller)
func CreateContractClient(address string, port int, tlsEnabled bool, certPath string) (*grpc.ClientConn, pb.ContractServiceClient, error) {
cc, err := createConnection(address, port, tlsEnabled, certPath)
if err != nil {
log.Fatal("Could not connect to gRPC server", err)
return nil, nil, err
Expand All @@ -28,8 +31,8 @@ func CreateContractClient(address string, port int, caller string) (*grpc.Client
return cc, sc, nil
}

func CreateClusterInfoClient(address string, port int, caller string) (*grpc.ClientConn, pb.ClusterInfoServiceClient, error) {
cc, err := helper.CreateConnection(address, port, caller)
func CreateClusterInfoClient(address string, port int, tlsEnabled bool, certPath string) (*grpc.ClientConn, pb.ClusterInfoServiceClient, error) {
cc, err := createConnection(address, port, tlsEnabled, certPath)
if err != nil {
log.Fatal("Could not connect to gRPC server", err)
return nil, nil, err
Expand All @@ -38,8 +41,8 @@ func CreateClusterInfoClient(address string, port int, caller string) (*grpc.Cli
return cc, sc, nil
}

func CreateAppInfoClient(address string, port int, caller string) (*grpc.ClientConn, pb.AppInfoServiceClient, error) {
cc, err := helper.CreateConnection(address, port, caller)
func CreateAppInfoClient(address string, port int, tlsEnabled bool, certPath string) (*grpc.ClientConn, pb.AppInfoServiceClient, error) {
cc, err := createConnection(address, port, tlsEnabled, certPath)
if err != nil {
log.Fatal("Could not connect to gRPC server", err)
return nil, nil, err
Expand All @@ -48,3 +51,43 @@ func CreateAppInfoClient(address string, port int, caller string) (*grpc.ClientC
return cc, sc, nil
}


func createConnection(address string, port int, tlsEnabled bool, certPath string) (*grpc.ClientConn, error) {
var err error
var creds credentials.TransportCredentials

if tlsEnabled {
creds, err = loadTLSClientCredential( certPath )
if err != nil {
return nil, err
}
} else {
creds = insecure.NewCredentials()
}

host := fmt.Sprintf("%s:%d", address, port)
conn, err := grpc.Dial(
host,
grpc.WithTransportCredentials(creds),
grpc.WithUnaryInterceptor(
grpc_middleware.ChainUnaryClient(
log.IOLoggingForClientSide(),
),
),
)
if err != nil {
return nil, err
}
return conn, nil
}

func loadTLSClientCredential(clientCertPath string) (credentials.TransportCredentials, error) {
creds, err := credentials.NewClientTLSFromFile(clientCertPath, "")
if err != nil {
log.Error("Fail to load client credentials: ", err)
return nil, err
}

return creds, nil
}

57 changes: 57 additions & 0 deletions pkg/grpc_server/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package grpc_server

import (
"net"
"strconv"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

"github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/go-grpc-middleware/recovery"

"github.com/openinfradev/tks-common/pkg/log"

)

func CreateServer(port int, tlsEnabled bool, certPath string, keyPath string) (*grpc.Server, net.Listener, error) {
log.Info("Starting to listen port ", port)

lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
if err != nil {
log.Error("failed to listen:", err)
return nil, nil, err
}

serverOptions := []grpc.ServerOption{
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
grpc_recovery.UnaryServerInterceptor(),
log.IOLoggingForServerSide(),
),
),
}

if tlsEnabled {
log.Info("TLS enabled!!!")
tlsCredentials, err := loadTLSCredentials(certPath, keyPath)
if err != nil {
log.Error("Cannot load TLS credentials: ", err)
return nil, nil, err
}
serverOptions = append(serverOptions, grpc.Creds(tlsCredentials))
}

return grpc.NewServer(serverOptions...), lis, nil
}

func loadTLSCredentials(certPath string, keyPath string) (credentials.TransportCredentials, error) {
creds, err := credentials.NewServerTLSFromFile(certPath, keyPath)
if err != nil {
log.Error("Fail to load credentials: ", err)
return nil, err
}

return creds, nil
}

31 changes: 0 additions & 31 deletions pkg/helper/grpc.go

This file was deleted.

25 changes: 18 additions & 7 deletions pkg/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io/ioutil"
"context"
"fmt"
"time"

"github.com/sirupsen/logrus"
"google.golang.org/grpc"
Expand Down Expand Up @@ -74,17 +73,29 @@ func Disable() {
}


// for grpc IO logging
func IOLog() grpc.UnaryClientInterceptor {
// grpc IO logging for client-side
func IOLoggingForClientSide() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
start := time.Now()
err := invoker(ctx, method, req, reply, cc, opts...)
end := time.Now()

Info(fmt.Sprintf("[GRPC:%s][START:%s][END:%s][ERR:%v]", method, start.Format(time.RFC3339), end.Format(time.RFC3339), err))
Debug(fmt.Sprintf("[GRPC:%s][REQUEST %s][REPLY %s]", method, req, reply))
Info(fmt.Sprintf("[INTERNAL_CALL:%s][REQUEST %s][RESPONSE %s]", method, req, reply))

return err
}
}

// grpc IO logging for server-side
func IOLoggingForServerSide() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
Info(fmt.Sprintf("[START:%s][REQUEST %s]", info.FullMethod, req))

res, err := handler(ctx, req)
if err != nil {
Error(err)
}

Info(fmt.Sprintf("[END:%s][RESPONSE %s]", info.FullMethod, res))

return res, err
}
}

0 comments on commit 105302d

Please sign in to comment.