diff --git a/certificate/certificate.go b/certificate/certificate.go index 781752bb7..c74dc45f1 100644 --- a/certificate/certificate.go +++ b/certificate/certificate.go @@ -8,6 +8,7 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/x509" + "encoding/asn1" "encoding/base64" "fmt" "log" @@ -447,6 +448,20 @@ func getPublicKeyInfo(cert *x509.Certificate) (SubjectPublicKeyInfo, error) { } +func GetHexASN1Serial(cert *x509.Certificate) (serial string, err error) { + m, err := asn1.Marshal(cert.SerialNumber) + if err != nil { + return + } + var rawValue asn1.RawValue + _, err = asn1.Unmarshal(m, &rawValue) + if err != nil { + return + } + serial = fmt.Sprintf("%X", rawValue.Bytes) + return +} + //certtoStored returns a Certificate struct created from a X509.Certificate func CertToStored(cert *x509.Certificate, parentSignature, domain, ip string, TSName string, valInfo *ValidationInfo) Certificate { var ( @@ -458,7 +473,11 @@ func CertToStored(cert *x509.Certificate, parentSignature, domain, ip string, TS stored.IPs = make([]string, 0) stored.Version = cert.Version - stored.Serial = fmt.Sprintf("%X", cert.SerialNumber) + + // If there's an error, we just store the zero value ("") + serial, _ := GetHexASN1Serial(cert) + stored.Serial = serial + stored.SignatureAlgorithm = SignatureAlgorithm[cert.SignatureAlgorithm] stored.Key, err = getPublicKeyInfo(cert) diff --git a/certificate/certificate_test.go b/certificate/certificate_test.go new file mode 100644 index 000000000..d4c920f35 --- /dev/null +++ b/certificate/certificate_test.go @@ -0,0 +1,42 @@ +package certificate + +import ( + "crypto/x509" + "math/big" + "testing" +) + +func TestGetHexASN1Serial(t *testing.T) { + type testcase struct { + input *x509.Certificate + output string + } + testcases := []testcase{ + { + &x509.Certificate{SerialNumber: big.NewInt(-1)}, + "FF", + }, + { + &x509.Certificate{SerialNumber: big.NewInt(1)}, + "01", + }, + { + &x509.Certificate{SerialNumber: big.NewInt(0)}, + "00", + }, + { + &x509.Certificate{SerialNumber: big.NewInt(201)}, + "00C9", + }, + { + &x509.Certificate{SerialNumber: big.NewInt(-201)}, + "FF37", + }, + } + for _, tc := range testcases { + serial, _ := GetHexASN1Serial(tc.input) + if serial != tc.output { + t.Errorf("Expected %s, got %s", tc.output, serial) + } + } +} diff --git a/tools/fixserialnumber.go b/tools/fixserialnumber.go new file mode 100644 index 000000000..2242ab422 --- /dev/null +++ b/tools/fixserialnumber.go @@ -0,0 +1,173 @@ +package main + +import ( + "crypto/x509" + "encoding/base64" + "flag" + "fmt" + "log" + "math" + "os" + + "github.com/mozilla/tls-observatory/certificate" + "github.com/mozilla/tls-observatory/database" +) + +type job struct { + id int64 + // Needs to be *string because apparently serial numbers can be NULL in the db + currentSerialNumber *string + cert *x509.Certificate +} + +type result struct { + id int64 + changed bool + err error +} + +func main() { + var workerCount int + var batchSize int64 + var minID int64 + var maxID int64 + flag.IntVar(&workerCount, "workers", 4, "Number of workers to use") + flag.Int64Var(&batchSize, "batchSize", 1000, "Batch size") + flag.Int64Var(&minID, "minID", 0, "Minimum certificate ID to modify") + flag.Int64Var(&maxID, "maxID", math.MaxInt64, "Maximum certificate ID to modify") + flag.Parse() + jobs := make(chan job, batchSize) + results := make(chan result, batchSize) + + db, err := database.RegisterConnection( + os.Getenv("TLSOBS_POSTGRESDB"), + os.Getenv("TLSOBS_POSTGRESUSER"), + os.Getenv("TLSOBS_POSTGRESPASS"), + os.Getenv("TLSOBS_POSTGRES"), + "disable", + ) + if err != nil { + log.Fatalf("Error connecting to database: %s", err) + } + defer db.Close() + + for w := 1; w <= workerCount; w++ { + go worker(w, jobs, results, db) + } + changedCount := 0 + errorCount := 0 + total := 0 + go func() { + for { + log.Printf("Fetching %d certificates with id > %d", batchSize, minID) + nextBatch, err := fetchNextBatchWithRetries(5, db, minID, batchSize) + if err != nil { + log.Fatalf("Error fetching next batch: %s", err) + } + if len(nextBatch) == 0 || minID >= maxID { + close(jobs) + close(results) + log.Printf("Done. %d/%d errors. %d/%d changed.", errorCount, total, changedCount, total) + return + } + total += len(nextBatch) + for _, j := range nextBatch { + jobs <- j + minID = j.id + } + } + }() + for result := range results { + if result.err != nil { + errorCount++ + log.Printf("Received error for cert id %d: %s", result.id, result.err) + } + if result.changed { + changedCount++ + } + } +} + +func fetchNextBatchWithRetries(retries int, db *database.DB, minID int64, batchSize int64) (jobs []job, err error) { + for i := 0; i < retries; i++ { + jobs, err = fetchNextBatch(db, minID, batchSize) + if err == nil { + break + } + } + return +} + +func fetchNextBatch(db *database.DB, minID int64, batchSize int64) ([]job, error) { + rows, err := db.Query(`SELECT id, serial_number, raw_cert + FROM certificates + WHERE id > $1 + ORDER BY id + LIMIT $2`, + minID, + batchSize, + ) + if err != nil { + log.Fatalf("Error querying database: %s", err) + } + defer rows.Close() + var jobs []job + for rows.Next() { + var j job + var b64Crt string + if err = rows.Scan(&j.id, &j.currentSerialNumber, &b64Crt); err != nil { + return nil, fmt.Errorf("Error scanning row: %s", err) + } + cert, err := b64RawCertToX509Cert(b64Crt) + if err != nil { + log.Printf("Error converting database certificate to crypto/x509 certificate: %s", err) + continue + } + j.cert = cert + jobs = append(jobs, j) + } + return jobs, nil +} + +func b64RawCertToX509Cert(b64Crt string) (*x509.Certificate, error) { + rawCert, err := base64.StdEncoding.DecodeString(b64Crt) + if err != nil { + return nil, fmt.Errorf("Error b64 decoding certificate: %s", err) + } + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return nil, fmt.Errorf("Error parsing x509 certificate: %s", err) + } + return cert, nil +} + +func worker(id int, jobs <-chan job, results chan result, db *database.DB) { + for j := range jobs { + correctSerialNumber, err := certificate.GetHexASN1Serial(j.cert) + if err != nil { + results <- result{id: j.id, err: err} + continue + } + if correctSerialNumber == *j.currentSerialNumber { + // Serial number is already stored correctly in the database + results <- result{id: j.id, err: nil} + continue + } + err = updateSerialNumberInDB(db, j.id, correctSerialNumber) + if err != nil { + results <- result{ + id: j.id, + err: fmt.Errorf("Error updating serial number in database: %s", err), + } + continue + } + results <- result{id: j.id, err: nil, changed: true} + } +} + +func updateSerialNumberInDB(db *database.DB, id int64, correctSerialNumber string) error { + _, err := db.Exec(`UPDATE certificates + SET serial_number = $1 + WHERE id = $2`, correctSerialNumber, id) + return err +}