Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net
ctx := s.srv.ctx

if params.GetUserAgent == nil && !params.TargetServer.IsOpenSSHNode() {
return nil, trace.BadParameter("agentless node require an agent getter")
return nil, trace.BadParameter("user agent getter is required for teleport nodes (this is a bug)")
}
s.logger.DebugContext(ctx, "Initiating dial and forwarding request",
"source_addr", logutils.StringerAttr(params.From),
Expand Down
13 changes: 8 additions & 5 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,17 @@ func (s *ServerConfig) CheckDefaults() error {
if s.TargetServer == nil {
return trace.BadParameter("target server is required")
}
if s.TargetServer.IsOpenSSHNode() {
switch s.TargetServer.GetSubKind() {
case types.SubKindTeleportNode:
if s.UserAgent == nil {
return trace.BadParameter("user agent required for teleport nodes (agentless)")
}
case types.SubKindOpenSSHNode:
if s.AgentlessSigner == nil {
return trace.BadParameter("agentless signer is required for OpenSSH Nodes")
}
} else {
if s.UserAgent == nil {
return trace.BadParameter("user agent required for teleport nodes")
}
case types.SubKindOpenSSHEICENode:
// agentless signer is set once the forwarding server is started.
}
if s.TargetConn == nil {
return trace.BadParameter("connection to target connection required")
Expand Down
104 changes: 104 additions & 0 deletions lib/srv/forward/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,22 @@ import (
"crypto/ed25519"
"crypto/rand"
"errors"
"net"
"os/user"
"sync/atomic"
"testing"

"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/keys"
apisshutils "github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils/log/logtest"
Expand Down Expand Up @@ -276,3 +280,103 @@ func TestCheckTCPIPForward(t *testing.T) {

// TODO(atburke): Add test for handleForwardedTCPIPRequest once we have
// infrastructure for higher-level tests here.

func TestServerConfigCheckDefaults(t *testing.T) {
teleportNode, err := types.NewNode("teleport-node", "", types.ServerSpecV2{}, nil)
require.NoError(t, err)

openSSHNode, err := types.NewNode("openssh-node", types.SubKindOpenSSHNode, types.ServerSpecV2{
Addr: "openssh.example.com:22",
Hostname: "openssh.example.com",
}, nil)
require.NoError(t, err)

openSSHEICENode, err := types.NewEICENode(types.ServerSpecV2{
Addr: "openssheice.example.com:22",
Hostname: "openssheice.example.com",
CloudMetadata: &types.CloudMetadata{
AWS: &types.AWSInfo{
AccountID: "123456789012",
InstanceID: "i-123456789012",
Region: "us-east-1",
VPCID: "vpc-abcd",
SubnetID: "subnet-123",
Integration: "teleportdev",
},
},
}, nil)
require.NoError(t, err)

for _, tt := range []struct {
name string
modifyCfg func(c *ServerConfig)
errorAssertion require.ErrorAssertionFunc
}{
{
name: "no targetServer",
modifyCfg: func(c *ServerConfig) {},
errorAssertion: func(tt require.TestingT, err error, i ...interface{}) {
require.Error(t, err)
require.ErrorContains(t, err, "target server is required")
},
}, {
name: "Teleport Node",
modifyCfg: func(c *ServerConfig) {
c.TargetServer = teleportNode
c.UserAgent = &sshutils.AgentChannel{}
},
errorAssertion: require.NoError,
}, {
name: "Teleport Node no agent",
modifyCfg: func(c *ServerConfig) {
c.TargetServer = teleportNode
},
errorAssertion: func(tt require.TestingT, err error, i ...interface{}) {
require.Error(t, err)
require.ErrorContains(t, err, "user agent required")
},
}, {
name: "OpenSSH Node",
modifyCfg: func(c *ServerConfig) {
c.TargetServer = openSSHNode
c.AgentlessSigner = &sshutils.LegacySHA1Signer{}
},
errorAssertion: require.NoError,
}, {
name: "OpenSSH Node no signer",
modifyCfg: func(c *ServerConfig) {
c.TargetServer = openSSHNode
},
errorAssertion: func(tt require.TestingT, err error, i ...interface{}) {
require.Error(t, err)
require.ErrorContains(t, err, "agentless signer is required")
},
}, {
name: "OpenSSH EICE Node",
modifyCfg: func(c *ServerConfig) {
c.TargetServer = openSSHEICENode
},
errorAssertion: require.NoError,
},
} {
t.Run(tt.name, func(t *testing.T) {
config := &ServerConfig{
LocalAuthClient: &authclient.Client{},
TargetClusterAccessPoint: &authclient.Client{},
DataDir: "datadir",
TargetConn: &net.UnixConn{},
SrcAddr: &net.IPAddr{},
DstAddr: &net.IPAddr{},
HostCertificate: &sshutils.LegacySHA1Signer{},
Clock: clockwork.NewFakeClock(),
Emitter: &authclient.Client{},
LockWatcher: &services.LockWatcher{},
}

tt.modifyCfg(config)

err := config.CheckDefaults()
tt.errorAssertion(t, err)
})
}
}
Loading