Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public ProtocolProxy<ProtocolMetaInfoPB> getProtocolMetaInfoProxy(
factory)), false);
}

private static class Invoker implements RpcInvocationHandler {
protected static class Invoker implements RpcInvocationHandler {
private final Map<String, Message> returnTypes =
new ConcurrentHashMap<String, Message>();
private boolean isClosed = false;
Expand All @@ -133,7 +133,7 @@ private static class Invoker implements RpcInvocationHandler {
private AtomicBoolean fallbackToSimpleAuth;
private AlignmentContext alignmentContext;

private Invoker(Class<?> protocol, InetSocketAddress addr,
protected Invoker(Class<?> protocol, InetSocketAddress addr,
UserGroupInformation ticket, Configuration conf, SocketFactory factory,
int rpcTimeout, RetryPolicy connectionRetryPolicy,
AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext)
Expand All @@ -148,7 +148,7 @@ private Invoker(Class<?> protocol, InetSocketAddress addr,
/**
* This constructor takes a connectionId, instead of creating a new one.
*/
private Invoker(Class<?> protocol, Client.ConnectionId connId,
protected Invoker(Class<?> protocol, Client.ConnectionId connId,
Configuration conf, SocketFactory factory) {
this.remoteId = connId;
this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class);
Expand Down Expand Up @@ -225,8 +225,6 @@ public Message invoke(Object proxy, final Method method, Object[] args)
traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method));
}

RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method);

if (LOG.isTraceEnabled()) {
LOG.trace(Thread.currentThread().getId() + ": Call -> " +
remoteId + ": " + method.getName() +
Expand All @@ -238,7 +236,7 @@ public Message invoke(Object proxy, final Method method, Object[] args)
final RpcWritable.Buffer val;
try {
val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER,
new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId,
constructRpcRequest(method, theRequest), remoteId,
fallbackToSimpleAuth, alignmentContext);

} catch (Throwable e) {
Expand Down Expand Up @@ -283,6 +281,11 @@ public boolean isDone() {
}
}

protected Writable constructRpcRequest(Method method, Message theRequest) {
RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method);
return new RpcProtobufRequest(rpcRequestHeader, theRequest);
}

private Message getReturnMessage(final Method method,
final RpcWritable.Buffer buf) throws ServiceException {
Message prototype = null;
Expand Down Expand Up @@ -332,6 +335,14 @@ private Message getReturnProtoType(Method method) throws Exception {
public ConnectionId getConnectionId() {
return remoteId;
}

protected long getClientProtocolVersion() {
return clientProtocolVersion;
}

protected String getProtocolName() {
return protocolName;
}
}

@VisibleForTesting
Expand Down Expand Up @@ -518,6 +529,13 @@ public Writable call(RPC.Server server, String connectionProtocolName,
String declaringClassProtoName =
rpcRequest.getDeclaringClassProtocolName();
long clientVersion = rpcRequest.getClientProtocolVersion();
return call(server, connectionProtocolName, request, receiveTime,
methodName, declaringClassProtoName, clientVersion);
}

protected Writable call(RPC.Server server, String connectionProtocolName,
RpcWritable.Buffer request, long receiveTime, String methodName,
String declaringClassProtoName, long clientVersion) throws Exception {
if (server.verbose)
LOG.info("Call: connectionProtocolName=" + connectionProtocolName +
", method=" + methodName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public ProtocolProxy<ProtocolMetaInfoPB> getProtocolMetaInfoProxy(
factory)), false);
}

private static final class Invoker implements RpcInvocationHandler {
protected static class Invoker implements RpcInvocationHandler {
private final Map<String, Message> returnTypes =
new ConcurrentHashMap<String, Message>();
private boolean isClosed = false;
Expand All @@ -127,7 +127,7 @@ private static final class Invoker implements RpcInvocationHandler {
private AtomicBoolean fallbackToSimpleAuth;
private AlignmentContext alignmentContext;

private Invoker(Class<?> protocol, InetSocketAddress addr,
protected Invoker(Class<?> protocol, InetSocketAddress addr,
UserGroupInformation ticket, Configuration conf, SocketFactory factory,
int rpcTimeout, RetryPolicy connectionRetryPolicy,
AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext)
Expand All @@ -142,7 +142,7 @@ private Invoker(Class<?> protocol, InetSocketAddress addr,
/**
* This constructor takes a connectionId, instead of creating a new one.
*/
private Invoker(Class<?> protocol, Client.ConnectionId connId,
protected Invoker(Class<?> protocol, Client.ConnectionId connId,
Configuration conf, SocketFactory factory) {
this.remoteId = connId;
this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class);
Expand Down Expand Up @@ -219,8 +219,6 @@ public Message invoke(Object proxy, final Method method, Object[] args)
traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method));
}

RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method);

if (LOG.isTraceEnabled()) {
LOG.trace(Thread.currentThread().getId() + ": Call -> " +
remoteId + ": " + method.getName() +
Expand All @@ -232,7 +230,7 @@ public Message invoke(Object proxy, final Method method, Object[] args)
final RpcWritable.Buffer val;
try {
val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER,
new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId,
constructRpcRequest(method, theRequest), remoteId,
fallbackToSimpleAuth, alignmentContext);

} catch (Throwable e) {
Expand Down Expand Up @@ -279,6 +277,11 @@ public boolean isDone() {
}
}

protected Writable constructRpcRequest(Method method, Message theRequest) {
RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method);
return new RpcProtobufRequest(rpcRequestHeader, theRequest);
}

private Message getReturnMessage(final Method method,
final RpcWritable.Buffer buf) throws ServiceException {
Message prototype = null;
Expand Down Expand Up @@ -328,6 +331,14 @@ private Message getReturnProtoType(Method method) throws Exception {
public ConnectionId getConnectionId() {
return remoteId;
}

protected long getClientProtocolVersion() {
return clientProtocolVersion;
}

protected String getProtocolName() {
return protocolName;
}
}

@VisibleForTesting
Expand Down Expand Up @@ -509,6 +520,13 @@ public Writable call(RPC.Server server, String connectionProtocolName,
String declaringClassProtoName =
rpcRequest.getDeclaringClassProtocolName();
long clientVersion = rpcRequest.getClientProtocolVersion();
return call(server, connectionProtocolName, request, receiveTime,
methodName, declaringClassProtoName, clientVersion);
}

protected Writable call(RPC.Server server, String connectionProtocolName,
RpcWritable.Buffer request, long receiveTime, String methodName,
String declaringClassProtoName, long clientVersion) throws Exception {
if (server.verbose) {
LOG.info("Call: connectionProtocolName=" + connectionProtocolName +
", method=" + methodName);
Expand Down