Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ dependencies {
implementation "org.apache.kafka:kafka-clients:${kafka_version}"
implementation 'com.onelogin:java-saml:2.5.0'
implementation 'com.onelogin:java-saml-core:2.5.0'
implementation 'io.protostuff:protostuff-api:1.7.4'
implementation 'io.protostuff:protostuff-core:1.7.4'
implementation 'io.protostuff:protostuff-collectionschema:1.7.4'
implementation 'io.protostuff:protostuff-runtime:1.7.4'

runtimeOnly 'net.minidev:accessors-smart:2.4.7'

Expand Down
76 changes: 74 additions & 2 deletions src/main/java/org/opensearch/security/support/Base64Helper.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.BaseEncoding;
import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import io.protostuff.runtime.RuntimeSchema;
import org.ldaptive.AbstractLdapBean;
import org.ldaptive.LdapAttribute;
import org.ldaptive.LdapEntry;
Expand All @@ -66,6 +70,8 @@

public class Base64Helper {

private static final ThreadLocal<LinkedBuffer> threadLocalLinkedBuffer = ThreadLocal.withInitial(() -> LinkedBuffer.allocate(1024));

private static final Set<Class<?>> SAFE_CLASSES = ImmutableSet.of(
String.class,
SocketAddress.class,
Expand Down Expand Up @@ -156,7 +162,7 @@ protected Object replaceObject(Object obj) throws IOException {
}
}

public static String serializeObject(final Serializable object) {
public static String serializeObjectJDK(final Serializable object) {

Preconditions.checkArgument(object != null, "object must not be null");

Expand All @@ -170,7 +176,7 @@ public static String serializeObject(final Serializable object) {
return BaseEncoding.base64().encode(bytes);
}

public static Serializable deserializeObject(final String string) {
public static Serializable deserializeObjectJDK(final String string) {

Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty");

Expand All @@ -183,6 +189,52 @@ public static Serializable deserializeObject(final String string) {
}
}


public static Serializable deserializeObjectProto(final String string) {
//ToDo: introduce safe class checks during deserialization using proto
Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty");
final byte[] bytes = BaseEncoding.base64().decode(string);
try {
Schema<SerializableWrapper> schema = RuntimeSchema.getSchema(SerializableWrapper.class);
SerializableWrapper serializableWrapper = schema.newMessage();
ProtostuffIOUtil.mergeFrom(bytes, serializableWrapper, schema);
return serializableWrapper.serializable;
} catch (final Exception e) {
throw new OpenSearchException(e);
}
}

public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) {
return useJDKDeserialization ? deserializeObjectJDK(string) : deserializeObjectProto(string);
}

public static Serializable deserializeObject(final String string) {
return deserializeObjectProto(string);
}

public static String serializeObjectProto(final Serializable object) {
//ToDo: introduce safe class checks during serialization using proto
SerializableWrapper serializableWrapper = new SerializableWrapper(object);
Preconditions.checkArgument(object != null, "object must not be null");
byte[] byteArray;
Schema<SerializableWrapper> schema = RuntimeSchema.getSchema(SerializableWrapper.class);
try {
byteArray = ProtostuffIOUtil.toByteArray(serializableWrapper, schema, threadLocalLinkedBuffer.get());
threadLocalLinkedBuffer.get().clear();
} catch (Exception e) {
throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass());
}
return BaseEncoding.base64().encode(byteArray);
}

public static String serializeObject(final Serializable object, final boolean useJDKSerialization) {
return useJDKSerialization ? serializeObjectJDK(object) : serializeObjectProto(object);
}

public static String serializeObject(final Serializable object) {
return serializeObjectProto(object);
}

private final static class SafeObjectInputStream extends ObjectInputStream {

public SafeObjectInputStream(InputStream in) throws IOException {
Expand All @@ -200,4 +252,24 @@ protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, Clas
throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName());
}
}


private static class SerializableWrapper {
/*
* Introduction of SerializableWrapper eases the protostuff deserialization part.
*
* When deserializing, we need to fetch the root proto Schema by specifying the class of the object that we
* intend to deserialize. The serialized bytes in case of proto do not have a class label, hence it's not
* possible to generically identify what object type are we deserializing.
*
* SerializableWrapper here will hold our actual serializable object, and we'll always (de)serialize
* SerializableWrapper object. Protostuff will internally construct and maintain schemas for underlying
* classes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

*/
Serializable serializable;

public SerializableWrapper(Serializable serializable) {
this.serializable = serializable;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public class ConfigConstants {

public static final String OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER = OPENDISTRO_SECURITY_CONFIG_PREFIX+"initial_action_class_header";

public static final String OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT = OPENDISTRO_SECURITY_CONFIG_PREFIX+"source_field_context";

/**
* Set by SSL plugin for https requests only
*/
Expand Down
29 changes: 29 additions & 0 deletions src/main/java/org/opensearch/security/support/HeaderHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
package org.opensearch.security.support;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import com.google.common.base.Strings;

Expand Down Expand Up @@ -77,4 +80,30 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context
public static boolean isTrustedClusterRequest(final ThreadContext context) {
return context.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_TRANSPORT_TRUSTED_CLUSTER_REQUEST) == Boolean.TRUE;
}


/**
* Returns all headers present in <code>ThreadContext::getHeaders()</code> which have values serialized within
* the security plugin
* @param context current ThreadContext
* @return Map containing all serialized headers
*/
public static Map<String, String> getAllSerializedHeaders(ThreadContext context) {
Map<String, String> headerMap = new HashMap<>();
Arrays.asList(
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER,
ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER,
ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER,
ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER,
ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER,
ConfigConstants.OPENDISTRO_SECURITY_DLS_FILTER_LEVEL_QUERY_HEADER,
ConfigConstants.OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT
).forEach(headerName -> {
String headerValue = context.getHeader(headerName);
if(headerValue != null) {
headerMap.put(headerName, headerValue);
}
});
return headerMap;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.Version;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsAction;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.opensearch.action.get.GetRequest;
Expand All @@ -58,6 +59,7 @@
import org.opensearch.security.ssl.transport.SSLConfig;
import org.opensearch.security.support.Base64Helper;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.HeaderHelper;
import org.opensearch.security.user.User;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Transport.Connection;
Expand Down Expand Up @@ -126,6 +128,8 @@ public <T extends TransportResponse> void sendRequestDecorate(AsyncSender sender
final String origCCSTransientFls = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_CCS);
final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS);

final boolean useJDKSerialization = connection.getVersion().before(Version.V_2_7_0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏


final boolean isDebugEnabled = log.isDebugEnabled();
try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
final TransportResponseHandler<T> restoringHandler = new RestoringTransportResponseHandler<T>(handler, stashedContext);
Expand All @@ -142,7 +146,7 @@ public <T extends TransportResponse> void sendRequestDecorate(AsyncSender sender
|| k.equals(ConfigConstants.OPENDISTRO_SECURITY_DOC_ALLOWLIST_HEADER)
|| k.equals(ConfigConstants.OPENDISTRO_SECURITY_FILTER_LEVEL_DLS_DONE)
|| k.equals(ConfigConstants.OPENDISTRO_SECURITY_DLS_MODE_HEADER)
|| k.equals(ConfigConstants.OPENDISTRO_SECURITY_DLS_FILTER_LEVEL_QUERY_HEADER)
|| k.equals(ConfigConstants.OPENDISTRO_SECURITY_DLS_FILTER_LEVEL_QUERY_HEADER)
|| (k.equals("_opendistro_security_source_field_context") && ! (request instanceof SearchRequest) && !(request instanceof GetRequest))
|| k.startsWith("_opendistro_security_trace")
|| k.startsWith(ConfigConstants.OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER)
Expand Down Expand Up @@ -203,10 +207,18 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
getThreadContext().putHeader("_opendistro_security_trace"+System.currentTimeMillis()+"#"+UUID.randomUUID().toString(), Thread.currentThread().getName()+" IC -> "+action+" "+getThreadContext().getHeaders().entrySet().stream().filter(p->!p.getKey().startsWith("_opendistro_security_trace")).collect(Collectors.toMap(p -> p.getKey(), p -> p.getValue())));
}

if (useJDKSerialization) {
serializeHeadersUsingJdkForVersionUpgrade();
}

sender.sendRequest(connection, action, request, options, restoringHandler);
}
}

private void serializeHeadersUsingJdkForVersionUpgrade() {
HeaderHelper.getAllSerializedHeaders(getThreadContext()).forEach((key, value) -> getThreadContext().putHeader(key, Base64Helper.serializeObject(Base64Helper.deserializeObject(value), true)));
}

private void ensureCorrectHeaders(final Object remoteAdr, final User origUser, final String origin,
final String injectedUserString, final String injectedRolesString) {
// keep original address
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.Version;
import org.opensearch.action.bulk.BulkShardRequest;
import org.opensearch.action.support.replication.TransportReplicationAction.ConcreteShardRequest;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -102,6 +103,12 @@ protected void messageReceivedDecorate(final T request, final TransportRequestHa
resolvedActionClass = ((ConcreteShardRequest<?>) request).getRequest().getClass().getSimpleName();
}

final boolean useJDKSerialization = transportChannel.getVersion().before(Version.V_2_7_0);

if(useJDKSerialization) {
serializeHeadersUsingProtoForVersionUpgrade();
}

String initialActionClassValue = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER);

final ThreadContext.StoredContext sgContext = getThreadContext().newStoredContext(false);
Expand Down Expand Up @@ -296,6 +303,10 @@ else if(!Strings.isNullOrEmpty(injectedUserHeader)) {
}
}
}

private void serializeHeadersUsingProtoForVersionUpgrade() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this function to HeaderHelpers to re-use

HeaderHelper.getAllSerializedHeaders(getThreadContext()).forEach((key, value) -> getThreadContext().putHeader(key, Base64Helper.serializeObject(Base64Helper.deserializeObject(value, true))));
}

private void putInitialActionClassHeader(String initialActionClassValue, String resolvedActionClass) {
if(initialActionClassValue == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,76 +26,95 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.security.user.User;

import static org.opensearch.security.support.Base64Helper.deserializeObject;
import static org.opensearch.security.support.Base64Helper.serializeObject;
import static org.opensearch.security.support.Base64Helper.deserializeObjectJDK;
import static org.opensearch.security.support.Base64Helper.deserializeObjectProto;
import static org.opensearch.security.support.Base64Helper.serializeObjectJDK;
import static org.opensearch.security.support.Base64Helper.serializeObjectProto;

public class Base64HelperTest {

private static final class NotSafeSerializable implements Serializable {
private static final long serialVersionUID = 5135559266828470092L;
}

private static Serializable ds(Serializable s) {
return deserializeObject(serializeObject(s));
private static Serializable dsJDK(Serializable s) {
return deserializeObjectJDK(serializeObjectJDK(s));
}

private static Serializable dsProto(Serializable s) {
return deserializeObjectProto(serializeObjectProto(s));
}

@Test
public void testString() {
String string = "string";
Assert.assertEquals(string, ds(string));
Assert.assertEquals(string, dsJDK(string));
Assert.assertEquals(string, dsProto(string));
}

@Test
public void testInteger() {
Integer integer = Integer.valueOf(0);
Assert.assertEquals(integer, ds(integer));
Assert.assertEquals(integer, dsJDK(integer));
Assert.assertEquals(integer, dsProto(integer));
}

@Test
public void testDouble() {
Double number = Double.valueOf(0.);
Assert.assertEquals(number, ds(number));
Assert.assertEquals(number, dsJDK(number));
Assert.assertEquals(number, dsProto(number));
}

@Test
public void testInetSocketAddress() {
InetSocketAddress inetSocketAddress = new InetSocketAddress(0);
Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress));
Assert.assertEquals(inetSocketAddress, dsJDK(inetSocketAddress));
Assert.assertEquals(inetSocketAddress, dsProto(inetSocketAddress));

}

@Test
public void testPattern() {
Pattern pattern = Pattern.compile(".*");
Assert.assertEquals(pattern.pattern(), ((Pattern) ds(pattern)).pattern());
Assert.assertEquals(pattern.pattern(), ((Pattern) dsJDK(pattern)).pattern());
Assert.assertEquals(pattern.pattern(), ((Pattern) dsProto(pattern)).pattern());

}

@Test
public void testUser() {
User user = new User("user");
Assert.assertEquals(user, ds(user));
Assert.assertEquals(user, dsJDK(user));
Assert.assertEquals(user, dsProto(user));
}

@Test
public void testSourceFieldsContext() {
SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest(""));
Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString());
Assert.assertEquals(sourceFieldsContext.toString(), dsJDK(sourceFieldsContext).toString());
Assert.assertEquals(sourceFieldsContext.toString(), dsProto(sourceFieldsContext).toString());

}

@Test
public void testHashMap() {
HashMap map = new HashMap();
Assert.assertEquals(map, ds(map));
Assert.assertEquals(map, dsJDK(map));
Assert.assertEquals(map, dsProto(map));
}

@Test
public void testArrayList() {
ArrayList list = new ArrayList();
Assert.assertEquals(list, ds(list));
Assert.assertEquals(list, dsJDK(list));
Assert.assertEquals(list, dsProto(list));
}

@Test(expected = OpenSearchException.class)
public void notSafeSerializable() {
serializeObject(new NotSafeSerializable());
serializeObjectJDK(new NotSafeSerializable());
serializeObjectProto(new NotSafeSerializable());
}

@Test(expected = OpenSearchException.class)
Expand All @@ -104,6 +123,7 @@ public void notSafeDeserializable() throws Exception {
try (final ObjectOutputStream out = new ObjectOutputStream(bos)) {
out.writeObject(new NotSafeSerializable());
}
deserializeObject(BaseEncoding.base64().encode(bos.toByteArray()));
deserializeObjectJDK(BaseEncoding.base64().encode(bos.toByteArray()));
deserializeObjectProto(BaseEncoding.base64().encode(bos.toByteArray()));
}
}