Skip to content

Commit 79f4ed7

Browse files
authored
Implement notebook-native auth for the Java SDK (#171)
## Changes Notebook-native authentication has been supported in the Python SDK for the last several months, and customers using the SDKs in Databricks appreciate the easier CUJ for authentication. This PR ports the behavior to the Java SDK. The approach taken for this implementation differs somewhat from the Python SDK's implementation. In particular, in Python, DBR provides some functionality to get the command context (`from dbruntime.databricks_repl_context import get_context`, `from dbruntime.sdk_credential_provider import init_runtime_native_auth`) as well as from IPython. The Java SDK uses the DBUtils object directly. However, the SDK doesn't depend on the DBUtils API (which is only in Scala), so instead we load it reflectively. ## Tests Ran in a notebook in DBR, worked.
1 parent c64abf5 commit 79f4ed7

3 files changed

Lines changed: 128 additions & 2 deletions

File tree

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksError.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public boolean isRetriable() {
106106
return true;
107107
}
108108
for (String substring : TRANSIENT_ERROR_STRING_MATCHES) {
109-
if (message.contains(substring)) {
109+
if (message != null && message.contains(substring)) {
110110
LOG.debug("Attempting retry because of {}", substring);
111111
return true;
112112
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ public DefaultCredentialsProvider() {
2828
new AzureServicePrincipalCredentialsProvider(),
2929
new AzureCliCredentialsProvider(),
3030
new ExternalBrowserCredentialsProvider(),
31-
new DatabricksCliCredentialsProvider());
31+
new DatabricksCliCredentialsProvider(),
32+
new NotebookNativeCredentialsProvider());
3233
}
3334

3435
@Override
@@ -42,6 +43,7 @@ public synchronized HeaderFactory configure(DatabricksConfig config) {
4243
continue;
4344
}
4445
try {
46+
LOG.info("Trying {} auth", provider.authType());
4547
HeaderFactory headerFactory = provider.configure(config);
4648
if (headerFactory == null) {
4749
continue;
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package com.databricks.sdk.core;
2+
3+
import java.lang.reflect.Field;
4+
import java.lang.reflect.InvocationTargetException;
5+
import java.util.*;
6+
import org.slf4j.Logger;
7+
import org.slf4j.LoggerFactory;
8+
9+
/**
10+
* A CredentialsProvider that uses the API token from the command context to authenticate.
11+
*
12+
* <p>The token and hostname are read from the command context, which can be retrieved through the
13+
* dbutils API. As the Java SDK does not depend on DBUtils directly, reflection is used to retrieve
14+
* the token. This token should be available wherever the DBUtils API is accessible (i.e. in the
15+
* Spark driver).
16+
*/
17+
public class NotebookNativeCredentialsProvider implements CredentialsProvider {
18+
private static final Logger LOG =
19+
LoggerFactory.getLogger(NotebookNativeCredentialsProvider.class);
20+
21+
@Override
22+
public String authType() {
23+
return "runtime";
24+
}
25+
26+
@Override
27+
public HeaderFactory configure(DatabricksConfig config) {
28+
if (System.getenv("DATABRICKS_RUNTIME_VERSION") == null) {
29+
LOG.debug("DBR not detected, skipping runtime auth");
30+
return null;
31+
}
32+
33+
try {
34+
// DBUtils is not available in the Java SDK, so we have to use reflection to get the token.
35+
// First, we get the context by calling getContext on the notebook field of dbutils, then we
36+
// get the apiKey and apiUrl fields from the context. If this is successful, we set the host
37+
// on the config.
38+
Object dbutils = getDbUtils();
39+
if (dbutils == null) {
40+
LOG.debug("DBUtils is not available, skipping runtime auth");
41+
return null;
42+
}
43+
Object notebook = getField(dbutils, "notebook");
44+
TokenAndUrl testTokenAndUrl = getTokenAndUrl(notebook);
45+
if (testTokenAndUrl.url == null) {
46+
LOG.debug("Workspace URL is not available, skipping runtime auth");
47+
}
48+
config.setHost(testTokenAndUrl.url);
49+
50+
return () -> {
51+
Map<String, String> headers = new HashMap<>();
52+
TokenAndUrl tokenAndUrl = getTokenAndUrl(notebook);
53+
headers.put("Authorization", String.format("Bearer %s", tokenAndUrl.token));
54+
return headers;
55+
};
56+
} catch (DatabricksException e) {
57+
LOG.debug("Failed to get token from command context, skipping runtime auth", e);
58+
return null;
59+
}
60+
}
61+
62+
/** Load the dbutils object initialized by DBR. */
63+
private static Object getDbUtils() {
64+
try {
65+
Class<?> dbutilsHolderClass = Class.forName("com.databricks.dbutils_v1.DBUtilsHolder$");
66+
Object dbutilsHolder = dbutilsHolderClass.getDeclaredField("MODULE$").get(null);
67+
InheritableThreadLocal<Object> dbutils = getField(dbutilsHolder, "dbutils0");
68+
return dbutils.get();
69+
} catch (ClassNotFoundException | IllegalAccessException | NoSuchFieldException e) {
70+
throw new DatabricksException("failed getting DBUtils", e);
71+
}
72+
}
73+
74+
/** Reflectively get a field by name from an object. */
75+
private static <T> T getField(Object o, String fieldName) {
76+
Field f;
77+
try {
78+
f = o.getClass().getDeclaredField(fieldName);
79+
} catch (NoSuchFieldException e) {
80+
throw new DatabricksException("field " + fieldName + " does not exist", e);
81+
}
82+
boolean accessible = f.isAccessible();
83+
try {
84+
f.setAccessible(true);
85+
return (T) f.get(o);
86+
} catch (IllegalAccessException e) {
87+
throw new DatabricksException("failed getting field " + fieldName, e);
88+
} finally {
89+
if (!accessible) {
90+
f.setAccessible(false);
91+
}
92+
}
93+
}
94+
95+
private static class TokenAndUrl {
96+
public final String token;
97+
public final String url;
98+
99+
TokenAndUrl(String token, String url) {
100+
this.token = token;
101+
this.url = url;
102+
}
103+
}
104+
105+
/** Fetch the current command context, and read the API token and URL from it. */
106+
private static TokenAndUrl getTokenAndUrl(Object notebook) {
107+
try {
108+
Object testCommandContext =
109+
notebook.getClass().getDeclaredMethod("getContext").invoke(notebook);
110+
Object tokenOpt =
111+
testCommandContext.getClass().getDeclaredMethod("apiToken").invoke(testCommandContext);
112+
String token = (String) tokenOpt.getClass().getDeclaredMethod("get").invoke(tokenOpt);
113+
Object hostOpt =
114+
testCommandContext.getClass().getDeclaredMethod("apiUrl").invoke(testCommandContext);
115+
String host = (String) hostOpt.getClass().getDeclaredMethod("get").invoke(hostOpt);
116+
return new TokenAndUrl(token, host);
117+
} catch (InvocationTargetException
118+
| NoSuchMethodException
119+
| IllegalAccessException
120+
| NoSuchElementException e) {
121+
throw new DatabricksException("failed to get token and URL from command context", e);
122+
}
123+
}
124+
}

0 commit comments

Comments
 (0)