diff --git a/v3/bind.go b/v3/bind.go index 5c7dd999..6cfd37eb 100644 --- a/v3/bind.go +++ b/v3/bind.go @@ -651,6 +651,9 @@ type GSSAPIClient interface { // to InitSecContext via the token parameters. // See RFC 4752 section 3.1. InitSecContext(target string, token []byte) (outputToken []byte, needContinue bool, err error) + // InitSecContextWithOptions is the same as InitSecContext but allows for additional options to be passed to the context establishment. + // See RFC 4752 section 3.1. + InitSecContextWithOptions(target string, token []byte, options []int) (outputToken []byte, needContinue bool, err error) // NegotiateSaslAuth performs the last step of the Sasl handshake. // It takes a token, which, when unwrapped, describes the servers supported // security layers (first octet) and maximum receive buffer (remaining @@ -688,6 +691,11 @@ func (l *Conn) GSSAPIBind(client GSSAPIClient, servicePrincipal, authzid string) // GSSAPIBindRequest performs the GSSAPI SASL bind using the provided GSSAPI client. func (l *Conn) GSSAPIBindRequest(client GSSAPIClient, req *GSSAPIBindRequest) error { + return l.GSSAPIBindRequestWithAPOptions(client, req, []int{}) +} + +// GSSAPIBindRequest performs the GSSAPI SASL bind using the provided GSSAPI client. +func (l *Conn) GSSAPIBindRequestWithAPOptions(client GSSAPIClient, req *GSSAPIBindRequest, APOptions []int) error { //nolint:errcheck defer client.DeleteSecContext() @@ -698,7 +706,7 @@ func (l *Conn) GSSAPIBindRequest(client GSSAPIClient, req *GSSAPIBindRequest) er for { if needInit { // Establish secure context between client and server. - reqToken, needInit, err = client.InitSecContext(req.ServicePrincipalName, recvToken) + reqToken, needInit, err = client.InitSecContextWithOptions(req.ServicePrincipalName, recvToken, APOptions) if err != nil { return err } diff --git a/v3/gssapi/client.go b/v3/gssapi/client.go index d6c6dbd4..c8f20a70 100644 --- a/v3/gssapi/client.go +++ b/v3/gssapi/client.go @@ -1,6 +1,10 @@ package gssapi import ( + "bytes" + "encoding/binary" + "encoding/hex" + "errors" "fmt" "github.com/jcmturner/gokrb5/v8/client" @@ -100,6 +104,13 @@ func (client *Client) DeleteSecContext() error { // GSS-API between the client and server. // See RFC 4752 section 3.1. func (client *Client) InitSecContext(target string, input []byte) ([]byte, bool, error) { + return client.InitSecContextWithOptions(target, input, []int{}) +} + +// InitSecContextWithOptions initiates the establishment of a security context for +// GSS-API between the client and server. +// See RFC 4752 section 3.1. +func (client *Client) InitSecContextWithOptions(target string, input []byte, APOptions []int) ([]byte, bool, error) { gssapiFlags := []int{gssapi.ContextFlagInteg, gssapi.ContextFlagConf, gssapi.ContextFlagMutual} switch input { @@ -110,7 +121,7 @@ func (client *Client) InitSecContext(target string, input []byte) ([]byte, bool, } client.ekey = ekey - token, err := spnego.NewKRB5TokenAPREQ(client.Client, tkt, ekey, gssapiFlags, []int{}) + token, err := spnego.NewKRB5TokenAPREQ(client.Client, tkt, ekey, gssapiFlags, APOptions) if err != nil { return nil, false, err } @@ -160,7 +171,7 @@ func (client *Client) InitSecContext(target string, input []byte) ([]byte, bool, // See RFC 4752 section 3.1. func (client *Client) NegotiateSaslAuth(input []byte, authzid string) ([]byte, error) { token := &gssapi.WrapToken{} - err := token.Unmarshal(input, true) + err := UnmarshalWrapToken(token, input, true) if err != nil { return nil, err } @@ -212,3 +223,49 @@ func (client *Client) NegotiateSaslAuth(input []byte, authzid string) ([]byte, e return output, nil } + +func getGssWrapTokenId() *[2]byte { + return &[2]byte{0x05, 0x04} +} + +func UnmarshalWrapToken(wt *gssapi.WrapToken, b []byte, expectFromAcceptor bool) error { + // Check if we can read a whole header + if len(b) < 16 { + return errors.New("bytes shorter than header length") + } + // Is the Token ID correct? + if !bytes.Equal(getGssWrapTokenId()[:], b[0:2]) { + return fmt.Errorf("wrong Token ID. Expected %s, was %s", + hex.EncodeToString(getGssWrapTokenId()[:]), + hex.EncodeToString(b[0:2])) + } + // Check the acceptor flag + flags := b[2] + isFromAcceptor := flags&0x01 == 1 + if isFromAcceptor && !expectFromAcceptor { + return errors.New("unexpected acceptor flag is set: not expecting a token from the acceptor") + } + if !isFromAcceptor && expectFromAcceptor { + return errors.New("expected acceptor flag is not set: expecting a token from the acceptor, not the initiator") + } + // Check the filler byte + if b[3] != gssapi.FillerByte { + return fmt.Errorf("unexpected filler byte: expecting 0xFF, was %s ", hex.EncodeToString(b[3:4])) + } + checksumL := binary.BigEndian.Uint16(b[4:6]) + // Sanity check on the checksum length + if int(checksumL) > len(b)-gssapi.HdrLen { + return fmt.Errorf("inconsistent checksum length: %d bytes to parse, checksum length is %d", len(b), checksumL) + } + + payloadStart := 16 + checksumL + + wt.Flags = flags + wt.EC = checksumL + wt.RRC = binary.BigEndian.Uint16(b[6:8]) + wt.SndSeqNum = binary.BigEndian.Uint64(b[8:16]) + wt.CheckSum = b[16:payloadStart] + wt.Payload = b[payloadStart:] + + return nil +}