Skip to content

Commit f542e97

Browse files
Fix regression in sdk (changes in token federation) (#1052)
- The issue : - An [E2E test](https://docs.google.com/document/d/1McX4IgD-ZBTtiNXNUrEemsjbVj3oeoQYDwegjjln6UA/edit?ouid=113104269373381935368&tab=t.0#heading=h.k832j8h2svg) conducted with JDBC, with additional SDK-level logging, confirmed that the SDK was refreshing tokens on every call—even when successive calls occurred within seconds. This behavior occurred despite Databricks M2M tokens being valid for 59 minutes. As a result, the token endpoint was hit excessively, eventually triggering global rate limits and throttling for the IP. - Each [getToken call in the SDK](https://github.com/databricks/databricks-jdbc/blob/a17b84c1a0418094a8434f56246c764fa235d19b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java#L166) fetched a new token from the server because the SDK did not cache tokens correctly. This was traced to a regression in the SDK’s token management logic. - This PR ensures that the SDK is configured once in the constructor, preventing repeated token endpoint calls. We also plan to perform a broader SDK code audit to identify and address any similar issues going forward. - Tested manually using M2M flow : The token retrieval is now performed only once when M2M creds are used. - Unit tests - Internal doc : https://docs.google.com/document/d/1McX4IgD-ZBTtiNXNUrEemsjbVj3oeoQYDwegjjln6UA/edit?tab=t.g07tag19b223 --------- Co-authored-by: Gopal Lal <[email protected]>
1 parent 965bd19 commit f542e97

File tree

4 files changed

+24
-21
lines changed

4 files changed

+24
-21
lines changed

NEXT_CHANGELOG.md

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/main/java/com/databricks/jdbc/auth/DatabricksTokenFederationProvider.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public class DatabricksTokenFederationProvider implements CredentialsProvider, T
4646
private static final JdbcLogger LOGGER =
4747
JdbcLoggerFactory.getLogger(DatabricksTokenFederationProvider.class);
4848
private Token token;
49+
private HeaderFactory externalHeaderFactory;
4950
private static final Map<String, String> TOKEN_EXCHANGE_PARAMS =
5051
Map.of(
5152
"grant_type",
@@ -69,6 +70,9 @@ public DatabricksTokenFederationProvider(
6970
this.credentialsProvider = credentialsProvider;
7071
this.externalProviderHeaders = new HashMap<>();
7172
this.hc = DatabricksHttpClientFactory.getInstance().getClient(connectionContext);
73+
// Initialize a minimal config; real config will be provided via configure(databricksConfig)
74+
this.config = null;
75+
this.externalHeaderFactory = null;
7276
this.token =
7377
new Token(
7478
DatabricksJdbcConstants.EMPTY_STRING,
@@ -85,6 +89,7 @@ public DatabricksTokenFederationProvider(
8589
this.connectionContext = connectionContext;
8690
this.credentialsProvider = credentialsProvider;
8791
this.config = config;
92+
this.externalHeaderFactory = this.credentialsProvider.configure(this.config);
8893
this.externalProviderHeaders = new HashMap<>();
8994
this.token =
9095
new Token(
@@ -113,6 +118,8 @@ public HeaderFactory configure(DatabricksConfig databricksConfig) {
113118
}
114119

115120
this.config = databricksConfig;
121+
// Call the underlying provider's configure ONCE and cache the HeaderFactory
122+
this.externalHeaderFactory = this.credentialsProvider.configure(this.config);
116123
return () -> {
117124
Token exchangedToken = getToken();
118125
Map<String, String> headers = new HashMap<>(this.externalProviderHeaders);
@@ -124,7 +131,11 @@ public HeaderFactory configure(DatabricksConfig databricksConfig) {
124131
}
125132

126133
public Token getToken() {
127-
this.externalProviderHeaders = this.credentialsProvider.configure(this.config).headers();
134+
if (this.externalHeaderFactory == null) {
135+
// Lazy-initialize if configure(databricksConfig) was not called yet
136+
this.externalHeaderFactory = this.credentialsProvider.configure(this.config);
137+
}
138+
this.externalProviderHeaders = this.externalHeaderFactory.headers();
128139
String[] tokenInfo = extractTokenInfoFromHeader(this.externalProviderHeaders);
129140
String accessTokenType = tokenInfo[0];
130141
String accessToken = tokenInfo[1];

src/main/java/com/databricks/jdbc/dbclient/impl/common/ClientConfigurator.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -377,18 +377,21 @@ public void setupM2MConfig() throws DatabricksParsingException {
377377
connectionContext, new AzureServicePrincipalCredentialsProvider()));
378378
} else {
379379
databricksConfig
380-
.setAuthType(DatabricksJdbcConstants.M2M_AUTH_TYPE)
381380
.setClientId(connectionContext.getClientId())
382381
.setClientSecret(connectionContext.getClientSecret());
383382
if (connectionContext.useJWTAssertion()) {
384-
databricksConfig.setCredentialsProvider(
385-
new DatabricksTokenFederationProvider(
386-
connectionContext,
387-
new PrivateKeyClientCredentialProvider(connectionContext, databricksConfig)));
383+
CredentialsProvider jwtProvider =
384+
new PrivateKeyClientCredentialProvider(connectionContext, databricksConfig);
385+
databricksConfig
386+
.setAuthType(jwtProvider.authType())
387+
.setCredentialsProvider(
388+
new DatabricksTokenFederationProvider(connectionContext, jwtProvider));
388389
} else {
389-
databricksConfig.setCredentialsProvider(
390-
new DatabricksTokenFederationProvider(
391-
connectionContext, new OAuthM2MServicePrincipalCredentialsProvider()));
390+
CredentialsProvider m2mProvider = new OAuthM2MServicePrincipalCredentialsProvider();
391+
databricksConfig
392+
.setAuthType(DatabricksJdbcConstants.M2M_AUTH_TYPE)
393+
.setCredentialsProvider(
394+
new DatabricksTokenFederationProvider(connectionContext, m2mProvider));
392395
}
393396
}
394397
}

src/test/java/com/databricks/jdbc/dbclient/impl/common/ClientConfiguratorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ void testM2MWithJWT() throws DatabricksSQLException {
170170
assertEquals("https://sample-host.18.azuredatabricks.net", config.getHost());
171171
assertEquals("test-client", config.getClientId());
172172
assertEquals("custom-oauth-m2m", provider.authType());
173-
assertEquals(DatabricksJdbcConstants.M2M_AUTH_TYPE, config.getAuthType());
173+
assertEquals(provider.authType(), config.getAuthType());
174174
assertEquals(
175175
PrivateKeyClientCredentialProvider.class, provider.getCredentialsProvider().getClass());
176176
}

0 commit comments

Comments
 (0)