Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,20 @@ public boolean preSampleRequest(final @Nonnull AppSecRequestContext ctx) {
if (counter.tryAcquire()) {
log.debug("API security sampling is required for this request (presampled)");
ctx.setKeepOpenForApiSecurityPostProcessing(true);
// Update immediately to prevent concurrent requests from seeing the same expired state
updateApiAccessIfExpired(hash);
return true;
}
return false;
}

/** Get the final sampling decision. This method is NOT thread-safe. */
/**
* Confirms the final sampling decision.
*
* <p>This method is called after the span completes. The actual sampling decision and map update
* already happened in {@link #preSampleRequest(AppSecRequestContext)} to prevent race conditions.
* This method only serves as a final confirmation gate before schema extraction.
*/
@Override
public boolean sampleRequest(AppSecRequestContext ctx) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems useless with the new approach but, I decided to maintain it to keep the checks although updateApiAccessIfExpired is not necessary anymore

if (ctx == null) {
Expand All @@ -88,7 +96,7 @@ public boolean sampleRequest(AppSecRequestContext ctx) {
// This should never happen, it should have been short-circuited before.
return false;
}
return updateApiAccessIfExpired(hash);
return true;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,26 @@ public AppSecSpanPostProcessor(ApiSecuritySampler sampler, EventProducerService

@Override
public void process(@Nonnull AgentSpan span, @Nonnull BooleanSupplier timeoutCheck) {
final RequestContext ctx_ = span.getRequestContext();
if (ctx_ == null) {
return;
}
final AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC);
if (ctx == null) {
return;
}

if (!ctx.isKeepOpenForApiSecurityPostProcessing()) {
return;
}
AppSecRequestContext ctx = null;
RequestContext ctx_ = null;
boolean needsRelease = false;

try {
ctx_ = span.getRequestContext();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand why introducing this logic in the try/catch helps, in the end needRelease will always be false and nothing will be done in the finally bock (which is the same as it was before).

if (ctx_ == null) {
return;
}
ctx = ctx_.getData(RequestContextSlot.APPSEC);
if (ctx == null) {
return;
}

// Check if we acquired a permit for this request - must be inside try to ensure finally runs
needsRelease = ctx.isKeepOpenForApiSecurityPostProcessing();
if (!needsRelease) {
return;
}

if (timeoutCheck.getAsBoolean()) {
log.debug("Timeout detected, skipping API security post-processing");
return;
Expand All @@ -56,17 +62,25 @@ public void process(@Nonnull AgentSpan span, @Nonnull BooleanSupplier timeoutChe
log.debug("Request sampled, processing API security post-processing");
extractSchemas(ctx, ctx_.getTraceSegment());
} finally {
ctx.setKeepOpenForApiSecurityPostProcessing(false);
try {
// XXX: Close the additive first. This is not strictly needed, but it'll prevent getting it
// detected as a
// missed request-ended event.
ctx.closeWafContext();
ctx.close();
} catch (Exception e) {
log.debug("Error closing AppSecRequestContext", e);
// Always release the semaphore permit if we acquired one
if (needsRelease) {
if (ctx != null) {
ctx.setKeepOpenForApiSecurityPostProcessing(false);
// XXX: Close the additive first. This is not strictly needed, but it'll prevent getting
// it detected as a missed request-ended event.
try {
ctx.closeWafContext();
} catch (Exception e) {
log.debug("Error closing WAF context", e);
}
try {
ctx.close();
} catch (Exception e) {
log.debug("Error closing AppSecRequestContext", e);
}
}
sampler.releaseOne();
}
sampler.releaseOne();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,17 +836,19 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) {
TraceSegment traceSeg = ctx_.getTraceSegment();
Map<String, Object> tags = spanInfo.getTags();

if (maybeSampleForApiSecurity(ctx, spanInfo, tags)) {
if (!Config.get().isApmTracingEnabled()) {
traceSeg.setTagTop(Tags.ASM_KEEP, true);
traceSeg.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM);
}
} else {
boolean sampledForApiSec = maybeSampleForApiSecurity(ctx, spanInfo, tags);

if (!sampledForApiSec) {
ctx.closeWafContext();
}

// AppSec report metric and events for web span only
if (traceSeg != null) {
if (sampledForApiSec && !Config.get().isApmTracingEnabled()) {
traceSeg.setTagTop(Tags.ASM_KEEP, true);
traceSeg.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM);
}

traceSeg.setTagTop("_dd.appsec.enabled", 1);
traceSeg.setTagTop("_dd.runtime_family", "jvm");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,35 +138,40 @@ class ApiSecuritySamplerTest extends DDSpecification {
!sampled
}

void 'sampleRequest honors expiration'() {
void 'preSampleRequest honors expiration'() {
given:
def ctx = createContext('route1', 'GET', 200)
ctx.setApiSecurityEndpointHash(42L)
ctx.setKeepOpenForApiSecurityPostProcessing(true)
def ctx1 = createContext('route1', 'GET', 200)
def ctx2 = createContext('route1', 'GET', 200)
def ctx3 = createContext('route1', 'GET', 200)
final timeSource = new ControllableTimeSource()
timeSource.set(0)
final long expirationTimeInMs = 10L
final long expirationTimeInNs = expirationTimeInMs * 1_000_000
def sampler = new ApiSecuritySamplerImpl(10, expirationTimeInMs, timeSource)

when:
def sampled = sampler.sampleRequest(ctx)
when: 'first request samples'
def preSampled1 = sampler.preSampleRequest(ctx1)
def sampled1 = sampler.sampleRequest(ctx1)

then:
sampled
preSampled1
sampled1

when:
sampled = sampler.sampleRequest(ctx)
when: 'second request to same endpoint before expiration'
def preSampled2 = sampler.preSampleRequest(ctx2)

then: 'second request is not sampled'
!sampled
!preSampled2

when: 'expiration time has passed'
sampler.releaseOne()
timeSource.advance(expirationTimeInNs)
sampled = sampler.sampleRequest(ctx)
def preSampled3 = sampler.preSampleRequest(ctx3)
def sampled3 = sampler.sampleRequest(ctx3)

then: 'request is sampled again'
sampled
preSampled3
sampled3
}

void 'internal accessMap never goes beyond capacity'() {
Expand Down Expand Up @@ -198,10 +203,13 @@ class ApiSecuritySamplerTest extends DDSpecification {

expect:
for (int i = 0; i < maxCapacity * 10; i++) {
final ctx = createContext('route1', 'GET', 200 + 1)
ctx.setApiSecurityEndpointHash(i as long)
ctx.setKeepOpenForApiSecurityPostProcessing(true)
assert sampler.sampleRequest(ctx)
final ctx = createContext('route1', 'GET', 200 + i)
def preSampled = sampler.preSampleRequest(ctx)
// First request always samples, then we advance time so each subsequent request expires
assert preSampled
def sampled = sampler.sampleRequest(ctx)
assert sampled
sampler.releaseOne()
assert sampler.accessMap.size() <= 2
if (i % 2) {
timeSource.advance(expirationTimeInMs * 1_000_000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,85 @@ class AppSecSpanPostProcessorTest extends DDSpecification {
1 * sampler.releaseOne()
0 * _
}

void 'permit is released even if extractSchemas throws exception'() {
given:
def sampler = Mock(ApiSecuritySamplerImpl)
def producer = Mock(EventProducerService)
def span = Mock(AgentSpan)
def reqCtx = Mock(RequestContext)
def ctx = Mock(AppSecRequestContext)
def processor = new AppSecSpanPostProcessor(sampler, producer)

when:
processor.process(span, { false })

then:
def ex = thrown(RuntimeException)
ex.message == "Unexpected error"
1 * span.getRequestContext() >> reqCtx
1 * reqCtx.getData(_) >> ctx
1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true
1 * sampler.sampleRequest(_) >> true
1 * reqCtx.getTraceSegment() >> { throw new RuntimeException("Unexpected error") }
1 * ctx.setKeepOpenForApiSecurityPostProcessing(false)
1 * ctx.closeWafContext()
1 * ctx.close()
1 * sampler.releaseOne() // Critical: permit is still released despite exception
0 * _
}

void 'multiple requests do not exhaust semaphore permits'() {
given:
// Use real ApiSecuritySamplerImpl which has a semaphore with 4 permits
def realSampler = new ApiSecuritySamplerImpl()
def producer = Mock(EventProducerService)
def processor = new AppSecSpanPostProcessor(realSampler, producer)

when: 'Process 5 consecutive requests that acquire permits'
5.times { i ->
def span = Mock(AgentSpan)
def reqCtx = Mock(RequestContext)
def ctx = Mock(AppSecRequestContext)

// Mock the interactions
span.getRequestContext() >> reqCtx
reqCtx.getData(_) >> ctx
ctx.isKeepOpenForApiSecurityPostProcessing() >> true
ctx.setKeepOpenForApiSecurityPostProcessing(false)
ctx.closeWafContext()
ctx.close()

// Process should complete without issues, releasing permit each time
processor.process(span, { false })
}

then: 'All requests complete successfully without permit exhaustion'
noExceptionThrown()
}

void 'permit is released when ctx cleanup operations fail'() {
given:
def sampler = Mock(ApiSecuritySamplerImpl)
def producer = Mock(EventProducerService)
def span = Mock(AgentSpan)
def reqCtx = Mock(RequestContext)
def ctx = Mock(AppSecRequestContext)
def processor = new AppSecSpanPostProcessor(sampler, producer)

when:
processor.process(span, { false })

then:
noExceptionThrown()
1 * span.getRequestContext() >> reqCtx
1 * reqCtx.getData(_) >> ctx
1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true
1 * sampler.sampleRequest(_) >> false
1 * ctx.setKeepOpenForApiSecurityPostProcessing(false)
1 * ctx.closeWafContext() >> { throw new RuntimeException("WAF context close failed") }
1 * ctx.close() >> { throw new RuntimeException("Context close failed") }
1 * sampler.releaseOne() // Critical: permit is still released despite cleanup failures
0 * _
}
}