Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Ofirschwartz/ms defender user context #6

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ There are multiple ways to run this sample: locally using Ollama or Azure OpenAI

See the [cost estimation](./docs/cost.md) details for running this sample on Azure.

#### (Optional) Enable additional user context to Microsoft Defender for Cloud
In case you have Microsoft Defender for Cloud protection on your Azure OpenAI resource and you want to have additional context on the alerts, run this command:
```bash
azd env set MS_DEFENDER_ENABLED true
```

To customize the application name of the context, run this command:
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved
```bash
azd env set APPLICATION_NAME <your application name>
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved
```

#### Deploy the sample

1. Open a terminal and navigate to the root of the project.
Expand Down
1 change: 1 addition & 0 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,4 @@ output OPENAI_API_VERSION string = useAzureOpenAi ? openAiApiVersion : ''
output OPENAI_MODEL_NAME string = chatModelName

output WEBAPP_URL string = webapp.outputs.uri
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved

10 changes: 9 additions & 1 deletion packages/api/src/functions/chat-post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { AIChatCompletionRequest, AIChatCompletionDelta, AIChatCompletion } from
import { AzureOpenAI, OpenAI } from "openai";
import 'dotenv/config';
import { ChatCompletionChunk } from 'openai/resources';
import { getMsDefenderUserJson } from "./security/ms-defender-utils"

const azureOpenAiScope = 'https://cognitiveservices.azure.com/.default';
const systemPrompt = `Assistant helps the user with cooking questions. Be brief in your answers. Answer only plain text, DO NOT use Markdown.
Expand Down Expand Up @@ -53,6 +54,11 @@ export async function postChat(stream: boolean, request: HttpRequest, context: I
throw new Error('No OpenAI API key or Azure OpenAI deployment provided');
}

var userContext: string | undefined;
if (process.env.MS_DEFENDER_ENABLED) {
userContext = getMsDefenderUserJson(request);
}

if (stream) {
const responseStream = await openai.chat.completions.create({
messages: [
Expand All @@ -61,7 +67,8 @@ export async function postChat(stream: boolean, request: HttpRequest, context: I
],
temperature: 0.7,
model,
stream: true
stream: true,
user: userContext,
});
const jsonStream = Readable.from(createJsonStream(responseStream));

Expand All @@ -80,6 +87,7 @@ export async function postChat(stream: boolean, request: HttpRequest, context: I
],
temperature: 0.7,
model,
user: userContext,
});

return {
Expand Down
74 changes: 74 additions & 0 deletions packages/api/src/functions/security/ms-defender-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import process from 'node:process';
import { HttpRequest } from '@azure/functions';

export function getMsDefenderUserJson(request: HttpRequest): string {
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved

const sourceIp = getSourceIp(request);
const authenticatedUserDetails = getAuthenticatedUserDetails(request);

const userObject = {
"EndUserTenantId": authenticatedUserDetails.get('tenantId'),
"EndUserId": authenticatedUserDetails.get('userId'),
"EndUserIdType": authenticatedUserDetails.get('identityProvider'),
"SourceIp": sourceIp,
"SourceRequestHeaders": extractSpecificHeaders(request),
"ApplicationName": process.env.APPLICATION_NAME,
};

var userContextJsonString = JSON.stringify(userObject);
return userContextJsonString;
}

function extractSpecificHeaders(request: HttpRequest): any {
const headerNames = ['User-Agent', 'X-Forwarded-For', 'Forwarded', 'X-Real-IP', 'True-Client-IP', 'CF-Connecting-IP'];
var relevantHeaders = new Map<string, string>();

for (const header of headerNames) {
if (request.headers.has(header)) {
relevantHeaders.set(header, request.headers.get(header)!);
}
}

return Object.fromEntries(relevantHeaders);
}

function getAuthenticatedUserDetails(request: HttpRequest) : Map<string, string> {
var authenticatedUserDetails = new Map<string, string>();
var principalHeader = request.headers.get('X-Ms-Client-Principal');
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved
if (principalHeader == null) {
return authenticatedUserDetails;
}

const principal = parsePrincipal(principalHeader);
if (principal != null) {
var idp = principal['identityProvider'] == "aad" ? "EntraId" : principal['identityProvider'];
authenticatedUserDetails.set('identityProvider', idp);
}

if (principal['identityProvider'] == "aad") {
// TODO: add only when userId represents actual IDP user id
// authenticatedUserDetails.set('userId', principal['userId']);
if (process.env.AZURE_TENANT_ID != null) {
authenticatedUserDetails.set('tenantId', process.env.AZURE_TENANT_ID);
}
}

return authenticatedUserDetails
}

function parsePrincipal(principal : string | null) : any {
if (principal == null) {
return null;
}

try {
return JSON.parse(Buffer.from(principal, 'base64').toString('utf-8'));
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved
} catch (error) {
return null;
}
}

function getSourceIp(request: HttpRequest) {
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved
var sourceIp = request.headers.get('X-Forwarded-For') ?? "";
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved
return sourceIp.split(',')[0].split(':')[0]
ofirschwartz1 marked this conversation as resolved.
Show resolved Hide resolved
}