Skip to content
Merged
207 changes: 131 additions & 76 deletions src/main/java/com/basho/riak/client/core/RiakNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.security.KeyStore;
import java.security.Security;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -202,6 +203,7 @@ private RiakNode(Builder builder)
permits = new Sync(builder.maxConnections);
}

checkNetworkAddressCacheSettings();

this.state = State.CREATED;
}
Expand Down Expand Up @@ -246,14 +248,9 @@ public synchronized RiakNode start() throws UnknownHostException
ownsBootstrap = true;
}

InetSocketAddress socketAddress = new InetSocketAddress(remoteAddress, port);

if (socketAddress.isUnresolved())
{
throw new UnknownHostException("RiakNode:start - Failed resolving host " + remoteAddress);
}
bootstrap.handler(new RiakChannelInitializer(this));

bootstrap.handler(new RiakChannelInitializer(this)).remoteAddress(socketAddress);
refreshBootstrapRemoteAddress();

if (connectionTimeout > 0)
{
Expand All @@ -268,7 +265,7 @@ public synchronized RiakNode start() throws UnknownHostException
Channel channel;
try
{
channel = doGetConnection();
channel = doGetConnection(false);
minChannels.add(channel);
}
catch (ConnectionFailedException ex)
Expand All @@ -293,6 +290,19 @@ public synchronized RiakNode start() throws UnknownHostException
return this;
}

private void refreshBootstrapRemoteAddress() throws UnknownHostException
{
// Refresh the address, hope their DNS TTL settings allow this.
InetSocketAddress socketAddress = new InetSocketAddress(remoteAddress, port);

if (socketAddress.isUnresolved())
{
throw new UnknownHostException("RiakNode:start - Failed resolving host " + remoteAddress);
}

bootstrap.remoteAddress(socketAddress);
}

public synchronized Future<Boolean> shutdown()
{
stateCheck(State.RUNNING, State.HEALTH_CHECKING);
Expand Down Expand Up @@ -640,18 +650,23 @@ private Channel getConnection()
{
try
{
channel = doGetConnection();
channel = doGetConnection(true);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems weird that we get a connection during shutdown.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we get a connection on shutdown?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that what is happening here? shutdown() method calls doGetConnection ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I can see, the only places we call doGetConnection/1 are in start(), the general getConnection/1 method, and the health checker.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh stupid GH code folding. Even when I expanded the diff it still wasn't clear. Sorry!

channel.closeFuture().removeListener(inAvailableCloseListener);
}
catch (ConnectionFailedException ex)
{
permits.release();
}
catch (UnknownHostException ex)
{
permits.release();
logger.error("Unknown host encountered while trying to open connection; {}", ex);
}
}
return channel;
}

private Channel doGetConnection() throws ConnectionFailedException
private Channel doGetConnection(boolean forceAddressRefresh) throws ConnectionFailedException, UnknownHostException
{
ChannelWithIdleTime cwi;
while ((cwi = available.poll()) != null)
Expand All @@ -667,6 +682,11 @@ private Channel doGetConnection() throws ConnectionFailedException
}
}

if(forceAddressRefresh)
{
refreshBootstrapRemoteAddress();
}

ChannelFuture f = bootstrap.connect();

try
Expand Down Expand Up @@ -694,83 +714,87 @@ private Channel doGetConnection() throws ConnectionFailedException

if (trustStore != null)
{
SSLContext context;
try
{
context = SSLContext.getInstance("TLS");
TrustManagerFactory tmf =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(trustStore);
if (keyStore!=null)
{
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(keyStore, keyPassword==null?"".toCharArray():keyPassword.toCharArray());
context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
}
else
{
context.init(null, tmf.getTrustManagers(), null);
}

}
catch (Exception ex)
{
c.close();
logger.error("Failure configuring SSL; {}:{} {}", remoteAddress, port, ex);
throw new ConnectionFailedException(ex);
}
setupTLSAndAuthenticate(c);
}

SSLEngine engine = context.createSSLEngine();
return c;

Set<String> protocols = new HashSet<String>(Arrays.asList(engine.getSupportedProtocols()));
}

if (protocols.contains("TLSv1.2"))
private void setupTLSAndAuthenticate(Channel c) throws ConnectionFailedException
{
SSLContext context;
try
{
context = SSLContext.getInstance("TLS");
TrustManagerFactory tmf =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(trustStore);
if(keyStore!=null)
{
engine.setEnabledProtocols(new String[] {"TLSv1.2"});
logger.debug("Using TLSv1.2");
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(keyStore, keyPassword==null?"".toCharArray():keyPassword.toCharArray());
context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
}
else if (protocols.contains("TLSv1.1"))
else
{
engine.setEnabledProtocols(new String[] {"TLSv1.1"});
logger.debug("Using TLSv1.1");
context.init(null, tmf.getTrustManagers(), null);
}

engine.setUseClientMode(true);
RiakSecurityDecoder decoder = new RiakSecurityDecoder(engine, username, password);
c.pipeline().addFirst(decoder);
}
catch (Exception ex)
{
c.close();
logger.error("Failure configuring SSL; {}:{} {}", remoteAddress, port, ex);
throw new ConnectionFailedException(ex);
}

try
{
DefaultPromise<Void> promise = decoder.getPromise();
logger.debug("Waiting on SSL Promise");
promise.await();
SSLEngine engine = context.createSSLEngine();

if (promise.isSuccess())
{
logger.debug("Auth succeeded; {}:{}", remoteAddress, port);
}
else
{
c.close();
logger.error("Failure during Auth; {}:{} {}",remoteAddress, port, promise.cause());
throw new ConnectionFailedException(promise.cause());
}
Set<String> protocols = new HashSet<String>(Arrays.asList(engine.getSupportedProtocols()));

if (protocols.contains("TLSv1.2"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why not just pass the supported protocols as the enabled ones? TLS should negotiate the highest supported version.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I've been meaning to revamp some of the TLS code + testing sometime soon so I'll roll this into that issue.

{
engine.setEnabledProtocols(new String[] {"TLSv1.2"});
logger.debug("Using TLSv1.2");
}
else if (protocols.contains("TLSv1.1"))
{
engine.setEnabledProtocols(new String[] {"TLSv1.1"});
logger.debug("Using TLSv1.1");
}

engine.setUseClientMode(true);
RiakSecurityDecoder decoder = new RiakSecurityDecoder(engine, username, password);
c.pipeline().addFirst(decoder);

try
{
DefaultPromise<Void> promise = decoder.getPromise();
logger.debug("Waiting on SSL Promise");
promise.await();

if (promise.isSuccess())
{
logger.debug("Auth succeeded; {}:{}", remoteAddress, port);
}
catch (InterruptedException e)
else
{
c.close();
logger.error("Thread interrupted during Auth; {}:{}",
remoteAddress, port);
Thread.currentThread().interrupt();
throw new ConnectionFailedException(e);
logger.error("Failure during Auth; {}:{} {}",remoteAddress, port, promise.cause());
throw new ConnectionFailedException(promise.cause());
}

}

return c;

}
catch (InterruptedException e)
{
c.close();
logger.error("Thread interrupted during Auth; {}:{}",
remoteAddress, port);
Thread.currentThread().interrupt();
throw new ConnectionFailedException(e);
}
}

/**
Expand Down Expand Up @@ -878,6 +902,34 @@ public void onException(Channel channel, final Throwable t)
}
}

private void checkNetworkAddressCacheSettings()
{
final String property = Security.getProperty("networkaddress.cache.ttl");

final boolean usingSecurityMgr = System.getSecurityManager() != null;
final boolean propertyUndefined = property == null;
boolean logWarning = false;

if(propertyUndefined && usingSecurityMgr)
{
logWarning = true;
}
else if(!propertyUndefined)
{
final int cacheTTL = Integer.parseInt(property);
logWarning = (cacheTTL == -1);
}

if (logWarning)
{
logger.warn(
"The Java Security \"networkaddress.cache.ttl\" property may be set to cache DNS lookups forever. " +
"Using domain names for Riak nodes or an intermediate load balancer could result in stale IP " +
"addresses being used for new connections, causing connection errors. " +
"If you use domain names for Riak nodes, please set this property to a value greater than zero.");
}
}

/**
* Returns the {@code remoteAddress} for this RiakNode
*
Expand Down Expand Up @@ -1063,7 +1115,7 @@ public void run()
}
}

private void checkHealth()
private void checkHealth()
{
try
{
Expand All @@ -1073,7 +1125,7 @@ private void checkHealth()
// connections from the available queue and either
// return/create a new one (meaning the node is up) or throw
// an exception if a connection can't be made.
Channel c = doGetConnection();
Channel c = doGetConnection(true);
logger.debug("Healthcheck channel: {} isOpen: {} handlers:{}", c.hashCode(), c.isOpen(), c.pipeline().names());


Expand Down Expand Up @@ -1123,18 +1175,21 @@ private void checkHealth()
{
healthCheckFailed(ex);
}
catch (IllegalStateException e)
catch (UnknownHostException ex)
{
healthCheckFailed(ex);
}
catch (IllegalStateException ex)
{
// no-op; there's a race condition where the bootstrap is shutting down
// right when a healthcheck occurs and netty will throw this
logger.debug("Illegal state exception during healthcheck.");
logger.debug("Stack: {}", e);
logger.debug("Stack: {}", ex);
}
catch (RuntimeException e)
catch (RuntimeException ex)
{
logger.error("Runtime exception during healthcheck: {}",e);
logger.error("Runtime exception during healthcheck: {}", ex);
}

}

private void healthCheckFailed(Throwable cause)
Expand Down
Loading