Skip to content

Commit 8a93623

Browse files
manuel-alvarez-alvarezamarziali
authored andcommitted
Add the tags returned by the service to the ai_guard span (#9931)
Add the tags returned by the service to the ai_guard span
1 parent 32b5c4f commit 8a93623

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.CONTENT;
55
import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.MESSAGES;
66
import static datadog.trace.util.Strings.isBlank;
7-
import static java.util.Collections.singletonMap;
87

98
import com.squareup.moshi.JsonAdapter;
109
import com.squareup.moshi.JsonReader;
@@ -69,7 +68,8 @@ public BadConfigurationException(final String message) {
6968
static final String REASON_TAG = "ai_guard.reason";
7069
static final String BLOCKED_TAG = "ai_guard.blocked";
7170
static final String META_STRUCT_TAG = "ai_guard";
72-
static final String META_STRUCT_KEY = "messages";
71+
static final String META_STRUCT_MESSAGES = "messages";
72+
static final String META_STRUCT_CATEGORIES = "attack_categories";
7373

7474
public static void install() {
7575
final Config config = Config.get();
@@ -208,8 +208,8 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
208208
} else {
209209
span.setTag(TARGET_TAG, "prompt");
210210
}
211-
final Map<String, Object> metaStruct =
212-
singletonMap(META_STRUCT_KEY, messagesForMetaStruct(messages));
211+
final Map<String, Object> metaStruct = new HashMap<>(2);
212+
metaStruct.put(META_STRUCT_MESSAGES, messagesForMetaStruct(messages));
213213
span.setMetaStruct(META_STRUCT_TAG, metaStruct);
214214
final Request.Builder request =
215215
new Request.Builder()
@@ -224,14 +224,21 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
224224
}
225225
final Action action = Action.valueOf(actionStr);
226226
final String reason = (String) result.get("reason");
227+
@SuppressWarnings("unchecked")
228+
final List<String> tags = (List<String>) result.get("tags");
227229
span.setTag(ACTION_TAG, action);
228-
span.setTag(REASON_TAG, reason);
230+
if (reason != null) {
231+
span.setTag(REASON_TAG, reason);
232+
}
233+
if (tags != null && !tags.isEmpty()) {
234+
metaStruct.put(META_STRUCT_CATEGORIES, tags);
235+
}
229236
final boolean shouldBlock =
230237
isBlockingEnabled(options, result.get("is_blocking_enabled")) && action != Action.ALLOW;
231238
WafMetricCollector.get().aiGuardRequest(action, shouldBlock);
232239
if (shouldBlock) {
233240
span.setTag(BLOCKED_TAG, true);
234-
throw new AIGuardAbortError(action, reason);
241+
throw new AIGuardAbortError(action, reason, tags);
235242
}
236243
return new Evaluation(action, reason);
237244
}

dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,14 @@ class AIGuardInternalTests extends DDSpecification {
157157
Request request = null
158158
Throwable error = null
159159
AIGuard.Evaluation eval = null
160+
Map<String, Object> receivedMeta = null
160161
final throwAbortError = suite.blocking && suite.action != ALLOW
161162
final call = Mock(Call) {
162163
execute() >> {
163164
return mockResponse(
164165
request,
165166
200,
166-
[data: [attributes: [action: suite.action, reason: suite.reason, is_blocking_enabled: suite.blocking]]]
167+
[data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], is_blocking_enabled: suite.blocking]]]
167168
)
168169
}
169170
}
@@ -189,11 +190,18 @@ class AIGuardInternalTests extends DDSpecification {
189190
}
190191
1 * span.setTag(AIGuardInternal.ACTION_TAG, suite.action)
191192
1 * span.setTag(AIGuardInternal.REASON_TAG, suite.reason)
192-
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, [messages: suite.messages])
193+
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> {
194+
receivedMeta = it[1] as Map<String, Object>
195+
return span
196+
}
193197
if (throwAbortError) {
194198
1 * span.addThrowable(_ as AIGuard.AIGuardAbortError)
195199
}
196200

201+
receivedMeta.messages == suite.messages
202+
if (suite.tags) {
203+
receivedMeta.attack_categories == suite.tags
204+
}
197205
assertRequest(request, suite.messages)
198206
if (throwAbortError) {
199207
error instanceof AIGuard.AIGuardAbortError
@@ -444,6 +452,14 @@ class AIGuardInternalTests extends DDSpecification {
444452
0 * span.setTag(AIGuardInternal.TOOL_TAG, _)
445453
}
446454

455+
void 'map requires even number of params'() {
456+
when:
457+
AIGuardInternal.mapOf('1', '2', '3')
458+
459+
then:
460+
thrown(IllegalArgumentException)
461+
}
462+
447463
private static assertTelemetry(final String metric, final String...tags) {
448464
final metrics = WafMetricCollector.get().with {
449465
prepareMetrics()
@@ -497,22 +513,28 @@ class AIGuardInternalTests extends DDSpecification {
497513
private static class TestSuite {
498514
private final AIGuard.Action action
499515
private final String reason
516+
private final List<String> tags
500517
private final boolean blocking
501518
private final String description
502519
private final String target
503520
private final List<AIGuard.Message> messages
504521

505-
TestSuite(AIGuard.Action action, String reason, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
522+
TestSuite(AIGuard.Action action, String reason, List<String> tags, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
506523
this.action = action
507524
this.reason = reason
525+
this.tags = tags
508526
this.blocking = blocking
509527
this.description = description
510528
this.target = target
511529
this.messages = messages
512530
}
513531

514532
static List<TestSuite> build() {
515-
def actionValues = [[ALLOW, 'Go ahead'], [DENY, 'Nope'], [ABORT, 'Kill it with fire']]
533+
def actionValues = [
534+
[ALLOW, 'Go ahead', []],
535+
[DENY, 'Nope', ['deny_everything', 'test_deny']],
536+
[ABORT, 'Kill it with fire', ['alarm_tag', 'abort_everything']]
537+
]
516538
def blockingValues = [true, false]
517539
def suiteValues = [
518540
['tool call', 'tool', TOOL_CALL],
@@ -521,7 +543,7 @@ class AIGuardInternalTests extends DDSpecification {
521543
]
522544
return combinations([actionValues, blockingValues, suiteValues] as Iterable)
523545
.collect { action, blocking, suite ->
524-
new TestSuite(action[0], action[1], blocking, suite[0], suite[1], suite[2])
546+
new TestSuite(action[0], action[1], action[2], blocking, suite[0], suite[1], suite[2])
525547
}
526548
}
527549

dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ public static Evaluation evaluate(final List<Message> messages, final Options op
6363
public static class AIGuardAbortError extends RuntimeException {
6464
private final Action action;
6565
private final String reason;
66+
private final List<String> tags;
6667

67-
public AIGuardAbortError(final Action action, final String reason) {
68+
public AIGuardAbortError(final Action action, final String reason, final List<String> tags) {
6869
super(reason);
6970
this.action = action;
7071
this.reason = reason;
72+
this.tags = tags;
7173
}
7274

7375
public Action getAction() {
@@ -77,6 +79,10 @@ public Action getAction() {
7779
public String getReason() {
7880
return reason;
7981
}
82+
83+
public List<String> getTags() {
84+
return tags;
85+
}
8086
}
8187

8288
/**

0 commit comments

Comments
 (0)