Skip to content

Commit

Permalink
Merge pull request #567 from ali-ince/1.7-pass-access-mode
Browse files Browse the repository at this point in the history
Pass AccessMode in BEGIN and RUN messages
  • Loading branch information
zhenlineo authored Mar 6, 2019
2 parents 49ef4db + 227c0fc commit 742d280
Show file tree
Hide file tree
Showing 26 changed files with 769 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.concurrent.CompletionStage;

import org.neo4j.driver.internal.async.AccessModeConnection;
import org.neo4j.driver.internal.spi.Connection;
import org.neo4j.driver.internal.spi.ConnectionPool;
import org.neo4j.driver.internal.spi.ConnectionProvider;
Expand All @@ -45,7 +46,7 @@ public class DirectConnectionProvider implements ConnectionProvider
@Override
public CompletionStage<Connection> acquireConnection( AccessMode mode )
{
return connectionPool.acquire( address );
return connectionPool.acquire( address ).thenApply( connection -> new AccessModeConnection( connection, mode ) );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright (c) 2002-2019 "Neo4j,"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.internal.async;

import java.util.concurrent.CompletionStage;

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.messaging.BoltProtocol;
import org.neo4j.driver.internal.messaging.Message;
import org.neo4j.driver.internal.spi.Connection;
import org.neo4j.driver.internal.spi.ResponseHandler;
import org.neo4j.driver.internal.util.ServerVersion;
import org.neo4j.driver.v1.AccessMode;

public class AccessModeConnection implements Connection
{
private final Connection delegate;
private final AccessMode mode;

public AccessModeConnection( Connection delegate, AccessMode mode )
{
this.delegate = delegate;
this.mode = mode;
}

public Connection connection()
{
return delegate;
}

@Override
public boolean isOpen()
{
return delegate.isOpen();
}

@Override
public void enableAutoRead()
{
delegate.enableAutoRead();
}

@Override
public void disableAutoRead()
{
delegate.disableAutoRead();
}

@Override
public void write( Message message, ResponseHandler handler )
{
delegate.write( message, handler );
}

@Override
public void write( Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2 )
{
delegate.write( message1, handler1, message2, handler2 );
}

@Override
public void writeAndFlush( Message message, ResponseHandler handler )
{
delegate.writeAndFlush( message, handler );
}

@Override
public void writeAndFlush( Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2 )
{
delegate.writeAndFlush( message1, handler1, message2, handler2 );
}

@Override
public CompletionStage<Void> reset()
{
return delegate.reset();
}

@Override
public CompletionStage<Void> release()
{
return delegate.release();
}

@Override
public void terminateAndRelease( String reason )
{
delegate.terminateAndRelease( reason );
}

@Override
public BoltServerAddress serverAddress()
{
return delegate.serverAddress();
}

@Override
public ServerVersion serverVersion()
{
return delegate.serverVersion();
}

@Override
public BoltProtocol protocol()
{
return delegate.protocol();
}

@Override
public AccessMode mode()
{
return mode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.RoutingErrorHandler;
import org.neo4j.driver.internal.async.AccessModeConnection;
import org.neo4j.driver.internal.async.RoutingConnection;
import org.neo4j.driver.internal.cluster.AddressSet;
import org.neo4j.driver.internal.cluster.ClusterComposition;
Expand Down Expand Up @@ -95,7 +96,8 @@ public CompletionStage<Connection> acquireConnection( AccessMode mode )
{
return freshRoutingTable( mode )
.thenCompose( routingTable -> acquire( mode, routingTable ) )
.thenApply( connection -> new RoutingConnection( connection, mode, this ) );
.thenApply( connection -> new RoutingConnection( connection, mode, this ) )
.thenApply( connection -> new AccessModeConnection( connection, mode ) );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,22 @@
import java.util.Objects;

import org.neo4j.driver.internal.Bookmarks;
import org.neo4j.driver.v1.AccessMode;
import org.neo4j.driver.v1.TransactionConfig;
import org.neo4j.driver.v1.Value;

public class BeginMessage extends TransactionStartingMessage
{
public static final byte SIGNATURE = 0x11;

public BeginMessage( Bookmarks bookmarks, TransactionConfig config )
public BeginMessage( Bookmarks bookmarks, TransactionConfig config, AccessMode mode )
{
this( bookmarks, config.timeout(), config.metadata() );
this( bookmarks, config.timeout(), config.metadata(), mode );
}

public BeginMessage( Bookmarks bookmarks, Duration txTimeout, Map<String,Value> txMetadata )
public BeginMessage( Bookmarks bookmarks, Duration txTimeout, Map<String,Value> txMetadata, AccessMode mode )
{
super( bookmarks, txTimeout, txMetadata );
super( bookmarks, txTimeout, txMetadata, mode );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Objects;

import org.neo4j.driver.internal.Bookmarks;
import org.neo4j.driver.v1.AccessMode;
import org.neo4j.driver.v1.TransactionConfig;
import org.neo4j.driver.v1.Value;

Expand All @@ -33,14 +34,15 @@ public class RunWithMetadataMessage extends TransactionStartingMessage
private final String statement;
private final Map<String,Value> parameters;

public RunWithMetadataMessage( String statement, Map<String,Value> parameters, Bookmarks bookmarks, TransactionConfig config )
public RunWithMetadataMessage( String statement, Map<String,Value> parameters, Bookmarks bookmarks, TransactionConfig config, AccessMode mode )
{
this( statement, parameters, bookmarks, config.timeout(), config.metadata() );
this( statement, parameters, bookmarks, config.timeout(), config.metadata(), mode );
}

public RunWithMetadataMessage( String statement, Map<String,Value> parameters, Bookmarks bookmarks, Duration txTimeout, Map<String,Value> txMetadata )
public RunWithMetadataMessage( String statement, Map<String,Value> parameters, Bookmarks bookmarks, Duration txTimeout, Map<String,Value> txMetadata,
AccessMode mode )
{
super( bookmarks, txTimeout, txMetadata );
super( bookmarks, txTimeout, txMetadata, mode );
this.statement = statement;
this.parameters = parameters;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.neo4j.driver.internal.Bookmarks;
import org.neo4j.driver.internal.messaging.Message;
import org.neo4j.driver.internal.util.Iterables;
import org.neo4j.driver.v1.AccessMode;
import org.neo4j.driver.v1.Value;

import static java.util.Collections.emptyMap;
Expand All @@ -34,26 +35,29 @@ abstract class TransactionStartingMessage implements Message
private static final String BOOKMARKS_METADATA_KEY = "bookmarks";
private static final String TX_TIMEOUT_METADATA_KEY = "tx_timeout";
private static final String TX_METADATA_METADATA_KEY = "tx_metadata";
private static final String MODE_KEY = "mode";
private static final String MODE_READ_VALUE = "r";

final Map<String,Value> metadata;

TransactionStartingMessage( Bookmarks bookmarks, Duration txTimeout, Map<String,Value> txMetadata )
TransactionStartingMessage( Bookmarks bookmarks, Duration txTimeout, Map<String,Value> txMetadata, AccessMode mode )
{
this.metadata = buildMetadata( bookmarks, txTimeout, txMetadata );
this.metadata = buildMetadata( bookmarks, txTimeout, txMetadata, mode );
}

public final Map<String,Value> metadata()
{
return metadata;
}

private static Map<String,Value> buildMetadata( Bookmarks bookmarks, Duration txTimeout, Map<String,Value> txMetadata )
private static Map<String,Value> buildMetadata( Bookmarks bookmarks, Duration txTimeout, Map<String,Value> txMetadata, AccessMode mode )
{
boolean bookmarksPresent = bookmarks != null && !bookmarks.isEmpty();
boolean txTimeoutPresent = txTimeout != null;
boolean txMetadataPresent = txMetadata != null && !txMetadata.isEmpty();
boolean accessModePresent = mode == AccessMode.READ;

if ( !bookmarksPresent && !txTimeoutPresent && !txMetadataPresent )
if ( !bookmarksPresent && !txTimeoutPresent && !txMetadataPresent && !accessModePresent )
{
return emptyMap();
}
Expand All @@ -73,6 +77,13 @@ private static Map<String,Value> buildMetadata( Bookmarks bookmarks, Duration tx
result.put( TX_METADATA_METADATA_KEY, value( txMetadata ) );
}

switch ( mode )
{
case READ:
result.put( MODE_KEY, value( MODE_READ_VALUE ) );
break;
}

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void prepareToCloseChannel( Channel channel )
@Override
public CompletionStage<Void> beginTransaction( Connection connection, Bookmarks bookmarks, TransactionConfig config )
{
BeginMessage beginMessage = new BeginMessage( bookmarks, config );
BeginMessage beginMessage = new BeginMessage( bookmarks, config, connection.mode() );

if ( bookmarks.isEmpty() )
{
Expand Down Expand Up @@ -148,7 +148,7 @@ private static CompletionStage<InternalStatementResultCursor> runStatement( Conn
Map<String,Value> params = statement.parameters().asMap( ofValue() );

CompletableFuture<Void> runCompletedFuture = new CompletableFuture<>();
Message runMessage = new RunWithMetadataMessage( query, params, bookmarksHolder.getBookmarks(), config );
Message runMessage = new RunWithMetadataMessage( query, params, bookmarksHolder.getBookmarks(), config, connection.mode() );
RunResponseHandler runHandler = new RunResponseHandler( runCompletedFuture, METADATA_EXTRACTOR );
PullAllResponseHandler pullAllHandler = newPullAllHandler( statement, runHandler, connection, bookmarksHolder, tx );

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.neo4j.driver.internal.messaging.BoltProtocol;
import org.neo4j.driver.internal.messaging.Message;
import org.neo4j.driver.internal.util.ServerVersion;
import org.neo4j.driver.v1.AccessMode;

public interface Connection
{
Expand Down Expand Up @@ -52,4 +53,9 @@ public interface Connection
ServerVersion serverVersion();

BoltProtocol protocol();

default AccessMode mode()
{
return AccessMode.WRITE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,20 @@
package org.neo4j.driver.internal;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import org.neo4j.driver.internal.async.AccessModeConnection;
import org.neo4j.driver.internal.spi.Connection;
import org.neo4j.driver.internal.spi.ConnectionPool;
import org.neo4j.driver.v1.AccessMode;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.junit.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.mock;
Expand All @@ -48,8 +54,26 @@ void acquiresConnectionsFromThePool()
ConnectionPool pool = poolMock( address, connection1, connection2 );
DirectConnectionProvider provider = new DirectConnectionProvider( address, pool );

assertSame( connection1, await( provider.acquireConnection( READ ) ) );
assertSame( connection2, await( provider.acquireConnection( WRITE ) ) );
Connection acquired1 = await( provider.acquireConnection( READ ) );
assertThat( acquired1, instanceOf( AccessModeConnection.class ) );
assertSame( connection1, ((AccessModeConnection) acquired1).connection() );

Connection acquired2 = await( provider.acquireConnection( WRITE ) );
assertThat( acquired2, instanceOf( AccessModeConnection.class ) );
assertSame( connection2, ((AccessModeConnection) acquired2).connection() );
}

@ParameterizedTest
@EnumSource( AccessMode.class )
void returnsCorrectAccessMode( AccessMode mode )
{
BoltServerAddress address = BoltServerAddress.LOCAL_DEFAULT;
ConnectionPool pool = poolMock( address, mock( Connection.class ) );
DirectConnectionProvider provider = new DirectConnectionProvider( address, pool );

Connection acquired = await( provider.acquireConnection( mode ) );

assertEquals( mode, acquired.mode() );
}

@Test
Expand Down
Loading

0 comments on commit 742d280

Please sign in to comment.