diff --git a/example/trivial/trivial.go b/example/trivial/trivial.go
index 6aa478a0..e8be7cb9 100644
--- a/example/trivial/trivial.go
+++ b/example/trivial/trivial.go
@@ -27,7 +27,7 @@ func main() {
}
rootURL, _ := url.Parse("http://localhost:8000")
- idpMetadataURL, _ := url.Parse("https://www.testshib.org/metadata/testshib-providers.xml")
+ idpMetadataURL, _ := url.Parse("https://samltest.id/saml/idp")
idpMetadata, err := samlsp.FetchMetadata(
context.Background(),
@@ -42,6 +42,7 @@ func main() {
IDPMetadata: idpMetadata,
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
Certificate: keyPair.Leaf,
+ SignRequest: true,
})
if err != nil {
panic(err) // TODO handle error
diff --git a/samlsp/new.go b/samlsp/new.go
index 451a65aa..a28a46b4 100644
--- a/samlsp/new.go
+++ b/samlsp/new.go
@@ -5,6 +5,7 @@ import (
"context"
"crypto/rsa"
"crypto/x509"
+ dsig "github.com/russellhaering/goxmldsig"
"net/http"
"net/url"
"time"
@@ -22,6 +23,7 @@ type Options struct {
Intermediates []*x509.Certificate
AllowIDPInitiated bool
IDPMetadata *saml.EntityDescriptor
+ SignRequest bool
ForceAuthn bool // TODO(ross): this should be *bool
// The following fields exist <= 0.3.0, but are superceded by the new
@@ -125,6 +127,10 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
if opts.ForceAuthn {
forceAuthn = &opts.ForceAuthn
}
+ signatureMethod := dsig.RSASHA1SignatureMethod
+ if !opts.SignRequest {
+ signatureMethod = ""
+ }
return saml.ServiceProvider{
EntityID: opts.EntityID,
@@ -136,6 +142,7 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
SloURL: *sloURL,
IDPMetadata: opts.IDPMetadata,
ForceAuthn: forceAuthn,
+ SignatureMethod: signatureMethod,
AllowIDPInitiated: opts.AllowIDPInitiated,
}
}
diff --git a/service_provider.go b/service_provider.go
index 04d3eddf..a8b64a58 100644
--- a/service_provider.go
+++ b/service_provider.go
@@ -4,6 +4,7 @@ import (
"bytes"
"compress/flate"
"crypto/rsa"
+ "crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/xml"
@@ -101,6 +102,9 @@ type ServiceProvider struct {
// SignatureVerifier, if non-nil, allows you to implement an alternative way
// to verify signatures.
SignatureVerifier SignatureVerifier
+
+ // SignatureMethod, if non-empty, authentication requests will be signed
+ SignatureMethod string
}
// MaxIssueDelay is the longest allowed time between when a SAML assertion is
@@ -126,7 +130,7 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor {
validDuration = sp.MetadataValidDuration
}
- authnRequestsSigned := false
+ authnRequestsSigned := len(sp.SignatureMethod) > 0
wantAssertionsSigned := true
validUntil := TimeNow().Add(validDuration)
@@ -137,12 +141,6 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor {
certBytes = append(certBytes, intermediate.Raw...)
}
keyDescriptors = []KeyDescriptor{
- {
- Use: "signing",
- KeyInfo: KeyInfo{
- Certificate: base64.StdEncoding.EncodeToString(certBytes),
- },
- },
{
Use: "encryption",
KeyInfo: KeyInfo{
@@ -156,6 +154,14 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor {
},
},
}
+ if len(sp.SignatureMethod) > 0 {
+ keyDescriptors = append(keyDescriptors, KeyDescriptor{
+ Use: "signing",
+ KeyInfo: KeyInfo{
+ Certificate: base64.StdEncoding.EncodeToString(certBytes),
+ },
+ })
+ }
}
return &EntityDescriptor{
@@ -330,9 +336,51 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string) (*AuthnReque
},
ForceAuthn: sp.ForceAuthn,
}
+ if len(sp.SignatureMethod) > 0 {
+ if err := sp.SignAuthnRequest(&req); err != nil {
+ return nil, err
+ }
+ }
return &req, nil
}
+// SignAuthnRequest adds the `Signature` element to the `AuthnRequest`.
+func (sp *ServiceProvider) SignAuthnRequest(req *AuthnRequest) error {
+ keyPair := tls.Certificate{
+ Certificate: [][]byte{sp.Certificate.Raw},
+ PrivateKey: sp.Key,
+ Leaf: sp.Certificate,
+ }
+ // TODO: add intermediates for SP
+ //for _, cert := range sp.Intermediates {
+ // keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
+ //}
+ keyStore := dsig.TLSCertKeyStore(keyPair)
+
+ if sp.SignatureMethod != dsig.RSASHA1SignatureMethod &&
+ sp.SignatureMethod != dsig.RSASHA256SignatureMethod &&
+ sp.SignatureMethod != dsig.RSASHA512SignatureMethod {
+ return fmt.Errorf("invalid signing method %s", sp.SignatureMethod)
+ }
+ signatureMethod := sp.SignatureMethod
+ signingContext := dsig.NewDefaultSigningContext(keyStore)
+ signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
+ if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
+ return err
+ }
+
+ assertionEl := req.Element()
+
+ signedRequestEl, err := signingContext.SignEnveloped(assertionEl)
+ if err != nil {
+ return err
+ }
+
+ sigEl := signedRequestEl.Child[len(signedRequestEl.Child)-1]
+ req.Signature = sigEl.(*etree.Element)
+ return nil
+}
+
// MakePostAuthenticationRequest creates a SAML authentication request using
// the HTTP-POST binding. It returns HTML text representing an HTML form that
// can be sent presented to a browser to initiate the login process.
diff --git a/service_provider_test.go b/service_provider_test.go
index ff908ceb..ac831215 100644
--- a/service_provider_test.go
+++ b/service_provider_test.go
@@ -98,7 +98,7 @@ func TestSPCanSetAuthenticationNameIDFormat(t *testing.T) {
assert.Equal(t, string(EmailAddressNameIDFormat), *req.NameIDPolicy.Format)
}
-func TestSPCanProduceMetadata(t *testing.T) {
+func TestSPCanProduceMetadataWithEncryptionCert(t *testing.T) {
test := NewServiceProviderTest()
s := ServiceProvider{
Key: test.Key,
@@ -116,13 +116,43 @@ func TestSPCanProduceMetadata(t *testing.T) {
assert.Equal(t, ""+
"\n"+
" \n"+
- " \n"+
+ " \n"+
" \n"+
" \n"+
" MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ==\n"+
" \n"+
" \n"+
+ " \n"+
+ " \n"+
+ " \n"+
+ " \n"+
" \n"+
+ " \n"+
+ " \n"+
+ " \n"+
+ "",
+ string(spMetadata))
+}
+
+func TestSPCanProduceMetadataWithBothCerts(t *testing.T) {
+ test := NewServiceProviderTest()
+ s := ServiceProvider{
+ Key: test.Key,
+ Certificate: test.Certificate,
+ MetadataURL: mustParseURL("https://example.com/saml2/metadata"),
+ AcsURL: mustParseURL("https://example.com/saml2/acs"),
+ SloURL: mustParseURL("https://example.com/saml2/slo"),
+ IDPMetadata: &EntityDescriptor{},
+ SignatureMethod: "not-empty",
+ }
+ err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata)
+ assert.NoError(t, err)
+
+ spMetadata, err := xml.MarshalIndent(s.Metadata(), "", " ")
+ assert.NoError(t, err)
+ assert.Equal(t, ""+
+ "\n"+
+ " \n"+
" \n"+
" \n"+
" \n"+
@@ -134,6 +164,13 @@ func TestSPCanProduceMetadata(t *testing.T) {
" \n"+
" \n"+
" \n"+
+ " \n"+
+ " \n"+
+ " \n"+
+ " MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ==\n"+
+ " \n"+
+ " \n"+
+ " \n"+
" \n"+
" \n"+
" \n"+
@@ -141,7 +178,7 @@ func TestSPCanProduceMetadata(t *testing.T) {
string(spMetadata))
}
-func TestCanProduceMetadataNoSigningKey(t *testing.T) {
+func TestCanProduceMetadataNoCerts(t *testing.T) {
test := NewServiceProviderTest()
s := ServiceProvider{
MetadataURL: mustParseURL("https://example.com/saml2/metadata"),
@@ -248,6 +285,62 @@ func TestSPCanProducePostRequest(t *testing.T) {
string(form))
}
+func TestSPCanProduceSignedRequest(t *testing.T) {
+ test := NewServiceProviderTest()
+ TimeNow = func() time.Time {
+ rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 UTC 2006", "Mon Dec 1 01:31:21.123456789 UTC 2015")
+ return rv
+ }
+ Clock = dsig.NewFakeClockAt(TimeNow())
+ s := ServiceProvider{
+ Key: test.Key,
+ Certificate: test.Certificate,
+ MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"),
+ AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"),
+ IDPMetadata: &EntityDescriptor{},
+ SignatureMethod: dsig.RSASHA1SignatureMethod,
+ }
+ err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata)
+ assert.NoError(t, err)
+
+ redirectURL, err := s.MakeRedirectAuthenticationRequest("relayState")
+ assert.NoError(t, err)
+
+ decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL)
+ assert.NoError(t, err)
+ assert.Equal(t,
+ "idp.testshib.org",
+ redirectURL.Host)
+ assert.Equal(t,
+ "/idp/profile/SAML2/Redirect/SSO",
+ redirectURL.Path)
+ assert.Equal(t,
+ "https://15661444.ngrok.io/saml2/metadataXQ5+kdgOf34vpAemZRFalLlzjr0=Wtomi/PiWx0bMFlImy5soCrrDbdY4BR2Qb8woGqc8KsVtXAwvl6lfYE2tuoT0YS5ipPLMMsFG8dB1TmLcA+0lnUcqfBiTiiHEwTIo3193RIsoH3STlOmXqBQf9Ax2nRdX+/4HwIYF58lgUzOb+nur+zGL6mYw2xjQBw6YGaX9Cc=MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ==",
+ string(decodedRequest))
+}
+
+func TestSPFailToProduceSignedRequestWithBogusSignatureMethod(t *testing.T) {
+ test := NewServiceProviderTest()
+ TimeNow = func() time.Time {
+ rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 UTC 2006", "Mon Dec 1 01:31:21.123456789 UTC 2015")
+ return rv
+ }
+ Clock = dsig.NewFakeClockAt(TimeNow())
+ s := ServiceProvider{
+ Key: test.Key,
+ Certificate: test.Certificate,
+ MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"),
+ AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"),
+ IDPMetadata: &EntityDescriptor{},
+ SignatureMethod: "bogus",
+ }
+ err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata)
+ assert.NoError(t, err)
+
+ _, err = s.MakeRedirectAuthenticationRequest("relayState")
+ assert.Errorf(t, err, "invalid signing method bogus")
+}
+
func TestSPCanProducePostLogoutRequest(t *testing.T) {
test := NewServiceProviderTest()
TimeNow = func() time.Time {