Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Azure OpenAI detector #2347

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat(azure): create openai detector
  • Loading branch information
rgmz committed Nov 20, 2024
commit 7f4f6de86d0a715080242d1ec5ac2859a19ace42
172 changes: 172 additions & 0 deletions pkg/detectors/azure_openai/azure_openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package azure_openai

import (
rgmz marked this conversation as resolved.
Show resolved Hide resolved
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

regexp "github.com/wasilibs/go-re2"

"github.com/trufflesecurity/trufflehog/v3/pkg/cache/simple"
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
)

// Scanner detects API keys for Azure's OpenAI service.
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
type Scanner struct {
client *http.Client
}

// Ensure the Scanner satisfies the interface at compile time.
var _ detectors.Detector = (*Scanner)(nil)

var (
// TODO: Investigate custom `azure-api.net` endpoints.
// https://github.com/openai/openai-python#microsoft-azure-openai
azureUrlPat = regexp.MustCompile(`(?i)([a-z0-9-]+\.openai\.azure\.com)`)
azureKeyPat = regexp.MustCompile(detectors.PrefixRegex([]string{"api[_.-]?key", "openai[_.-]?key"}) + `\b(?-i:([a-f0-9]{32}))\b`)

invalidServices = simple.NewCache[struct{}]()
)

// Keywords are used for efficiently pre-filtering chunks.
// Use identifiers in the secret preferably, or the provider name.
func (s Scanner) Keywords() []string {
return []string{".openai.azure.com"}
}

func (s Scanner) Type() detectorspb.DetectorType {
return detectorspb.DetectorType_AzureOpenAI
}

func (s Scanner) Description() string {
return "Azure OpenAI provides various AI models and services. The API keys can be used to access and interact with these models and services."
}

// FromData will find and optionally verify OpenAI secrets in a given set of bytes.
func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (results []detectors.Result, err error) {
dataStr := string(data)

// De-duplicate results.
tokens := make(map[string]struct{})
for _, match := range azureKeyPat.FindAllStringSubmatch(dataStr, -1) {
tokens[match[1]] = struct{}{}
}
if len(tokens) == 0 {
return
}
urls := make(map[string]struct{})
for _, match := range azureUrlPat.FindAllStringSubmatch(dataStr, -1) {
u := match[1]
if invalidServices.Exists(u) {
continue
}
urls[u] = struct{}{}
}

// Process results.
logCtx := logContext.AddLogger(ctx)
for token := range tokens {
s1 := detectors.Result{
DetectorType: s.Type(),
Redacted: token[:3] + "..." + token[25:],
Raw: []byte(token),
}

for url := range urls {
if verify {
client := s.client
if client == nil {
client = common.SaneHttpClient()
}

isVerified, extraData, verificationErr := verifyAzureToken(logCtx, client, url, token)
if isVerified || len(urls) == 1 {
s1.RawV2 = []byte(token + ":" + url)
s1.Verified = isVerified
s1.ExtraData = extraData
s1.SetVerificationError(verificationErr, token)
break
}

// Instance doesn't exist.
// Verification issue: lookup azsdk-east-us.openai.azure.com: no such host
if verificationErr != nil && strings.Contains(verificationErr.Error(), "no such host") {
delete(urls, url)
invalidServices.Set(url, struct{}{})
}
}
}

results = append(results, s1)
}
return
}

func verifyAzureToken(ctx logContext.Context, client *http.Client, baseUrl, token string) (bool, map[string]string, error) {
// TODO: Replace this with a more suitable long-term endpoint.
// Most endpoints require additional info, e.g., deployment name, which complicates verification.
// This may be retired in the future, so we should look for another candidate.
// https://learn.microsoft.com/en-us/answers/questions/1371786/get-azure-openai-deployments-in-api
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://%s/openai/deployments?api-version=2023-03-15-preview", baseUrl), nil)
if err != nil {
return false, nil, nil
}

req.Header.Set("Api-Key", token)
req.Header.Set("Content-Type", "application/json")
res, err := client.Do(req)
if err != nil {
return false, nil, err
}
defer func() {
_, _ = io.Copy(io.Discard, res.Body)
_ = res.Body.Close()
}()

switch res.StatusCode {
case http.StatusOK:
body, err := io.ReadAll(res.Body)
if err != nil {
return false, nil, err
}

var deployments deploymentsResponse
if err := json.Unmarshal(body, &deployments); err != nil {
if json.Valid(body) {
return false, nil, fmt.Errorf("failed to decode response %s: %w", req.URL, err)
} else {
// If the response isn't JSON it's highly unlikely to be valid.
return false, nil, nil
}
}

// JSON unmarshal doesn't check whether the structure actually matches.
if deployments.Object == "" {
return false, nil, nil
}

// No extra data available at the moment.
return true, nil, nil
case http.StatusUnauthorized:
return false, nil, nil
default:
return false, nil, fmt.Errorf("unexpected response status %d for %s", res.StatusCode, req.URL)
}
}

type deploymentsResponse struct {
Data []deployment `json:"data"`
Object string `json:"object"`
}

type deployment struct {
ID string `json:"id"`
}
159 changes: 159 additions & 0 deletions pkg/detectors/azure_openai/azure_openai_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
//go:build detectors
// +build detectors

package azure_openai

import (
"context"
"fmt"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb"
"testing"
"time"
)

func TestAzureOpenAI_FromChunk(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
testSecrets, err := common.GetSecret(ctx, "trufflehog-testing", "detectors5")
if err != nil {
t.Fatalf("could not get test secrets from GCP: %s", err)
}
secret := testSecrets.MustGetField("AZUREOPENAI")
inactiveSecret := testSecrets.MustGetField("AZUREOPENAI_INACTIVE")

type args struct {
ctx context.Context
data []byte
verify bool
}
tests := []struct {
name string
s Scanner
args args
want []detectors.Result
wantErr bool
wantVerificationErr bool
}{
{
name: "found, verified",
s: Scanner{},
args: args{
ctx: context.Background(),
data: []byte(fmt.Sprintf("You can find a azureopenai secret %s within", secret)),
verify: true,
},
want: []detectors.Result{
{
DetectorType: detectorspb.DetectorType_AzureOpenAI,
Verified: true,
},
},
wantErr: false,
wantVerificationErr: false,
},
{
name: "found, unverified",
s: Scanner{},
args: args{
ctx: context.Background(),
data: []byte(fmt.Sprintf("You can find a azureopenai secret %s within but not valid", inactiveSecret)), // the secret would satisfy the regex but not pass validation
verify: true,
},
want: []detectors.Result{
{
DetectorType: detectorspb.DetectorType_AzureOpenAI,
Verified: false,
},
},
wantErr: false,
wantVerificationErr: false,
},
{
name: "not found",
s: Scanner{},
args: args{
ctx: context.Background(),
data: []byte("You cannot find the secret within"),
verify: true,
},
want: nil,
wantErr: false,
wantVerificationErr: false,
},
{
name: "found, would be verified if not for timeout",
s: Scanner{client: common.SaneHttpClientTimeOut(1 * time.Microsecond)},
args: args{
ctx: context.Background(),
data: []byte(fmt.Sprintf("You can find a azureopenai secret %s within", secret)),
verify: true,
},
want: []detectors.Result{
{
DetectorType: detectorspb.DetectorType_AzureOpenAI,
Verified: false,
},
},
wantErr: false,
wantVerificationErr: true,
},
{
name: "found, verified but unexpected api surface",
s: Scanner{client: common.ConstantResponseHttpClient(404, "")},
args: args{
ctx: context.Background(),
data: []byte(fmt.Sprintf("You can find a azureopenai secret %s within", secret)),
verify: true,
},
want: []detectors.Result{
{
DetectorType: detectorspb.DetectorType_AzureOpenAI,
Verified: false,
},
},
wantErr: false,
wantVerificationErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.s.FromData(tt.args.ctx, tt.args.verify, tt.args.data)
if (err != nil) != tt.wantErr {
t.Errorf("Azureopenai.FromData() error = %v, wantErr %v", err, tt.wantErr)
return
}
for i := range got {
if len(got[i].Raw) == 0 {
t.Fatalf("no raw secret present: \n %+v", got[i])
}
if (got[i].VerificationError() != nil) != tt.wantVerificationErr {
t.Fatalf("wantVerificationError = %v, verification error = %v", tt.wantVerificationErr, got[i].VerificationError())
}
}
ignoreOpts := cmpopts.IgnoreFields(detectors.Result{}, "Raw", "verificationError")
if diff := cmp.Diff(got, tt.want, ignoreOpts); diff != "" {
t.Errorf("Azureopenai.FromData() %s diff: (-got +want)\n%s", tt.name, diff)
}
})
}
}

func BenchmarkFromData(benchmark *testing.B) {
ctx := context.Background()
s := Scanner{}
for name, data := range detectors.MustGetBenchmarkData() {
benchmark.Run(name, func(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
_, err := s.FromData(ctx, false, data)
if err != nil {
b.Fatal(err)
}
}
})
}
}
Loading
Loading