Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ public void processAfter(CdsReadEventContext context, List<CdsData> data) {
if (ApplicationHandlerHelper.noContentFieldInData(context.getTarget(), data)) {
return;
}
logger.debug("Processing after read event for entity {}", context.getTarget().getName());
logger.debug("Processing after read event for entity {}", context.getTarget().getQualifiedName());

Converter converter = (path, element, value) -> {
logger.info("Processing after read event for entity {}", element.getName());
String contentId = (String) path.target().values().get(Attachments.CONTENT_ID);
String status = (String) path.target().values().get(Attachments.STATUS);
InputStream content = (InputStream) path.target().values().get(Attachments.CONTENT);
Expand All @@ -107,7 +106,7 @@ public void processAfter(CdsReadEventContext context, List<CdsData> data) {

private List<String> getAttachmentAssociations(CdsModel model, CdsEntity entity, String associationName,
List<String> processedEntities) {
List<String> associationNames = new ArrayList<String>();
List<String> associationNames = new ArrayList<>();
if (ApplicationHandlerHelper.isMediaEntity(entity)) {
associationNames.add(associationName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,28 @@ private void startScanning(CdsEntity attachmentEntityToScan, String contentId) {
// get current request context
RequestContextRunner runner = runtime.requestContext();

logger.debug("Transaction completed. Starting to scan attachment asynchronously.");
logger.debug("Transaction completed. Starting to scan attachment {} in entity {} asynchronously.", contentId,
attachmentEntityToScan.getQualifiedName());
Supplier<Void> executeAdapterSupplier = () -> {
// run malware scan asynchronously with current request context
runner.run(resourceCtx -> {
// ensure that DB transaction is still active until the content is completely read from InputStream and
// scanned by malware scanner
runtime.changeSetContext().run(changeSetCtx -> {
logger.info("Starting to scan attachment");
logger.debug("Started asynchronously scan of attachment {} in entity {}.", contentId,
attachmentEntityToScan.getQualifiedName());
attachmentMalwareScanner.scanAttachment(attachmentEntityToScan, contentId);
});
});
return null;
};
CompletableFuture.supplyAsync(executeAdapterSupplier).whenComplete((result, exception) -> {
if (Objects.nonNull(exception)) {
logger.error("Error during scanning attachment", exception);
logger.error("Error scanning attachment {} in entity {}.", contentId,
attachmentEntityToScan.getQualifiedName(), exception);
} else {
logger.info("Scanning attachment completed.");
logger.debug("Scanning of attachment {} in entity {} was completed successfully.", contentId,
attachmentEntityToScan.getQualifiedName());
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.slf4j.LoggerFactory;

import com.google.common.annotations.VisibleForTesting;
import com.sap.cds.CdsData;
import com.sap.cds.Result;
import com.sap.cds.feature.attachments.generated.cds4j.sap.attachments.Attachments;
import com.sap.cds.feature.attachments.generated.cds4j.sap.attachments.StatusCode;
Expand Down Expand Up @@ -68,33 +67,34 @@ public DefaultAttachmentMalwareScanner(PersistenceService persistenceService, At

@Override
public void scanAttachment(CdsEntity attachmentEntity, String contentId) {
logger.info("Service handler called to scan document for malware");
logger.debug("Started scanning attachment {} of entity {}.", contentId, attachmentEntity.getQualifiedName());

List<SelectionResult> selectionResult = selectData(attachmentEntity, contentId);

selectionResult.forEach(result -> {

long rowCount = result.result().rowCount();
if (rowCount <= 0) {
logger.info("No data found, nothing to scan for entity: {}", result.entity.getQualifiedName());
logger.debug("No attachments {} found in entity {}, nothing to scan.", contentId,
result.entity.getQualifiedName());
return;
}

if (rowCount > 1) {
logger.warn("More than one attachment found for document id: {} in entity: {}", contentId,
logger.warn("More than one attachment {} found in entity {}.", contentId,
result.entity.getQualifiedName());
throw new IllegalStateException("More than one attachment found for document id: " + contentId);
throw new IllegalStateException("More than one attachment with contentId %s.".formatted(contentId));
}

CdsData cdsData = result.result().single(CdsData.class);
MalwareScanResultStatus status = scanDocument(cdsData);
Attachments attachment = result.result().single(Attachments.class);
MalwareScanResultStatus status = scanDocument(attachment);
updateData(result.entity, contentId, status);
});

}

private List<SelectionResult> selectData(CdsEntity attachmentEntity, String contentId) {
var result = new ArrayList<SelectionResult>();
List<SelectionResult> result = new ArrayList<>();
try {
CdsEntity entity = (CdsEntity) attachmentEntity.getTargetOf(Drafts.SIBLING_ENTITY);
Result selectionResult = readData(contentId, entity);
Expand All @@ -109,37 +109,43 @@ private List<SelectionResult> selectData(CdsEntity attachmentEntity, String cont
}

private Result readData(String contentId, CdsEntity entity) {
CqnSelect select = Select.from(entity).columns(Attachments.CONTENT_ID, Attachments.CONTENT)
.where(entry -> entry.get(Attachments.CONTENT_ID).eq(contentId));
return persistenceService.run(select);
CqnSelect select = Select.from(entity).columns(Attachments.CONTENT_ID, Attachments.CONTENT, Attachments.STATUS)
.where(e -> e.get(Attachments.CONTENT_ID).eq(contentId)
.and(e.get(Attachments.STATUS).ne(StatusCode.CLEAN)));

Result result = persistenceService.run(select);
result.streamOf(Attachments.class)
.forEach(attachment -> logger.debug("Found attachment {} in entity {} with status {}.",
attachment.getContentId(), entity.getQualifiedName(), attachment.getStatus()));
return result;
}

private MalwareScanResultStatus scanDocument(CdsData data) {
private MalwareScanResultStatus scanDocument(Attachments attachment) {
if (malwareScanClient != null) {
String contentId = (String) data.get(Attachments.CONTENT_ID);
InputStream dbContent = (InputStream) data.get(Attachments.CONTENT);
try {
InputStream content = Objects.nonNull(dbContent) ? dbContent
: attachmentService.readAttachment(contentId);
InputStream content = Objects.nonNull(attachment.getContent()) ? attachment.getContent()
: attachmentService.readAttachment(attachment.getContentId());
logger.debug("Start scanning attachment {}.", attachment.getContentId());
return malwareScanClient.scanContent(content);
} catch (RuntimeException e) {
logger.error("Error while scanning document with document id: {}", contentId, e);
logger.error("Error while scanning attachment {}.", attachment.getContentId(), e);
return MalwareScanResultStatus.FAILED;
}
}
return MalwareScanResultStatus.NO_SCANNER;
}

private void updateData(CdsEntity attachmentEntity, String contentId, MalwareScanResultStatus status) {
CdsData updateData = CdsData.create();
updateData.put(Attachments.STATUS, mapStatus(status));
updateData.put(Attachments.SCANNED_AT, Instant.now());
logger.debug("CdsData shall be updated for entity: {}", attachmentEntity.getQualifiedName());
Attachments updateData = Attachments.create();
updateData.setStatus(mapStatus(status));
updateData.setScannedAt(Instant.now());

CqnUpdate update = Update.entity(attachmentEntity).data(updateData)
.where(entry -> entry.get(Attachments.CONTENT_ID).eq(contentId));
Result result = persistenceService.run(update);
logger.info("Attachment has been updated, with result row count {} for entity {}", result.rowCount(),
attachmentEntity.getQualifiedName());

logger.debug("Updated scan status to {} of attachment {} in entity {} -> Row count {}.", updateData.getStatus(),
contentId, attachmentEntity.getQualifiedName(), result.rowCount());
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ public DefaultMalwareScanClient(HttpClientProvider clientProvider) {

@Override
public MalwareScanResultStatus scanContent(InputStream content) {
logger.info("Start scanning document");
HttpPost request = buildHttpRequest(content);
return executeRequest(request);
}
Expand Down Expand Up @@ -97,7 +96,6 @@ private static MalwareScanResultStatus mapResponseToStatus(MalwareScanResult sca
scanResult.isMalwareDetected(), scanResult.isEncryptedContentDetected());
return MalwareScanResultStatus.INFECTED;
} else {
logger.info("Document is clean");
return MalwareScanResultStatus.CLEAN;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ private void verifyLogIsWritten() {
observer.stop();
var errorList = observer.getLogEvents().stream().filter(event -> event.getLevel().equals(Level.ERROR)).toList();
assertThat(errorList).hasSize(1);
assertThat(errorList.get(0).getFormattedMessage()).contains("Error during scanning attachment");
assertThat(errorList.get(0).getFormattedMessage()).contains("Error scanning attachment");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.ArgumentCaptor;

import com.sap.cds.CdsData;
import com.sap.cds.Result;
import com.sap.cds.feature.attachments.generated.cds4j.sap.attachments.Attachments;
import com.sap.cds.feature.attachments.generated.cds4j.sap.attachments.StatusCode;
Expand Down Expand Up @@ -80,7 +79,7 @@ void correctSelectForNonDraftEntity() {
@Test
void correctSelectForDraftEntity() {
var entity = runtime.getCdsModel().findEntity(getTestServiceAttachmentName());
mockSelectResult(CdsData.create(), MalwareScanResultStatus.CLEAN);
mockSelectResult(Attachments.create(), MalwareScanResultStatus.CLEAN);

cut.scanAttachment(entity.orElseThrow(), "ID");

Expand All @@ -100,9 +99,9 @@ void fallbackToActiveEntityIfDraftHasNoData() {
when(persistenceService.run(any(CqnSelect.class))).thenReturn(emptyResult).thenReturn(result);
when(result.rowCount()).thenReturn(1L);
var content = mock(InputStream.class);
var cdsData = CdsData.create();
var cdsData = Attachments.create();
cdsData.put(Attachments.CONTENT, content);
when(result.single(CdsData.class)).thenReturn(cdsData);
when(result.single(Attachments.class)).thenReturn(cdsData);
when(malwareScanClient.scanContent(any())).thenReturn(MalwareScanResultStatus.CLEAN);

cut.scanAttachment(entity.orElseThrow(), "ID");
Expand Down Expand Up @@ -130,7 +129,7 @@ void exceptionIfTooManyResultsAreSelected() {
@EnumSource(MalwareScanResultStatus.class)
void dataAreUpdatedWithStatus(MalwareScanResultStatus status) {
var entity = runtime.getCdsModel().findEntity(getTestServiceAttachmentName());
mockSelectResult(CdsData.create(), status);
mockSelectResult(Attachments.create(), status);

cut.scanAttachment(entity.orElseThrow(), "ID");

Expand All @@ -142,7 +141,7 @@ void dataAreUpdatedWithStatusFromFailingScanClient() {
var entity = runtime.getCdsModel().findEntity(getTestServiceAttachmentName());
when(persistenceService.run(any(CqnSelect.class))).thenReturn(result);
when(result.rowCount()).thenReturn(1L);
when(result.single(CdsData.class)).thenReturn(CdsData.create());
when(result.single(Attachments.class)).thenReturn(Attachments.create());
when(malwareScanClient.scanContent(any())).thenThrow(new ServiceException("Error reading attachment"));

cut.scanAttachment(entity.orElseThrow(), "ID");
Expand All @@ -155,7 +154,7 @@ void dataAreUpdatedWithStatusFromFailingAttachmentService() {
var entity = runtime.getCdsModel().findEntity(getTestServiceAttachmentName());
when(persistenceService.run(any(CqnSelect.class))).thenReturn(result);
when(result.rowCount()).thenReturn(1L);
when(result.single(CdsData.class)).thenReturn(CdsData.create());
when(result.single(Attachments.class)).thenReturn(Attachments.create());
when(attachmentService.readAttachment(any())).thenThrow(new ServiceException("Error reading attachment"));

cut.scanAttachment(entity.orElseThrow(), "ID");
Expand All @@ -167,7 +166,7 @@ void dataAreUpdatedWithStatusFromFailingAttachmentService() {
void contentTakenFromTheDatabaseSelect() {
var entity = runtime.getCdsModel().findEntity(getTestServiceAttachmentName());
var content = mock(InputStream.class);
var data = CdsData.create();
var data = Attachments.create();
data.put("content", content);
mockSelectResult(data, MalwareScanResultStatus.CLEAN);

Expand All @@ -181,7 +180,7 @@ void contentTakenFromTheDatabaseSelect() {
void contentTakenFromTheAttachmentService() {
var entity = runtime.getCdsModel().findEntity(getTestServiceAttachmentName());
var contentId = "contentId";
var data = CdsData.create();
var data = Attachments.create();
data.put(Attachments.CONTENT_ID, contentId);
mockSelectResult(data, MalwareScanResultStatus.CLEAN);
var content = mock(InputStream.class);
Expand All @@ -197,7 +196,7 @@ void contentTakenFromTheAttachmentService() {
void contentTakenFromTheAttachmentServiceForNonDraft() {
var entity = runtime.getCdsModel().findEntity(Attachment_.CDS_NAME);
var contentId = "contentId";
var data = CdsData.create();
var data = Attachments.create();
data.put(Attachments.CONTENT_ID, contentId);
mockSelectResult(data, MalwareScanResultStatus.CLEAN);
var content = mock(InputStream.class);
Expand All @@ -214,10 +213,10 @@ void noDataReturnedForUpdateNothingDoneForNonDraftEntity() {
var entity = runtime.getCdsModel().findEntity(getTestServiceAttachmentName());
when(persistenceService.run(any(CqnSelect.class))).thenReturn(result);
when(result.rowCount()).thenReturn(1L).thenReturn(0L);
var originSelectionData = CdsData.create();
var originSelectionData = Attachments.create();
originSelectionData.put(Attachments.CONTENT_ID, "first contentId");
originSelectionData.put(Attachments.CONTENT, mock(InputStream.class));
when(result.single(CdsData.class)).thenReturn(originSelectionData).thenReturn(CdsData.create());
when(result.single(Attachments.class)).thenReturn(originSelectionData).thenReturn(Attachments.create());
when(malwareScanClient.scanContent(any())).thenReturn(MalwareScanResultStatus.CLEAN);

cut.scanAttachment(entity.orElseThrow(), "ID");
Expand All @@ -233,10 +232,10 @@ void clientNotCalledIfNoInstanceBound() {
var entity = runtime.getCdsModel().findEntity(getTestServiceAttachmentName());
var secondResult = mock(Result.class);
when(secondResult.rowCount()).thenReturn(0L);
when(secondResult.single(CdsData.class)).thenReturn(CdsData.create());
when(secondResult.single(Attachments.class)).thenReturn(Attachments.create());
when(persistenceService.run(any(CqnSelect.class))).thenReturn(result).thenReturn(secondResult);
when(result.rowCount()).thenReturn(1L);
when(result.single(CdsData.class)).thenReturn(CdsData.create());
when(result.single(Attachments.class)).thenReturn(Attachments.create());

cut.scanAttachment(entity.orElseThrow(), "ID");

Expand Down Expand Up @@ -273,17 +272,17 @@ private String getTestServiceAttachmentName() {
return com.sap.cds.feature.attachments.generated.test.cds4j.unit.test.testservice.Attachment_.CDS_NAME;
}

private void mockSelectResult(CdsData cdsData, MalwareScanResultStatus status) {
private void mockSelectResult(Attachments cdsData, MalwareScanResultStatus status) {
when(persistenceService.run(any(CqnSelect.class))).thenReturn(result);
when(result.rowCount()).thenReturn(1L);
when(result.single(CdsData.class)).thenReturn(cdsData);
when(result.single(Attachments.class)).thenReturn(cdsData);
when(malwareScanClient.scanContent(any())).thenReturn(status);
}

private void verifyKeyWhereCondition(CqnSelect select) {
assertThat(select.where()).isPresent();
var selectWhere = select.where().get();
assertThat(selectWhere.toString()).contains("[{\"ref\":[\"contentId\"]},\"=\",{\"val\":\"ID\"}]");
assertThat(selectWhere.toString()).contains("[{\"ref\":[\"contentId\"]},\"=\",{\"val\":\"ID\"},\"and\",{\"ref\":[\"status\"]},\"<>\",{\"val\":\"Clean\"}]");
}

}