Skip to content

Commit

Permalink
MINOR: Few cleanups
Browse files Browse the repository at this point in the history
Reviewers: Manikumar Reddy <[email protected]>
  • Loading branch information
soondenana authored and omkreddy committed Sep 11, 2024
1 parent 450b707 commit 0a00456
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ public byte[] evaluateResponse(byte[] response) throws SaslException, SaslAuthen
case RECEIVE_CLIENT_FINAL_MESSAGE:
try {
ClientFinalMessage clientFinalMessage = new ClientFinalMessage(response);
if (!clientFinalMessage.nonce().endsWith(serverFirstMessage.nonce())) {
throw new SaslException("Invalid client nonce in the final client message.");
}
verifyClientProof(clientFinalMessage);
byte[] serverKey = scramCredential.serverKey();
byte[] serverSignature = formatter.serverSignature(serverKey, clientFirstMessage, serverFirstMessage, clientFinalMessage);
Expand Down Expand Up @@ -222,7 +225,8 @@ private void setState(State state) {
this.state = state;
}

private void verifyClientProof(ClientFinalMessage clientFinalMessage) throws SaslException {
// Visible for testing
void verifyClientProof(ClientFinalMessage clientFinalMessage) throws SaslException {
try {
byte[] expectedStoredKey = scramCredential.storedKey();
byte[] clientSignature = formatter.clientSignature(expectedStoredKey, clientFirstMessage, serverFirstMessage, clientFinalMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,26 @@
package org.apache.kafka.common.security.scram.internals;


import java.nio.charset.StandardCharsets;
import java.util.HashMap;

import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.authenticator.CredentialCache;
import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage;
import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage;
import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;

import javax.security.sasl.SaslException;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -67,10 +76,69 @@ public void authorizationIdNotEqualsAuthenticationId() {
assertThrows(SaslAuthenticationException.class, () -> saslServer.evaluateResponse(clientFirstMessage(USER_A, USER_B)));
}

/**
* Validate that server responds with client's nonce as prefix of its nonce in the
* server first message.
* <br>
* In addition, it checks that the client final message has nonce that it sent in its
* first message.
*/
@Test
public void validateNonceExchange() throws SaslException {
ScramSaslServer spySaslServer = Mockito.spy(saslServer);
byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A);
ClientFirstMessage clientFirstMessage = new ClientFirstMessage(clientFirstMsgBytes);

byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes);
ServerFirstMessage serverFirstMessage = new ServerFirstMessage(serverFirstMsgBytes);
assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()),
"Nonce in server message should start with client first message's nonce");

byte[] clientFinalMessage = clientFinalMessage(serverFirstMessage.nonce());
Mockito.doNothing()
.when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class));
byte[] serverFinalMsgBytes = spySaslServer.evaluateResponse(clientFinalMessage);
ServerFinalMessage serverFinalMessage = new ServerFinalMessage(serverFinalMsgBytes);
assertNull(serverFinalMessage.error(), "Server final message should not contain error");
}

@Test
public void validateFailedNonceExchange() throws SaslException {
ScramSaslServer spySaslServer = Mockito.spy(saslServer);
byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A);
ClientFirstMessage clientFirstMessage = new ClientFirstMessage(clientFirstMsgBytes);

byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes);
ServerFirstMessage serverFirstMessage = new ServerFirstMessage(serverFirstMsgBytes);
assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()),
"Nonce in server message should start with client first message's nonce");

byte[] clientFinalMessage = clientFinalMessage(formatter.secureRandomString());
Mockito.doNothing()
.when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class));
SaslException saslException = assertThrows(SaslException.class,
() -> spySaslServer.evaluateResponse(clientFinalMessage));
assertEquals("Invalid client nonce in the final client message.",
saslException.getMessage(),
"Failure message: " + saslException.getMessage());
}

private byte[] clientFirstMessage(String userName, String authorizationId) {
String nonce = formatter.secureRandomString();
String authorizationField = authorizationId != null ? "a=" + authorizationId : "";
String firstMessage = String.format("n,%s,n=%s,r=%s", authorizationField, userName, nonce);
return firstMessage.getBytes(StandardCharsets.UTF_8);
}

private byte[] clientFinalMessage(String nonce) {
String channelBinding = randomBytesAsString();
String proof = randomBytesAsString();

String message = String.format("c=%s,r=%s,p=%s", channelBinding, nonce, proof);
return message.getBytes(StandardCharsets.UTF_8);
}

private String randomBytesAsString() {
return Base64.getEncoder().encodeToString(formatter.secureRandomBytes());
}
}

0 comments on commit 0a00456

Please sign in to comment.