|
18 | 18 | package org.apache.spark |
19 | 19 |
|
20 | 20 | import java.io.File |
| 21 | +import java.util.UUID |
21 | 22 | import javax.net.ssl.SSLContext |
22 | 23 |
|
23 | 24 | import org.apache.hadoop.conf.Configuration |
| 25 | +import org.apache.hadoop.security.alias.{CredentialProvider, CredentialProviderFactory} |
24 | 26 | import org.scalatest.BeforeAndAfterAll |
25 | 27 |
|
26 | 28 | import org.apache.spark.util.SparkConfWithEnv |
@@ -154,4 +156,60 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { |
154 | 156 | assert(opts.trustStore === Some(new File("val2"))) |
155 | 157 | } |
156 | 158 |
|
| 159 | + test("get password from Hadoop credential provider") { |
| 160 | + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath |
| 161 | + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath |
| 162 | + |
| 163 | + val conf = new SparkConf |
| 164 | + val hadoopConf = new Configuration() |
| 165 | + val tmpPath = s"localjceks://file${sys.props("java.io.tmpdir")}/test-" + |
| 166 | + s"${UUID.randomUUID().toString}.jceks" |
| 167 | + val provider = createCredentialProvider(tmpPath, hadoopConf) |
| 168 | + |
| 169 | + conf.set("spark.ssl.enabled", "true") |
| 170 | + conf.set("spark.ssl.keyStore", keyStorePath) |
| 171 | + storePassword(provider, "spark.ssl.keyStorePassword", "password") |
| 172 | + storePassword(provider, "spark.ssl.keyPassword", "password") |
| 173 | + conf.set("spark.ssl.trustStore", trustStorePath) |
| 174 | + storePassword(provider, "spark.ssl.trustStorePassword", "password") |
| 175 | + conf.set("spark.ssl.enabledAlgorithms", |
| 176 | + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") |
| 177 | + conf.set("spark.ssl.protocol", "SSLv3") |
| 178 | + |
| 179 | + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) |
| 180 | + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) |
| 181 | + |
| 182 | + assert(opts.enabled === true) |
| 183 | + assert(opts.trustStore.isDefined === true) |
| 184 | + assert(opts.trustStore.get.getName === "truststore") |
| 185 | + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) |
| 186 | + assert(opts.keyStore.isDefined === true) |
| 187 | + assert(opts.keyStore.get.getName === "keystore") |
| 188 | + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) |
| 189 | + assert(opts.trustStorePassword === Some("password")) |
| 190 | + assert(opts.keyStorePassword === Some("password")) |
| 191 | + assert(opts.keyPassword === Some("password")) |
| 192 | + assert(opts.protocol === Some("SSLv3")) |
| 193 | + assert(opts.enabledAlgorithms === |
| 194 | + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) |
| 195 | + } |
| 196 | + |
| 197 | + private def createCredentialProvider(tmpPath: String, conf: Configuration): CredentialProvider = { |
| 198 | + conf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, tmpPath) |
| 199 | + |
| 200 | + val provider = CredentialProviderFactory.getProviders(conf).get(0) |
| 201 | + if (provider == null) { |
| 202 | + throw new IllegalStateException(s"Fail to get credential provider with path $tmpPath") |
| 203 | + } |
| 204 | + |
| 205 | + provider |
| 206 | + } |
| 207 | + |
| 208 | + private def storePassword( |
| 209 | + provider: CredentialProvider, |
| 210 | + passwordKey: String, |
| 211 | + password: String): Unit = { |
| 212 | + provider.createCredentialEntry(passwordKey, password.toCharArray) |
| 213 | + provider.flush() |
| 214 | + } |
157 | 215 | } |
0 commit comments