diff --git a/java/lance-jni/src/namespace.rs b/java/lance-jni/src/namespace.rs index 4b1d5a82d21..b9db171c064 100644 --- a/java/lance-jni/src/namespace.rs +++ b/java/lance-jni/src/namespace.rs @@ -1,23 +1,121 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; +use std::sync::Arc; + use bytes::Bytes; -use jni::objects::{JByteArray, JMap, JObject, JString}; +use jni::objects::{GlobalRef, JByteArray, JMap, JObject, JString, JValue}; use jni::sys::{jbyteArray, jlong, jstring}; use jni::JNIEnv; use lance_namespace::models::*; use lance_namespace::LanceNamespace as LanceNamespaceTrait; use lance_namespace_impls::{ - ConnectBuilder, DirectoryNamespace, DirectoryNamespaceBuilder, RestAdapter, RestAdapterConfig, - RestNamespace, RestNamespaceBuilder, + ConnectBuilder, DirectoryNamespace, DirectoryNamespaceBuilder, DynamicContextProvider, + OperationInfo, RestAdapter, RestAdapterConfig, RestNamespace, RestNamespaceBuilder, }; use serde::{Deserialize, Serialize}; -use std::sync::Arc; use crate::error::{Error, Result}; use crate::utils::to_rust_map; use crate::RT; +/// Java-implemented dynamic context provider. +/// +/// Wraps a Java object that implements the DynamicContextProvider interface. +pub struct JavaDynamicContextProvider { + java_provider: GlobalRef, + jvm: Arc, +} + +impl JavaDynamicContextProvider { + /// Create a new Java context provider wrapper. + pub fn new(env: &mut JNIEnv, java_provider: &JObject) -> Result { + let java_provider = env.new_global_ref(java_provider)?; + let jvm = Arc::new(env.get_java_vm()?); + Ok(Self { java_provider, jvm }) + } +} + +impl std::fmt::Debug for JavaDynamicContextProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JavaDynamicContextProvider") + } +} + +impl DynamicContextProvider for JavaDynamicContextProvider { + fn provide_context(&self, info: &OperationInfo) -> HashMap { + // Attach to JVM + let mut env = match self.jvm.attach_current_thread() { + Ok(env) => env, + Err(e) => { + log::error!("Failed to attach to JVM: {}", e); + return HashMap::new(); + } + }; + + // Create Java strings for parameters + let operation = match env.new_string(&info.operation) { + Ok(s) => s, + Err(e) => { + log::error!("Failed to create operation string: {}", e); + return HashMap::new(); + } + }; + + let object_id = match env.new_string(&info.object_id) { + Ok(s) => s, + Err(e) => { + log::error!("Failed to create object_id string: {}", e); + return HashMap::new(); + } + }; + + // Call provideContext(String, String) -> Map + let result = env.call_method( + &self.java_provider, + "provideContext", + "(Ljava/lang/String;Ljava/lang/String;)Ljava/util/Map;", + &[JValue::Object(&operation), JValue::Object(&object_id)], + ); + + match result { + Ok(jvalue) => match jvalue.l() { + Ok(obj) if !obj.is_null() => { + // Convert Java Map to Rust HashMap + convert_java_map_to_hashmap(&mut env, &obj).unwrap_or_default() + } + Ok(_) => HashMap::new(), + Err(e) => { + log::error!("provideContext did not return object: {}", e); + HashMap::new() + } + }, + Err(e) => { + log::error!("Failed to call provideContext: {}", e); + HashMap::new() + } + } + } +} + +fn convert_java_map_to_hashmap( + env: &mut JNIEnv, + map_obj: &JObject, +) -> Result> { + let jmap = JMap::from_env(env, map_obj)?; + let mut result = HashMap::new(); + + let mut iter = jmap.iter(env)?; + while let Some((key, value)) = iter.next(env)? { + let key_str: String = env.get_string(&JString::from(key))?.into(); + let value_str: String = env.get_string(&JString::from(value))?.into(); + result.insert(key_str, value_str); + } + + Ok(result) +} + /// Blocking wrapper for DirectoryNamespace pub struct BlockingDirectoryNamespace { pub(crate) inner: DirectoryNamespace, @@ -40,20 +138,47 @@ pub extern "system" fn Java_org_lance_namespace_DirectoryNamespace_createNative( ) -> jlong { ok_or_throw_with_return!( env, - create_directory_namespace_internal(&mut env, properties_map), + create_directory_namespace_internal(&mut env, properties_map, None), 0 ) } -fn create_directory_namespace_internal(env: &mut JNIEnv, properties_map: JObject) -> Result { +#[no_mangle] +pub extern "system" fn Java_org_lance_namespace_DirectoryNamespace_createNativeWithProvider( + mut env: JNIEnv, + _obj: JObject, + properties_map: JObject, + context_provider: JObject, +) -> jlong { + ok_or_throw_with_return!( + env, + create_directory_namespace_internal(&mut env, properties_map, Some(context_provider)), + 0 + ) +} + +fn create_directory_namespace_internal( + env: &mut JNIEnv, + properties_map: JObject, + context_provider: Option, +) -> Result { // Convert Java HashMap to Rust HashMap let jmap = JMap::from_env(env, &properties_map)?; let properties = to_rust_map(env, &jmap)?; // Build DirectoryNamespace using builder - let builder = DirectoryNamespaceBuilder::from_properties(properties, None).map_err(|e| { - Error::runtime_error(format!("Failed to create DirectoryNamespaceBuilder: {}", e)) - })?; + let mut builder = + DirectoryNamespaceBuilder::from_properties(properties, None).map_err(|e| { + Error::runtime_error(format!("Failed to create DirectoryNamespaceBuilder: {}", e)) + })?; + + // Add context provider if provided + if let Some(provider_obj) = context_provider { + if !provider_obj.is_null() { + let java_provider = JavaDynamicContextProvider::new(env, &provider_obj)?; + builder = builder.context_provider(Arc::new(java_provider)); + } + } let namespace = RT .block_on(builder.build()) @@ -537,21 +662,47 @@ pub extern "system" fn Java_org_lance_namespace_RestNamespace_createNative( ) -> jlong { ok_or_throw_with_return!( env, - create_rest_namespace_internal(&mut env, properties_map), + create_rest_namespace_internal(&mut env, properties_map, None), 0 ) } -fn create_rest_namespace_internal(env: &mut JNIEnv, properties_map: JObject) -> Result { +#[no_mangle] +pub extern "system" fn Java_org_lance_namespace_RestNamespace_createNativeWithProvider( + mut env: JNIEnv, + _obj: JObject, + properties_map: JObject, + context_provider: JObject, +) -> jlong { + ok_or_throw_with_return!( + env, + create_rest_namespace_internal(&mut env, properties_map, Some(context_provider)), + 0 + ) +} + +fn create_rest_namespace_internal( + env: &mut JNIEnv, + properties_map: JObject, + context_provider: Option, +) -> Result { // Convert Java HashMap to Rust HashMap let jmap = JMap::from_env(env, &properties_map)?; let properties = to_rust_map(env, &jmap)?; // Build RestNamespace using builder - let builder = RestNamespaceBuilder::from_properties(properties).map_err(|e| { + let mut builder = RestNamespaceBuilder::from_properties(properties).map_err(|e| { Error::runtime_error(format!("Failed to create RestNamespaceBuilder: {}", e)) })?; + // Add context provider if provided + if let Some(provider_obj) = context_provider { + if !provider_obj.is_null() { + let java_provider = JavaDynamicContextProvider::new(env, &provider_obj)?; + builder = builder.context_provider(Arc::new(java_provider)); + } + } + let namespace = builder.build(); let blocking_namespace = BlockingRestNamespace { inner: namespace }; diff --git a/java/src/main/java/org/lance/namespace/DirectoryNamespace.java b/java/src/main/java/org/lance/namespace/DirectoryNamespace.java index a0796739a3c..3ffe2b82f01 100644 --- a/java/src/main/java/org/lance/namespace/DirectoryNamespace.java +++ b/java/src/main/java/org/lance/namespace/DirectoryNamespace.java @@ -21,7 +21,10 @@ import org.apache.arrow.memory.BufferAllocator; import java.io.Closeable; +import java.lang.reflect.Constructor; +import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** * DirectoryNamespace implementation that provides Lance namespace functionality for directory-based @@ -149,11 +152,43 @@ public DirectoryNamespace() {} @Override public void initialize(Map configProperties, BufferAllocator allocator) { + initialize(configProperties, allocator, null); + } + + /** + * Initialize with a dynamic context provider. + * + *

If contextProvider is null and the properties contain {@code dynamic_context_provider.impl}, + * the provider will be loaded from the class path. The class must implement {@link + * DynamicContextProvider} and have a constructor accepting {@code Map}. + * + * @param configProperties Configuration properties for the namespace + * @param allocator Arrow buffer allocator + * @param contextProvider Optional provider for per-request context (e.g., dynamic auth headers) + */ + public void initialize( + Map configProperties, + BufferAllocator allocator, + DynamicContextProvider contextProvider) { if (this.nativeDirectoryNamespaceHandle != 0) { throw new IllegalStateException("DirectoryNamespace already initialized"); } this.allocator = allocator; - this.nativeDirectoryNamespaceHandle = createNative(configProperties); + + // If no explicit provider, try to create from properties + DynamicContextProvider provider = contextProvider; + if (provider == null) { + provider = createProviderFromProperties(configProperties).orElse(null); + } + + // Filter out provider properties before passing to native layer + Map filteredProperties = filterProviderProperties(configProperties); + + if (provider != null) { + this.nativeDirectoryNamespaceHandle = createNativeWithProvider(filteredProperties, provider); + } else { + this.nativeDirectoryNamespaceHandle = createNative(filteredProperties); + } } @Override @@ -399,6 +434,9 @@ private static T fromJson(String json, Class clazz) { // Native methods private native long createNative(Map properties); + private native long createNativeWithProvider( + Map properties, DynamicContextProvider contextProvider); + private native void releaseNative(long handle); private native String namespaceIdNative(long handle); @@ -453,4 +491,77 @@ private native String mergeInsertIntoTableNative( private native String describeTransactionNative(long handle, String requestJson); private native String alterTransactionNative(long handle, String requestJson); + + // ========================================================================== + // Provider loading helpers + // ========================================================================== + + private static final String PROVIDER_PREFIX = "dynamic_context_provider."; + private static final String IMPL_KEY = "dynamic_context_provider.impl"; + + /** + * Create a context provider from properties if configured. + * + *

Loads the class specified by {@code dynamic_context_provider.impl} from the class path and + * instantiates it with the extracted provider properties. + */ + private static Optional createProviderFromProperties( + Map properties) { + String className = properties.get(IMPL_KEY); + if (className == null || className.isEmpty()) { + return Optional.empty(); + } + + // Extract provider-specific properties (strip prefix, exclude impl key) + Map providerProps = new HashMap<>(); + for (Map.Entry entry : properties.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(PROVIDER_PREFIX) && !key.equals(IMPL_KEY)) { + String propName = key.substring(PROVIDER_PREFIX.length()); + providerProps.put(propName, entry.getValue()); + } + } + + try { + Class providerClass = Class.forName(className); + if (!DynamicContextProvider.class.isAssignableFrom(providerClass)) { + throw new IllegalArgumentException( + String.format( + "Class '%s' does not implement DynamicContextProvider interface", className)); + } + + @SuppressWarnings("unchecked") + Class typedClass = + (Class) providerClass; + + Constructor constructor = + typedClass.getConstructor(Map.class); + return Optional.of(constructor.newInstance(providerProps)); + + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException( + String.format("Failed to load context provider class '%s': %s", className, e), e); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException( + String.format( + "Context provider class '%s' must have a public constructor " + + "that accepts Map", + className), + e); + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException( + String.format("Failed to instantiate context provider '%s': %s", className, e), e); + } + } + + /** Filter out dynamic_context_provider.* properties from the map. */ + private static Map filterProviderProperties(Map properties) { + Map filtered = new HashMap<>(); + for (Map.Entry entry : properties.entrySet()) { + if (!entry.getKey().startsWith(PROVIDER_PREFIX)) { + filtered.put(entry.getKey(), entry.getValue()); + } + } + return filtered; + } } diff --git a/java/src/main/java/org/lance/namespace/DynamicContextProvider.java b/java/src/main/java/org/lance/namespace/DynamicContextProvider.java new file mode 100644 index 00000000000..77b10c892a4 --- /dev/null +++ b/java/src/main/java/org/lance/namespace/DynamicContextProvider.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.namespace; + +import java.util.Map; + +/** + * Interface for providing dynamic per-request context to namespace operations. + * + *

Implementations can generate per-request context (e.g., authentication headers) based on the + * operation being performed. The provider is called synchronously before each namespace operation. + * + *

For RestNamespace, context keys that start with {@code headers.} are converted to HTTP headers + * by stripping the prefix. For example, {@code {"headers.Authorization": "Bearer abc123"}} becomes + * the {@code Authorization: Bearer abc123} header. Keys without the {@code headers.} prefix are + * ignored for HTTP headers but may be used for other purposes. + * + *

Example implementation: + * + *

+ * public class MyContextProvider implements DynamicContextProvider {
+ *   @Override
+ *   public Map<String, String> provideContext(String operation, String objectId) {
+ *     Map<String, String> context = new HashMap<>();
+ *     context.put("headers.Authorization", "Bearer " + getAuthToken());
+ *     context.put("headers.X-Request-Id", UUID.randomUUID().toString());
+ *     return context;
+ *   }
+ * }
+ * 
+ * + *

Usage with DirectoryNamespace: + * + *

+ * DynamicContextProvider provider = new MyContextProvider();
+ * Map<String, String> properties = Map.of("root", "/path/to/data");
+ * DirectoryNamespace namespace = new DirectoryNamespace();
+ * namespace.initialize(properties, allocator, provider);
+ * 
+ * + *

Usage with RestNamespace: + * + *

+ * DynamicContextProvider provider = new MyContextProvider();
+ * Map<String, String> properties = Map.of("uri", "https://api.example.com");
+ * RestNamespace namespace = new RestNamespace();
+ * namespace.initialize(properties, provider);
+ * 
+ */ +public interface DynamicContextProvider { + + /** + * Provide context for a namespace operation. + * + *

This method is called synchronously before each namespace operation. Implementations should + * be thread-safe as multiple operations may be performed concurrently. + * + * @param operation The operation name (e.g., "list_tables", "describe_table", "create_namespace") + * @param objectId The object identifier (namespace or table ID in delimited form, e.g., + * "workspace$table_name") + * @return Map of context key-value pairs. For HTTP headers, use keys with the "headers." prefix + * (e.g., "headers.Authorization"). Return an empty map if no additional context is needed. + * Must not return null. + */ + Map provideContext(String operation, String objectId); +} diff --git a/java/src/main/java/org/lance/namespace/RestNamespace.java b/java/src/main/java/org/lance/namespace/RestNamespace.java index b55eeb2f200..840e9f3d690 100644 --- a/java/src/main/java/org/lance/namespace/RestNamespace.java +++ b/java/src/main/java/org/lance/namespace/RestNamespace.java @@ -21,7 +21,10 @@ import org.apache.arrow.memory.BufferAllocator; import java.io.Closeable; +import java.lang.reflect.Constructor; +import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** * RestNamespace implementation that provides Lance namespace functionality via REST API endpoints. @@ -74,11 +77,47 @@ public RestNamespace() {} @Override public void initialize(Map configProperties, BufferAllocator allocator) { + initialize(configProperties, allocator, null); + } + + /** + * Initialize with a dynamic context provider. + * + *

The context provider is called before each namespace operation and can return per-request + * context (e.g., authentication headers). Context keys that start with {@code headers.} are + * converted to HTTP headers by stripping the prefix. + * + *

If contextProvider is null and the properties contain {@code dynamic_context_provider.impl}, + * the provider will be loaded from the class path. The class must implement {@link + * DynamicContextProvider} and have a constructor accepting {@code Map}. + * + * @param configProperties Configuration properties for the namespace + * @param allocator Arrow buffer allocator + * @param contextProvider Optional provider for per-request context (e.g., dynamic auth headers) + */ + public void initialize( + Map configProperties, + BufferAllocator allocator, + DynamicContextProvider contextProvider) { if (this.nativeRestNamespaceHandle != 0) { throw new IllegalStateException("RestNamespace already initialized"); } this.allocator = allocator; - this.nativeRestNamespaceHandle = createNative(configProperties); + + // If no explicit provider, try to create from properties + DynamicContextProvider provider = contextProvider; + if (provider == null) { + provider = createProviderFromProperties(configProperties).orElse(null); + } + + // Filter out provider properties before passing to native layer + Map filteredProperties = filterProviderProperties(configProperties); + + if (provider != null) { + this.nativeRestNamespaceHandle = createNativeWithProvider(filteredProperties, provider); + } else { + this.nativeRestNamespaceHandle = createNative(filteredProperties); + } } @Override @@ -321,6 +360,9 @@ private static T fromJson(String json, Class clazz) { // Native methods private native long createNative(Map properties); + private native long createNativeWithProvider( + Map properties, DynamicContextProvider contextProvider); + private native void releaseNative(long handle); private native String namespaceIdNative(long handle); @@ -375,4 +417,77 @@ private native String mergeInsertIntoTableNative( private native String describeTransactionNative(long handle, String requestJson); private native String alterTransactionNative(long handle, String requestJson); + + // ========================================================================== + // Provider loading helpers + // ========================================================================== + + private static final String PROVIDER_PREFIX = "dynamic_context_provider."; + private static final String IMPL_KEY = "dynamic_context_provider.impl"; + + /** + * Create a context provider from properties if configured. + * + *

Loads the class specified by {@code dynamic_context_provider.impl} from the class path and + * instantiates it with the extracted provider properties. + */ + private static Optional createProviderFromProperties( + Map properties) { + String className = properties.get(IMPL_KEY); + if (className == null || className.isEmpty()) { + return Optional.empty(); + } + + // Extract provider-specific properties (strip prefix, exclude impl key) + Map providerProps = new HashMap<>(); + for (Map.Entry entry : properties.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(PROVIDER_PREFIX) && !key.equals(IMPL_KEY)) { + String propName = key.substring(PROVIDER_PREFIX.length()); + providerProps.put(propName, entry.getValue()); + } + } + + try { + Class providerClass = Class.forName(className); + if (!DynamicContextProvider.class.isAssignableFrom(providerClass)) { + throw new IllegalArgumentException( + String.format( + "Class '%s' does not implement DynamicContextProvider interface", className)); + } + + @SuppressWarnings("unchecked") + Class typedClass = + (Class) providerClass; + + Constructor constructor = + typedClass.getConstructor(Map.class); + return Optional.of(constructor.newInstance(providerProps)); + + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException( + String.format("Failed to load context provider class '%s': %s", className, e), e); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException( + String.format( + "Context provider class '%s' must have a public constructor " + + "that accepts Map", + className), + e); + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException( + String.format("Failed to instantiate context provider '%s': %s", className, e), e); + } + } + + /** Filter out dynamic_context_provider.* properties from the map. */ + private static Map filterProviderProperties(Map properties) { + Map filtered = new HashMap<>(); + for (Map.Entry entry : properties.entrySet()) { + if (!entry.getKey().startsWith(PROVIDER_PREFIX)) { + filtered.put(entry.getKey(), entry.getValue()); + } + } + return filtered; + } } diff --git a/java/src/test/java/org/lance/namespace/DynamicContextProviderTest.java b/java/src/test/java/org/lance/namespace/DynamicContextProviderTest.java new file mode 100644 index 00000000000..7959eb9be58 --- /dev/null +++ b/java/src/test/java/org/lance/namespace/DynamicContextProviderTest.java @@ -0,0 +1,307 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.namespace; + +import org.lance.namespace.model.*; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.*; + +/** Tests for DynamicContextProvider interface. */ +public class DynamicContextProviderTest { + @TempDir Path tempDir; + + private BufferAllocator allocator; + + @BeforeEach + void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + void tearDown() { + if (allocator != null) { + allocator.close(); + } + } + + @Test + void testDirectoryNamespaceWithContextProvider() { + AtomicInteger callCount = new AtomicInteger(0); + + DynamicContextProvider provider = + (operation, objectId) -> { + callCount.incrementAndGet(); + Map context = new HashMap<>(); + context.put("headers.Authorization", "Bearer test-token-123"); + context.put("headers.X-Request-Id", "req-" + operation); + return context; + }; + + try (DirectoryNamespace namespace = new DirectoryNamespace()) { + Map config = new HashMap<>(); + config.put("root", tempDir.toString()); + namespace.initialize(config, allocator, provider); + + // Perform operations to verify the provider is called + CreateNamespaceRequest createReq = + new CreateNamespaceRequest().id(Arrays.asList("workspace")); + namespace.createNamespace(createReq); + + ListNamespacesRequest listReq = new ListNamespacesRequest(); + namespace.listNamespaces(listReq); + + // The provider should have been called for each operation + // Note: DirectoryNamespace stores the provider but may not actively use context + // until the underlying Rust code is updated to use it for credential vending + assertNotNull(namespace.namespaceId()); + } + } + + @Test + void testDirectoryNamespaceWithNullProvider() { + try (DirectoryNamespace namespace = new DirectoryNamespace()) { + Map config = new HashMap<>(); + config.put("root", tempDir.toString()); + + // Should work with null provider (backward compatibility) + namespace.initialize(config, allocator, null); + + CreateNamespaceRequest createReq = + new CreateNamespaceRequest().id(Arrays.asList("workspace")); + namespace.createNamespace(createReq); + + ListNamespacesRequest listReq = new ListNamespacesRequest(); + ListNamespacesResponse listResp = namespace.listNamespaces(listReq); + + assertNotNull(listResp); + assertTrue(listResp.getNamespaces().contains("workspace")); + } + } + + @Test + void testContextProviderReturnsEmptyMap() { + DynamicContextProvider provider = (operation, objectId) -> new HashMap<>(); + + try (DirectoryNamespace namespace = new DirectoryNamespace()) { + Map config = new HashMap<>(); + config.put("root", tempDir.toString()); + namespace.initialize(config, allocator, provider); + + CreateNamespaceRequest createReq = + new CreateNamespaceRequest().id(Arrays.asList("workspace")); + CreateNamespaceResponse resp = namespace.createNamespace(createReq); + + assertNotNull(resp); + } + } + + @Test + void testRestNamespaceWithContextProviderIntegration() { + AtomicInteger callCount = new AtomicInteger(0); + + DynamicContextProvider provider = + (operation, objectId) -> { + callCount.incrementAndGet(); + Map context = new HashMap<>(); + context.put("headers.Authorization", "Bearer xyz-token"); + context.put("headers.X-Trace-Id", "trace-" + System.currentTimeMillis()); + return context; + }; + + // Start a test REST server with DirectoryNamespace backend + Map backendConfig = new HashMap<>(); + backendConfig.put("root", tempDir.toString()); + + try (RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", null)) { + adapter.start(); + int port = adapter.getPort(); + + // Create RestNamespace client with context provider + try (RestNamespace namespace = new RestNamespace()) { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + port); + namespace.initialize(clientConfig, allocator, provider); + + // Perform operations - context provider should be called + CreateNamespaceRequest createReq = + new CreateNamespaceRequest().id(Arrays.asList("workspace")); + namespace.createNamespace(createReq); + + ListNamespacesRequest listReq = new ListNamespacesRequest(); + ListNamespacesResponse listResp = namespace.listNamespaces(listReq); + + // Verify provider was called for REST operations + assertTrue(callCount.get() >= 2, "Context provider should be called for each operation"); + assertNotNull(listResp); + assertTrue(listResp.getNamespaces().contains("workspace")); + } + } + } + + @Test + void testContextProviderReceivesCorrectOperationInfo() { + Map capturedOperations = new HashMap<>(); + + DynamicContextProvider provider = + (operation, objectId) -> { + capturedOperations.put(operation, objectId); + return new HashMap<>(); + }; + + Map backendConfig = new HashMap<>(); + backendConfig.put("root", tempDir.toString()); + + try (RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", null)) { + adapter.start(); + int port = adapter.getPort(); + + try (RestNamespace namespace = new RestNamespace()) { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + port); + namespace.initialize(clientConfig, allocator, provider); + + // Create namespace + CreateNamespaceRequest createReq = + new CreateNamespaceRequest().id(Arrays.asList("workspace")); + namespace.createNamespace(createReq); + + // List namespaces + ListNamespacesRequest listReq = new ListNamespacesRequest(); + namespace.listNamespaces(listReq); + + // Verify operations were captured + assertTrue(capturedOperations.containsKey("create_namespace")); + assertTrue(capturedOperations.containsKey("list_namespaces")); + } + } + } + + // ========================================================================== + // Class path based provider tests + // ========================================================================== + + @Test + void testDirectoryNamespaceWithClassPathProvider() { + try (DirectoryNamespace namespace = new DirectoryNamespace()) { + Map config = new HashMap<>(); + config.put("root", tempDir.toString()); + config.put("dynamic_context_provider.impl", "org.lance.namespace.TestContextProvider"); + config.put("dynamic_context_provider.token", "my-secret-token"); + config.put("dynamic_context_provider.prefix", "Token"); + + namespace.initialize(config, allocator); + + // Verify namespace works + CreateNamespaceRequest createReq = + new CreateNamespaceRequest().id(Arrays.asList("workspace")); + namespace.createNamespace(createReq); + + ListNamespacesRequest listReq = new ListNamespacesRequest(); + ListNamespacesResponse listResp = namespace.listNamespaces(listReq); + + assertNotNull(listResp); + assertTrue(listResp.getNamespaces().contains("workspace")); + } + } + + @Test + void testRestNamespaceWithClassPathProvider() { + Map backendConfig = new HashMap<>(); + backendConfig.put("root", tempDir.toString()); + + try (RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", null)) { + adapter.start(); + int port = adapter.getPort(); + + try (RestNamespace namespace = new RestNamespace()) { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + port); + clientConfig.put( + "dynamic_context_provider.impl", "org.lance.namespace.TestContextProvider"); + clientConfig.put("dynamic_context_provider.token", "secret-api-key"); + + namespace.initialize(clientConfig, allocator); + + CreateNamespaceRequest createReq = + new CreateNamespaceRequest().id(Arrays.asList("workspace")); + namespace.createNamespace(createReq); + + ListNamespacesRequest listReq = new ListNamespacesRequest(); + ListNamespacesResponse listResp = namespace.listNamespaces(listReq); + + assertNotNull(listResp); + assertTrue(listResp.getNamespaces().contains("workspace")); + } + } + } + + @Test + void testUnknownProviderClassThrowsException() { + try (DirectoryNamespace namespace = new DirectoryNamespace()) { + Map config = new HashMap<>(); + config.put("root", tempDir.toString()); + config.put("dynamic_context_provider.impl", "com.nonexistent.NonExistentProvider"); + + assertThrows( + IllegalArgumentException.class, + () -> namespace.initialize(config, allocator), + "Failed to load context provider class"); + } + } + + @Test + void testExplicitProviderTakesPrecedence() { + AtomicInteger explicitCallCount = new AtomicInteger(0); + + DynamicContextProvider explicitProvider = + (operation, objectId) -> { + explicitCallCount.incrementAndGet(); + Map ctx = new HashMap<>(); + ctx.put("headers.Authorization", "Bearer explicit"); + return ctx; + }; + + try (DirectoryNamespace namespace = new DirectoryNamespace()) { + Map config = new HashMap<>(); + config.put("root", tempDir.toString()); + // Even though we specify a class path, explicit provider should take precedence + config.put("dynamic_context_provider.impl", "org.lance.namespace.TestContextProvider"); + config.put("dynamic_context_provider.token", "ignored"); + + // Pass explicit provider - should take precedence over properties + namespace.initialize(config, allocator, explicitProvider); + + // Verify namespace works + CreateNamespaceRequest createReq = + new CreateNamespaceRequest().id(Arrays.asList("workspace")); + namespace.createNamespace(createReq); + + // Namespace should work + assertNotNull(namespace.namespaceId()); + } + } +} diff --git a/java/src/test/java/org/lance/namespace/TestContextProvider.java b/java/src/test/java/org/lance/namespace/TestContextProvider.java new file mode 100644 index 00000000000..4eea30c88c3 --- /dev/null +++ b/java/src/test/java/org/lance/namespace/TestContextProvider.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.namespace; + +import java.util.HashMap; +import java.util.Map; + +/** Test implementation of DynamicContextProvider for testing class path loading. */ +public class TestContextProvider implements DynamicContextProvider { + private final String token; + private final String prefix; + + public TestContextProvider(Map properties) { + this.token = properties.get("token"); + this.prefix = properties.getOrDefault("prefix", "Bearer"); + } + + @Override + public Map provideContext(String operation, String objectId) { + Map context = new HashMap<>(); + context.put("headers.Authorization", prefix + " " + token); + context.put("headers.X-Operation", operation); + return context; + } +} diff --git a/python/python/lance/namespace.py b/python/python/lance/namespace.py index 9b18e3ee215..9df0c451173 100644 --- a/python/python/lance/namespace.py +++ b/python/python/lance/namespace.py @@ -7,11 +7,13 @@ 1. Native Rust-backed namespace implementations (DirectoryNamespace, RestNamespace) 2. Storage options integration with LanceNamespace for automatic credential refresh 3. Plugin registry for external namespace implementations +4. Dynamic context provider registry for per-request context injection The LanceNamespace ABC interface is provided by the lance_namespace package. """ -from typing import Dict, List +from abc import ABC, abstractmethod +from typing import Dict, List, Optional from lance_namespace import ( CreateEmptyTableRequest, @@ -61,9 +63,148 @@ "RestNamespace", "RestAdapter", "LanceNamespaceStorageOptionsProvider", + "DynamicContextProvider", ] +# ============================================================================= +# Dynamic Context Provider +# ============================================================================= + + +class DynamicContextProvider(ABC): + """Abstract base class for dynamic context providers. + + Implementations provide per-request context (e.g., authentication headers) + based on the operation being performed. The provider is called synchronously + before each namespace operation. + + For RestNamespace, context keys that start with `headers.` are converted to + HTTP headers by stripping the prefix. For example, `{"headers.Authorization": + "Bearer token"}` becomes the `Authorization: Bearer token` header. + + Example + ------- + >>> # Define a provider class + >>> class MyProvider(DynamicContextProvider): + ... def __init__(self, api_key: str): + ... self.api_key = api_key + ... + ... def provide_context(self, info: dict) -> dict: + ... return { + ... "headers.Authorization": f"Bearer {self.api_key}", + ... } + ... + >>> # Create provider instance and use directly + >>> provider = MyProvider(api_key="secret") + >>> provider.provide_context({"operation": "list_tables", "object_id": "ns"}) + {'headers.Authorization': 'Bearer secret'} + """ + + @abstractmethod + def provide_context(self, info: Dict[str, str]) -> Dict[str, str]: + """Provide context for a namespace operation. + + Parameters + ---------- + info : dict + Information about the operation: + - operation: The operation name (e.g., "list_tables", "describe_table") + - object_id: The object identifier (namespace or table ID) + + Returns + ------- + dict + Context key-value pairs. For HTTP headers, use keys with the + "headers." prefix (e.g., "headers.Authorization"). + """ + pass + + +def _create_context_provider_from_properties( + properties: Dict[str, str], +) -> Optional[DynamicContextProvider]: + """Create a context provider instance from properties. + + Extracts `dynamic_context_provider.*` properties and creates a provider + instance by dynamically loading the class from the given class path. + + Parameters + ---------- + properties : dict + The full properties dict that may contain dynamic_context_provider.* keys. + + Returns + ------- + DynamicContextProvider or None + The created provider instance, or None if no provider is configured. + + Raises + ------ + ValueError + If dynamic_context_provider.impl is set but the class cannot be loaded. + """ + import importlib + + prefix = "dynamic_context_provider." + impl_key = "dynamic_context_provider.impl" + + impl_path = properties.get(impl_key) + if not impl_path: + return None + + # Parse the class path (e.g., "my_module.submodule.MyClass") + if "." not in impl_path: + raise ValueError( + f"Invalid context provider class path '{impl_path}'. " + f"Expected format: 'module.ClassName' (e.g., 'my_module.MyProvider')" + ) + + module_path, class_name = impl_path.rsplit(".", 1) + + try: + module = importlib.import_module(module_path) + provider_class = getattr(module, class_name) + except ModuleNotFoundError as e: + raise ValueError( + f"Failed to import module '{module_path}' for context provider: {e}" + ) from e + except AttributeError as e: + raise ValueError( + f"Class '{class_name}' not found in module '{module_path}': {e}" + ) from e + + # Extract provider-specific properties (strip prefix, exclude impl key) + provider_props = {} + for key, value in properties.items(): + if key.startswith(prefix) and key != impl_key: + prop_name = key[len(prefix) :] + provider_props[prop_name] = value + + # Create the provider instance + return provider_class(**provider_props) + + +def _filter_context_provider_properties(properties: Dict[str, str]) -> Dict[str, str]: + """Remove dynamic_context_provider.* properties from the dict. + + These properties are handled at the Python level and should not be + passed to the Rust layer. + + Parameters + ---------- + properties : dict + The full properties dict. + + Returns + ------- + dict + Properties with dynamic_context_provider.* keys removed. + """ + prefix = "dynamic_context_provider." + return {k: v for k, v in properties.items() if not k.startswith(prefix)} + + class DirectoryNamespace(LanceNamespace): """Directory-based Lance Namespace implementation backed by Rust. @@ -140,14 +281,40 @@ class DirectoryNamespace(LanceNamespace): ... "credential_vendor.aws_role_arn": "arn:aws:iam::123456789012:role/MyRole", ... "credential_vendor.aws_duration_millis": "3600000", ... }) + + With dynamic context provider: + + >>> import tempfile + >>> class MyProvider(DynamicContextProvider): + ... def __init__(self, token: str): + ... self.token = token + ... def provide_context(self, info: dict) -> dict: + ... return {"headers.Authorization": f"Bearer {self.token}"} + ... + >>> provider = MyProvider(token="secret-token") + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... ns = lance.namespace.DirectoryNamespace( + ... root=tmpdir, + ... context_provider=provider, + ... ) + ... _ = ns.namespace_id() # verify it works """ - def __init__(self, session=None, **properties): + def __init__(self, session=None, context_provider=None, **properties): # Convert all values to strings as expected by Rust from_properties str_properties = {str(k): str(v) for k, v in properties.items()} + # Create context provider from properties if configured + if context_provider is None: + context_provider = _create_context_provider_from_properties(str_properties) + + # Filter out dynamic_context_provider.* properties before passing to Rust + filtered_properties = _filter_context_provider_properties(str_properties) + # Create the underlying Rust namespace - self._inner = PyDirectoryNamespace(session=session, **str_properties) + self._inner = PyDirectoryNamespace( + session=session, context_provider=context_provider, **filtered_properties + ) def namespace_id(self) -> str: """Return a human-readable unique identifier for this namespace instance.""" @@ -254,9 +421,25 @@ class RestNamespace(LanceNamespace): >>> # Using the connect() factory function from lance_namespace >>> import lance_namespace >>> ns = lance_namespace.connect("rest", {"uri": "http://localhost:4099"}) + + With dynamic context provider: + + >>> class AuthProvider(DynamicContextProvider): + ... def __init__(self, api_key: str): + ... self.api_key = api_key + ... def provide_context(self, info: dict) -> dict: + ... return {"headers.Authorization": f"Bearer {self.api_key}"} + ... + >>> provider = AuthProvider(api_key="my-secret-key") + >>> ns = lance.namespace.RestNamespace( + ... uri="http://localhost:4099", + ... context_provider=provider, + ... ) + >>> ns.namespace_id() # verify it works + 'RestNamespace { endpoint: "http://localhost:4099", delimiter: "$" }' """ - def __init__(self, **properties): + def __init__(self, context_provider=None, **properties): if PyRestNamespace is None: raise RuntimeError( "RestNamespace is not available. " @@ -266,8 +449,17 @@ def __init__(self, **properties): # Convert all values to strings as expected by Rust from_properties str_properties = {str(k): str(v) for k, v in properties.items()} + # Create context provider from properties if configured + if context_provider is None: + context_provider = _create_context_provider_from_properties(str_properties) + + # Filter out dynamic_context_provider.* properties before passing to Rust + filtered_properties = _filter_context_provider_properties(str_properties) + # Create the underlying Rust namespace - self._inner = PyRestNamespace(**str_properties) + self._inner = PyRestNamespace( + context_provider=context_provider, **filtered_properties + ) def namespace_id(self) -> str: """Return a human-readable unique identifier for this namespace instance.""" diff --git a/python/python/tests/test_namespace_rest.py b/python/python/tests/test_namespace_rest.py index 7fa3a65c5f1..de1a57ace8d 100644 --- a/python/python/tests/test_namespace_rest.py +++ b/python/python/tests/test_namespace_rest.py @@ -680,3 +680,66 @@ def test_connect_with_custom_delimiter(self): ipc_data = table_to_ipc_bytes(table_data) response = ns.create_table(create_req, ipc_data) assert response is not None + + +class TestDynamicContextProvider: + """Tests for DynamicContextProvider with RestNamespace.""" + + def test_rest_namespace_with_explicit_provider(self): + """Test RestNamespace with an explicit context provider.""" + call_count = {"count": 0} + + class TestProvider(lance.namespace.DynamicContextProvider): + def provide_context(self, info): + call_count["count"] += 1 + return { + "headers.Authorization": "Bearer test-token", + "headers.X-Request-Id": f"req-{info.get('operation', 'unknown')}", + } + + with tempfile.TemporaryDirectory() as tmpdir: + backend_config = {"root": tmpdir} + + with lance.namespace.RestAdapter("dir", backend_config, port=0) as adapter: + ns = lance.namespace.RestNamespace( + uri=f"http://127.0.0.1:{adapter.port}", + context_provider=TestProvider(), + ) + + # Perform operations + create_req = CreateNamespaceRequest(id=["workspace"]) + ns.create_namespace(create_req) + + list_req = ListNamespacesRequest(id=[]) + ns.list_namespaces(list_req) + + # Context provider should have been called + assert call_count["count"] >= 2 + + def test_explicit_provider_takes_precedence(self): + """Test that explicit provider takes precedence over class path.""" + explicit_called = {"called": False} + + class ExplicitProvider(lance.namespace.DynamicContextProvider): + def provide_context(self, info): + explicit_called["called"] = True + return {"headers.Authorization": "Bearer explicit"} + + with tempfile.TemporaryDirectory() as tmpdir: + backend_config = {"root": tmpdir} + + with lance.namespace.RestAdapter("dir", backend_config, port=0) as adapter: + # Pass both explicit provider and class path - explicit should win + ns = lance.namespace.RestNamespace( + context_provider=ExplicitProvider(), + **{ + "uri": f"http://127.0.0.1:{adapter.port}", + "dynamic_context_provider.impl": "nonexistent.Provider", + }, + ) + + create_req = CreateNamespaceRequest(id=["workspace"]) + ns.create_namespace(create_req) + + # Explicit provider should have been used + assert explicit_called["called"] diff --git a/python/src/namespace.rs b/python/src/namespace.rs index f26574f221a..e37e5710ba3 100644 --- a/python/src/namespace.rs +++ b/python/src/namespace.rs @@ -7,11 +7,11 @@ use std::collections::HashMap; use std::sync::Arc; use bytes::Bytes; -use lance_namespace_impls::DirectoryNamespaceBuilder; #[cfg(feature = "rest")] use lance_namespace_impls::RestNamespaceBuilder; #[cfg(feature = "rest-adapter")] use lance_namespace_impls::{ConnectBuilder, RestAdapter, RestAdapterConfig, RestAdapterHandle}; +use lance_namespace_impls::{DirectoryNamespaceBuilder, DynamicContextProvider, OperationInfo}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; use pythonize::{depythonize, pythonize}; @@ -19,6 +19,73 @@ use pythonize::{depythonize, pythonize}; use crate::error::PythonErrorExt; use crate::session::Session; +/// Python-implemented dynamic context provider. +/// +/// Wraps a Python object that has a `provide_context(info: dict) -> dict` method. +/// For RestNamespace, context keys that start with `headers.` are converted to +/// HTTP headers by stripping the prefix. +pub struct PyDynamicContextProvider { + provider: Py, +} + +impl Clone for PyDynamicContextProvider { + fn clone(&self) -> Self { + Python::attach(|py| Self { + provider: self.provider.clone_ref(py), + }) + } +} + +impl PyDynamicContextProvider { + /// Create a new Python context provider wrapper. + pub fn new(provider: Py) -> Self { + Self { provider } + } +} + +impl std::fmt::Debug for PyDynamicContextProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "PyDynamicContextProvider") + } +} + +impl DynamicContextProvider for PyDynamicContextProvider { + fn provide_context(&self, info: &OperationInfo) -> HashMap { + Python::attach(|py| { + // Create Python dict for operation info + let py_info = PyDict::new(py); + if py_info.set_item("operation", &info.operation).is_err() { + return HashMap::new(); + } + if py_info.set_item("object_id", &info.object_id).is_err() { + return HashMap::new(); + } + + // Call the provider's provide_context method + let result = self + .provider + .call_method1(py, "provide_context", (py_info,)); + + match result { + Ok(headers_py) => { + // Convert Python dict to Rust HashMap + let bound_headers = headers_py.bind(py); + if let Ok(dict) = bound_headers.downcast::() { + dict_to_hashmap(dict).unwrap_or_default() + } else { + log::warn!("Context provider did not return a dict"); + HashMap::new() + } + } + Err(e) => { + log::error!("Failed to call context provider: {}", e); + HashMap::new() + } + } + }) + } +} + /// Convert Python dict to HashMap fn dict_to_hashmap(dict: &Bound<'_, PyDict>) -> PyResult> { let mut map = HashMap::new(); @@ -39,10 +106,18 @@ pub struct PyDirectoryNamespace { #[pymethods] impl PyDirectoryNamespace { /// Create a new DirectoryNamespace from properties + /// + /// # Arguments + /// + /// * `session` - Optional Lance session for sharing storage connections + /// * `context_provider` - Optional object with `provide_context(info: dict) -> dict` method + /// for providing dynamic per-request context + /// * `**properties` - Namespace configuration properties #[new] - #[pyo3(signature = (session = None, **properties))] + #[pyo3(signature = (session = None, context_provider = None, **properties))] fn new( session: Option<&Bound<'_, Session>>, + context_provider: Option<&Bound<'_, PyAny>>, properties: Option<&Bound<'_, PyDict>>, ) -> PyResult { let mut props = HashMap::new(); @@ -53,7 +128,7 @@ impl PyDirectoryNamespace { let session_arc = session.map(|s| s.borrow().inner.clone()); - let builder = + let mut builder = DirectoryNamespaceBuilder::from_properties(props, session_arc).map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( "Failed to create DirectoryNamespace: {}", @@ -61,6 +136,12 @@ impl PyDirectoryNamespace { )) })?; + // Add context provider if provided + if let Some(provider) = context_provider { + let py_provider = PyDynamicContextProvider::new(provider.clone().unbind()); + builder = builder.context_provider(Arc::new(py_provider)); + } + let namespace = crate::rt().block_on(None, builder.build())?.infer_error()?; Ok(Self { @@ -256,22 +337,39 @@ pub struct PyRestNamespace { #[pymethods] impl PyRestNamespace { /// Create a new RestNamespace from properties + /// + /// # Arguments + /// + /// * `context_provider` - Optional object with `provide_context(info: dict) -> dict` method + /// for providing dynamic per-request context. Context keys that start with `headers.` + /// are converted to HTTP headers by stripping the prefix. For example, + /// `{"headers.Authorization": "Bearer token"}` becomes the `Authorization` header. + /// * `**properties` - Namespace configuration properties (uri, delimiter, header.*, etc.) #[new] - #[pyo3(signature = (**properties))] - fn new(properties: Option<&Bound<'_, PyDict>>) -> PyResult { + #[pyo3(signature = (context_provider = None, **properties))] + fn new( + context_provider: Option<&Bound<'_, PyAny>>, + properties: Option<&Bound<'_, PyDict>>, + ) -> PyResult { let mut props = HashMap::new(); if let Some(dict) = properties { props = dict_to_hashmap(dict)?; } - let builder = RestNamespaceBuilder::from_properties(props).map_err(|e| { + let mut builder = RestNamespaceBuilder::from_properties(props).map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( "Failed to create RestNamespace: {}", e )) })?; + // Add context provider if provided + if let Some(provider) = context_provider { + let py_provider = PyDynamicContextProvider::new(provider.clone().unbind()); + builder = builder.context_provider(Arc::new(py_provider)); + } + let namespace = builder.build(); Ok(Self { diff --git a/rust/lance-namespace-impls/Cargo.toml b/rust/lance-namespace-impls/Cargo.toml index 85ee4a6989f..b41e7f44e01 100644 --- a/rust/lance-namespace-impls/Cargo.toml +++ b/rust/lance-namespace-impls/Cargo.toml @@ -13,7 +13,7 @@ rust-version.workspace = true [features] default = ["dir-aws", "dir-azure", "dir-gcp", "dir-oss", "dir-huggingface"] -rest = ["dep:reqwest"] +rest = ["dep:reqwest", "dep:serde"] rest-adapter = ["dep:axum", "dep:tower", "dep:tower-http", "dep:serde"] # Cloud storage features for directory implementation - align with lance-io dir-gcp = ["lance-io/gcp", "lance/gcp"] diff --git a/rust/lance-namespace-impls/src/connect.rs b/rust/lance-namespace-impls/src/connect.rs index aa84e2fd6c1..ba26fda3643 100644 --- a/rust/lance-namespace-impls/src/connect.rs +++ b/rust/lance-namespace-impls/src/connect.rs @@ -10,6 +10,8 @@ use lance::session::Session; use lance_core::{Error, Result}; use lance_namespace::LanceNamespace; +use crate::context::DynamicContextProvider; + /// Builder for creating Lance namespace connections. /// /// This builder provides a fluent API for configuring and establishing @@ -46,11 +48,53 @@ use lance_namespace::LanceNamespace; /// # Ok(()) /// # } /// ``` -#[derive(Debug, Clone)] +/// +/// ## With Dynamic Context Provider +/// +/// ```no_run +/// # use lance_namespace_impls::{ConnectBuilder, DynamicContextProvider, OperationInfo}; +/// # use std::collections::HashMap; +/// # use std::sync::Arc; +/// # async fn example() -> Result<(), Box> { +/// #[derive(Debug)] +/// struct MyProvider; +/// +/// impl DynamicContextProvider for MyProvider { +/// fn provide_context(&self, info: &OperationInfo) -> HashMap { +/// let mut ctx = HashMap::new(); +/// ctx.insert("headers.Authorization".to_string(), "Bearer token".to_string()); +/// ctx +/// } +/// } +/// +/// let namespace = ConnectBuilder::new("rest") +/// .property("uri", "https://api.example.com") +/// .context_provider(Arc::new(MyProvider)) +/// .connect() +/// .await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone)] pub struct ConnectBuilder { impl_name: String, properties: HashMap, session: Option>, + context_provider: Option>, +} + +impl std::fmt::Debug for ConnectBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConnectBuilder") + .field("impl_name", &self.impl_name) + .field("properties", &self.properties) + .field("session", &self.session) + .field( + "context_provider", + &self.context_provider.as_ref().map(|_| "Some(...)"), + ) + .finish() + } } impl ConnectBuilder { @@ -64,6 +108,7 @@ impl ConnectBuilder { impl_name: impl_name.into(), properties: HashMap::new(), session: None, + context_provider: None, } } @@ -102,6 +147,20 @@ impl ConnectBuilder { self } + /// Set a dynamic context provider for per-request context. + /// + /// The provider will be called before each operation to generate + /// additional context. For RestNamespace, context keys that start with + /// `headers.` are converted to HTTP headers by stripping the prefix. + /// + /// # Arguments + /// + /// * `provider` - The context provider implementation + pub fn context_provider(mut self, provider: Arc) -> Self { + self.context_provider = Some(provider); + self + } + /// Build and establish the connection to the namespace. /// /// # Returns @@ -119,8 +178,12 @@ impl ConnectBuilder { #[cfg(feature = "rest")] "rest" => { // Create REST implementation (REST doesn't use session) - crate::rest::RestNamespaceBuilder::from_properties(self.properties) - .map(|builder| Arc::new(builder.build()) as Arc) + let mut builder = + crate::rest::RestNamespaceBuilder::from_properties(self.properties)?; + if let Some(provider) = self.context_provider { + builder = builder.context_provider(provider); + } + Ok(Arc::new(builder.build()) as Arc) } #[cfg(not(feature = "rest"))] "rest" => Err(Error::Namespace { @@ -130,13 +193,17 @@ impl ConnectBuilder { }), "dir" => { // Create directory implementation (always available) - crate::dir::DirectoryNamespaceBuilder::from_properties( + let mut builder = crate::dir::DirectoryNamespaceBuilder::from_properties( self.properties, self.session, - )? - .build() - .await - .map(|ns| Arc::new(ns) as Arc) + )?; + if let Some(provider) = self.context_provider { + builder = builder.context_provider(provider); + } + builder + .build() + .await + .map(|ns| Arc::new(ns) as Arc) } _ => Err(Error::Namespace { source: format!( diff --git a/rust/lance-namespace-impls/src/context.rs b/rust/lance-namespace-impls/src/context.rs new file mode 100644 index 00000000000..028eb342bac --- /dev/null +++ b/rust/lance-namespace-impls/src/context.rs @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Dynamic context provider for per-request context overrides. +//! +//! This module provides the [`DynamicContextProvider`] trait that enables +//! per-request context injection (e.g., dynamic authentication headers). +//! +//! ## Usage +//! +//! Implement the trait and pass to namespace builders: +//! +//! ```ignore +//! use lance_namespace_impls::{RestNamespaceBuilder, DynamicContextProvider, OperationInfo}; +//! use std::collections::HashMap; +//! use std::sync::Arc; +//! +//! #[derive(Debug)] +//! struct MyProvider; +//! +//! impl DynamicContextProvider for MyProvider { +//! fn provide_context(&self, info: &OperationInfo) -> HashMap { +//! let mut context = HashMap::new(); +//! context.insert("headers.Authorization".to_string(), format!("Bearer {}", get_current_token())); +//! context.insert("headers.X-Request-Id".to_string(), generate_request_id()); +//! context +//! } +//! } +//! +//! let namespace = RestNamespaceBuilder::new("https://api.example.com") +//! .context_provider(Arc::new(MyProvider)) +//! .build(); +//! ``` +//! +//! For RestNamespace, context keys that start with `headers.` are converted to HTTP headers +//! by stripping the prefix. For example, `{"headers.Authorization": "Bearer abc123"}` +//! becomes the `Authorization: Bearer abc123` header. Keys without the `headers.` prefix +//! are ignored for HTTP headers but may be used for other purposes. + +use std::collections::HashMap; + +/// Information about the namespace operation being executed. +/// +/// This is passed to the [`DynamicContextProvider`] to allow it to make +/// context decisions based on the operation. +#[derive(Debug, Clone)] +pub struct OperationInfo { + /// The operation name (e.g., "list_tables", "describe_table", "create_namespace") + pub operation: String, + /// The object ID for the operation (namespace or table identifier). + /// This is the delimited string form, e.g., "workspace$table_name". + pub object_id: String, +} + +impl OperationInfo { + /// Create a new OperationInfo. + pub fn new(operation: impl Into, object_id: impl Into) -> Self { + Self { + operation: operation.into(), + object_id: object_id.into(), + } + } +} + +/// Trait for providing dynamic request context. +/// +/// Implementations can generate per-request context (e.g., authentication headers) +/// based on the operation being performed. The provider is called synchronously +/// before each namespace operation. +/// +/// For RestNamespace, context keys that start with `headers.` are converted to +/// HTTP headers by stripping the prefix. For example, `{"headers.Authorization": "Bearer token"}` +/// becomes the `Authorization: Bearer token` header. +/// +/// ## Thread Safety +/// +/// Implementations must be `Send + Sync` as the provider may be called from +/// multiple threads concurrently. +/// +/// ## Error Handling +/// +/// If the provider needs to signal an error, it should return an empty HashMap +/// and log the error. The namespace operation will proceed without the +/// additional context. +pub trait DynamicContextProvider: Send + Sync + std::fmt::Debug { + /// Provide context for a namespace operation. + /// + /// # Arguments + /// + /// * `info` - Information about the operation being performed + /// + /// # Returns + /// + /// Returns a HashMap of context key-value pairs. For HTTP headers, use keys + /// with the `headers.` prefix (e.g., `headers.Authorization`). + /// Returns an empty HashMap if no additional context is needed. + fn provide_context(&self, info: &OperationInfo) -> HashMap; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug)] + struct MockContextProvider { + prefix: String, + } + + impl DynamicContextProvider for MockContextProvider { + fn provide_context(&self, info: &OperationInfo) -> HashMap { + let mut context = HashMap::new(); + context.insert( + "test-header".to_string(), + format!("{}-{}", self.prefix, info.operation), + ); + context.insert("object-id".to_string(), info.object_id.clone()); + context + } + } + + #[test] + fn test_operation_info_creation() { + let info = OperationInfo::new("describe_table", "workspace$my_table"); + assert_eq!(info.operation, "describe_table"); + assert_eq!(info.object_id, "workspace$my_table"); + } + + #[test] + fn test_context_provider_basic() { + let provider = MockContextProvider { + prefix: "test".to_string(), + }; + + let info = OperationInfo::new("list_tables", "workspace$ns"); + + let context = provider.provide_context(&info); + assert_eq!( + context.get("test-header"), + Some(&"test-list_tables".to_string()) + ); + assert_eq!(context.get("object-id"), Some(&"workspace$ns".to_string())); + } + + #[test] + fn test_empty_context() { + #[derive(Debug)] + struct EmptyProvider; + + impl DynamicContextProvider for EmptyProvider { + fn provide_context(&self, _info: &OperationInfo) -> HashMap { + HashMap::new() + } + } + + let provider = EmptyProvider; + let info = OperationInfo::new("list_tables", "ns"); + + let context = provider.provide_context(&info); + assert!(context.is_empty()); + } +} diff --git a/rust/lance-namespace-impls/src/dir.rs b/rust/lance-namespace-impls/src/dir.rs index 2168324a308..4d6a88419ee 100644 --- a/rust/lance-namespace-impls/src/dir.rs +++ b/rust/lance-namespace-impls/src/dir.rs @@ -21,6 +21,7 @@ use std::collections::HashMap; use std::io::Cursor; use std::sync::Arc; +use crate::context::DynamicContextProvider; use lance_namespace::models::{ CreateEmptyTableRequest, CreateEmptyTableResponse, CreateNamespaceRequest, CreateNamespaceResponse, CreateTableRequest, CreateTableResponse, DeclareTableRequest, @@ -85,7 +86,7 @@ pub(crate) struct TableStatus { /// # Ok(()) /// # } /// ``` -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct DirectoryNamespaceBuilder { root: String, storage_options: Option>, @@ -94,6 +95,26 @@ pub struct DirectoryNamespaceBuilder { dir_listing_enabled: bool, inline_optimization_enabled: bool, credential_vendor_properties: HashMap, + context_provider: Option>, +} + +impl std::fmt::Debug for DirectoryNamespaceBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DirectoryNamespaceBuilder") + .field("root", &self.root) + .field("storage_options", &self.storage_options) + .field("manifest_enabled", &self.manifest_enabled) + .field("dir_listing_enabled", &self.dir_listing_enabled) + .field( + "inline_optimization_enabled", + &self.inline_optimization_enabled, + ) + .field( + "context_provider", + &self.context_provider.as_ref().map(|_| "Some(...)"), + ) + .finish() + } } impl DirectoryNamespaceBuilder { @@ -111,6 +132,7 @@ impl DirectoryNamespaceBuilder { dir_listing_enabled: true, // Default to enabled for backwards compatibility inline_optimization_enabled: true, credential_vendor_properties: HashMap::new(), + context_provider: None, } } @@ -271,6 +293,7 @@ impl DirectoryNamespaceBuilder { dir_listing_enabled, inline_optimization_enabled, credential_vendor_properties, + context_provider: None, }) } @@ -362,6 +385,20 @@ impl DirectoryNamespaceBuilder { self } + /// Set a dynamic context provider for per-request context. + /// + /// The provider can be used to generate additional context for operations. + /// For DirectoryNamespace, the context is stored but not directly used + /// in operations (unlike RestNamespace where it's converted to HTTP headers). + /// + /// # Arguments + /// + /// * `provider` - The context provider implementation + pub fn context_provider(mut self, provider: Arc) -> Self { + self.context_provider = Some(provider); + self + } + /// Build the DirectoryNamespace. /// /// # Returns @@ -423,6 +460,7 @@ impl DirectoryNamespaceBuilder { manifest_ns, dir_listing_enabled: self.dir_listing_enabled, credential_vendor, + context_provider: self.context_provider, }) } @@ -492,6 +530,10 @@ pub struct DirectoryNamespace { /// Credential vendor created once during initialization. /// Used to vend temporary credentials for table access. credential_vendor: Option>, + /// Dynamic context provider for per-request context. + /// Stored but not directly used in operations (available for future extensions). + #[allow(dead_code)] + context_provider: Option>, } impl std::fmt::Debug for DirectoryNamespace { diff --git a/rust/lance-namespace-impls/src/lib.rs b/rust/lance-namespace-impls/src/lib.rs index 88248841bcb..83fb93ddc0e 100644 --- a/rust/lance-namespace-impls/src/lib.rs +++ b/rust/lance-namespace-impls/src/lib.rs @@ -69,6 +69,7 @@ //! ``` pub mod connect; +pub mod context; pub mod credentials; pub mod dir; @@ -80,6 +81,7 @@ pub mod rest_adapter; // Re-export connect builder pub use connect::ConnectBuilder; +pub use context::{DynamicContextProvider, OperationInfo}; pub use dir::{manifest::ManifestNamespace, DirectoryNamespace, DirectoryNamespaceBuilder}; // Re-export credential vending diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs index 020746487a4..0eae07e4ce2 100644 --- a/rust/lance-namespace-impls/src/rest.rs +++ b/rust/lance-namespace-impls/src/rest.rs @@ -4,13 +4,16 @@ //! REST implementation of Lance Namespace use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; +use reqwest::header::{HeaderName, HeaderValue}; -use lance_namespace::apis::{ - configuration::Configuration, namespace_api, table_api, tag_api, transaction_api, -}; +use crate::context::{DynamicContextProvider, OperationInfo}; + +use lance_namespace::apis::urlencode; use lance_namespace::models::{ AlterTableAddColumnsRequest, AlterTableAddColumnsResponse, AlterTableAlterColumnsRequest, AlterTableAlterColumnsResponse, AlterTableDropColumnsRequest, AlterTableDropColumnsResponse, @@ -36,11 +39,102 @@ use lance_namespace::models::{ UpdateTableRequest, UpdateTableResponse, UpdateTableSchemaMetadataRequest, UpdateTableSchemaMetadataResponse, UpdateTableTagRequest, UpdateTableTagResponse, }; +use serde::{de::DeserializeOwned, Serialize}; use lance_core::{box_error, Error, Result}; use lance_namespace::LanceNamespace; +/// HTTP client wrapper that supports per-request header injection. +/// +/// This client wraps a single `reqwest::Client` and applies dynamic headers +/// to each request without recreating the client. This is more efficient than +/// creating a new client per request when using a `DynamicContextProvider`. +/// +/// The design follows lancedb's `RestfulLanceDbClient` pattern where headers +/// are applied to the built request using `headers_mut()` before execution. +#[derive(Clone)] +struct RestClient { + client: reqwest::Client, + base_path: String, + base_headers: HashMap, + context_provider: Option>, +} + +impl std::fmt::Debug for RestClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RestClient") + .field("base_path", &self.base_path) + .field("base_headers", &self.base_headers) + .field( + "context_provider", + &self.context_provider.as_ref().map(|_| "Some(...)"), + ) + .finish() + } +} + +impl RestClient { + /// Apply base headers and dynamic context headers to a request. + /// + /// This method mutates the request's headers directly, which is more efficient + /// than creating a new client with default_headers for each request. + fn apply_headers(&self, request: &mut reqwest::Request, operation: &str, object_id: &str) { + let request_headers = request.headers_mut(); + + // First apply base headers + for (key, value) in &self.base_headers { + if let (Ok(header_name), Ok(header_value)) = + (HeaderName::from_str(key), HeaderValue::from_str(value)) + { + request_headers.insert(header_name, header_value); + } + } + + // Then apply context headers (override base headers if conflict) + if let Some(provider) = &self.context_provider { + let info = OperationInfo::new(operation, object_id); + let context = provider.provide_context(&info); + + const HEADERS_PREFIX: &str = "headers."; + for (key, value) in context { + if let Some(header_name) = key.strip_prefix(HEADERS_PREFIX) { + if let (Ok(header_name), Ok(header_value)) = ( + HeaderName::from_str(header_name), + HeaderValue::from_str(&value), + ) { + request_headers.insert(header_name, header_value); + } + } + } + } + } + + /// Execute a request with dynamic headers applied. + /// + /// This method builds the request, applies headers, and executes it. + async fn execute( + &self, + req_builder: reqwest::RequestBuilder, + operation: &str, + object_id: &str, + ) -> std::result::Result { + let mut request = req_builder.build()?; + self.apply_headers(&mut request, operation, object_id); + self.client.execute(request).await + } + + /// Get the base path URL + fn base_path(&self) -> &str { + &self.base_path + } + + /// Get a reference to the underlying reqwest client + fn client(&self) -> &reqwest::Client { + &self.client + } +} + /// Builder for creating a RestNamespace. /// /// This builder provides a fluent API for configuring and establishing @@ -59,7 +153,7 @@ use lance_namespace::LanceNamespace; /// # Ok(()) /// # } /// ``` -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct RestNamespaceBuilder { uri: String, delimiter: String, @@ -68,6 +162,25 @@ pub struct RestNamespaceBuilder { key_file: Option, ssl_ca_cert: Option, assert_hostname: bool, + context_provider: Option>, +} + +impl std::fmt::Debug for RestNamespaceBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RestNamespaceBuilder") + .field("uri", &self.uri) + .field("delimiter", &self.delimiter) + .field("headers", &self.headers) + .field("cert_file", &self.cert_file) + .field("key_file", &self.key_file) + .field("ssl_ca_cert", &self.ssl_ca_cert) + .field("assert_hostname", &self.assert_hostname) + .field( + "context_provider", + &self.context_provider.as_ref().map(|_| "Some(...)"), + ) + .finish() + } } impl RestNamespaceBuilder { @@ -88,6 +201,7 @@ impl RestNamespaceBuilder { key_file: None, ssl_ca_cert: None, assert_hostname: true, + context_provider: None, } } @@ -172,6 +286,7 @@ impl RestNamespaceBuilder { key_file, ssl_ca_cert, assert_hostname, + context_provider: None, }) } @@ -246,6 +361,44 @@ impl RestNamespaceBuilder { self } + /// Set a dynamic context provider for per-request context. + /// + /// The provider will be called before each HTTP request to generate + /// additional context. Context keys that start with `headers.` are converted + /// to HTTP headers by stripping the prefix. For example, `headers.Authorization` + /// becomes the `Authorization` header. Keys without the `headers.` prefix are ignored. + /// + /// # Arguments + /// + /// * `provider` - The context provider implementation + /// + /// # Examples + /// + /// ```ignore + /// use lance_namespace_impls::{RestNamespaceBuilder, DynamicContextProvider, OperationInfo}; + /// use std::collections::HashMap; + /// use std::sync::Arc; + /// + /// #[derive(Debug)] + /// struct MyProvider; + /// + /// impl DynamicContextProvider for MyProvider { + /// fn provide_context(&self, info: &OperationInfo) -> HashMap { + /// let mut ctx = HashMap::new(); + /// ctx.insert("auth-token".to_string(), "my-token".to_string()); + /// ctx + /// } + /// } + /// + /// let namespace = RestNamespaceBuilder::new("http://localhost:8080") + /// .context_provider(Arc::new(MyProvider)) + /// .build(); + /// ``` + pub fn context_provider(mut self, provider: Arc) -> Self { + self.context_provider = Some(provider); + self + } + /// Build the RestNamespace. /// /// # Returns @@ -268,29 +421,6 @@ fn object_id_str(id: &Option>, delimiter: &str) -> Result { } } -/// Convert API error to lance core error -fn convert_api_error(err: lance_namespace::apis::Error) -> Error { - use lance_namespace::apis::Error as ApiError; - match err { - ApiError::Reqwest(e) => Error::IO { - source: box_error(e), - location: snafu::location!(), - }, - ApiError::Serde(e) => Error::Namespace { - source: format!("Serialization error: {}", e).into(), - location: snafu::location!(), - }, - ApiError::Io(e) => Error::IO { - source: box_error(e), - location: snafu::location!(), - }, - ApiError::ResponseError(e) => Error::Namespace { - source: format!("Response error: {:?}", e).into(), - location: snafu::location!(), - }, - } -} - /// REST implementation of Lance Namespace /// /// # Examples @@ -307,7 +437,8 @@ fn convert_api_error(err: lance_namespace::apis::Error) - #[derive(Clone)] pub struct RestNamespace { delimiter: String, - reqwest_config: Configuration, + /// REST client that handles per-request header injection efficiently. + rest_client: RestClient, } impl std::fmt::Debug for RestNamespace { @@ -325,23 +456,9 @@ impl std::fmt::Display for RestNamespace { impl RestNamespace { /// Create a new REST namespace from builder pub(crate) fn from_builder(builder: RestNamespaceBuilder) -> Self { - // Build reqwest client with custom headers if provided + // Build reqwest client WITHOUT default headers - we'll apply headers per-request let mut client_builder = reqwest::Client::builder(); - // Add custom headers to the client - if !builder.headers.is_empty() { - let mut headers = reqwest::header::HeaderMap::new(); - for (key, value) in &builder.headers { - if let (Ok(header_name), Ok(header_value)) = ( - reqwest::header::HeaderName::from_bytes(key.as_bytes()), - reqwest::header::HeaderValue::from_str(value), - ) { - headers.insert(header_name, header_value); - } - } - client_builder = client_builder.default_headers(headers); - } - // Configure mTLS if certificate and key files are provided if let (Some(cert_file), Some(key_file)) = (&builder.cert_file, &builder.key_file) { if let (Ok(cert), Ok(key)) = (std::fs::read(cert_file), std::fs::read(key_file)) { @@ -367,28 +484,218 @@ impl RestNamespace { .build() .unwrap_or_else(|_| reqwest::Client::new()); - let mut reqwest_config = Configuration::new(); - reqwest_config.client = client; - reqwest_config.base_path = builder.uri; + // Create the RestClient that handles per-request header injection + let rest_client = RestClient { + client, + base_path: builder.uri, + base_headers: builder.headers, + context_provider: builder.context_provider, + }; Self { delimiter: builder.delimiter, - reqwest_config, + rest_client, } } - /// Create a new REST namespace with custom configuration (for testing) - #[cfg(test)] - pub fn with_configuration(delimiter: String, reqwest_config: Configuration) -> Self { - Self { - delimiter, - reqwest_config, + /// Execute a GET request and parse JSON response. + async fn get_json( + &self, + path: &str, + query: &[(&str, &str)], + operation: &str, + object_id: &str, + ) -> Result { + let url = format!("{}{}", self.rest_client.base_path(), path); + let req_builder = self.rest_client.client().get(&url).query(query); + + let resp = self + .rest_client + .execute(req_builder, operation, object_id) + .await + .map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + + let status = resp.status(); + let content = resp.text().await.map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + + if status.is_success() { + serde_json::from_str(&content).map_err(|e| Error::Namespace { + source: format!("Failed to parse response: {}", e).into(), + location: snafu::location!(), + }) + } else { + Err(Error::Namespace { + source: format!("Response error: status={}, content={}", status, content).into(), + location: snafu::location!(), + }) + } + } + + /// Execute a POST request with JSON body and parse JSON response. + async fn post_json( + &self, + path: &str, + query: &[(&str, &str)], + body: &T, + operation: &str, + object_id: &str, + ) -> Result { + let url = format!("{}{}", self.rest_client.base_path(), path); + let req_builder = self.rest_client.client().post(&url).query(query).json(body); + + let resp = self + .rest_client + .execute(req_builder, operation, object_id) + .await + .map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + + let status = resp.status(); + let content = resp.text().await.map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + + if status.is_success() { + serde_json::from_str(&content).map_err(|e| Error::Namespace { + source: format!("Failed to parse response: {}", e).into(), + location: snafu::location!(), + }) + } else { + Err(Error::Namespace { + source: format!("Response error: status={}, content={}", status, content).into(), + location: snafu::location!(), + }) + } + } + + /// Execute a POST request that returns nothing (204 No Content expected). + async fn post_json_no_content( + &self, + path: &str, + query: &[(&str, &str)], + body: &T, + operation: &str, + object_id: &str, + ) -> Result<()> { + let url = format!("{}{}", self.rest_client.base_path(), path); + let req_builder = self.rest_client.client().post(&url).query(query).json(body); + + let resp = self + .rest_client + .execute(req_builder, operation, object_id) + .await + .map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + + let status = resp.status(); + if status.is_success() { + Ok(()) + } else { + let content = resp.text().await.map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + Err(Error::Namespace { + source: format!("Response error: status={}, content={}", status, content).into(), + location: snafu::location!(), + }) + } + } + + /// Execute a POST request with binary body and parse JSON response. + async fn post_binary_json( + &self, + path: &str, + query: &[(&str, &str)], + body: Vec, + operation: &str, + object_id: &str, + ) -> Result { + let url = format!("{}{}", self.rest_client.base_path(), path); + let req_builder = self.rest_client.client().post(&url).query(query).body(body); + + let resp = self + .rest_client + .execute(req_builder, operation, object_id) + .await + .map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + + let status = resp.status(); + let content = resp.text().await.map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + + if status.is_success() { + serde_json::from_str(&content).map_err(|e| Error::Namespace { + source: format!("Failed to parse response: {}", e).into(), + location: snafu::location!(), + }) + } else { + Err(Error::Namespace { + source: format!("Response error: status={}, content={}", status, content).into(), + location: snafu::location!(), + }) + } + } + + /// Execute a POST request with JSON body and get binary response. + #[allow(dead_code)] + async fn post_json_binary( + &self, + path: &str, + query: &[(&str, &str)], + body: &T, + operation: &str, + object_id: &str, + ) -> Result { + let url = format!("{}{}", self.rest_client.base_path(), path); + let req_builder = self.rest_client.client().post(&url).query(query).json(body); + + let resp = self + .rest_client + .execute(req_builder, operation, object_id) + .await + .map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + + let status = resp.status(); + if status.is_success() { + resp.bytes().await.map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + }) + } else { + let content = resp.text().await.map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + Err(Error::Namespace { + source: format!("Response error: status={}, content={}", status, content).into(), + location: snafu::location!(), + }) } } /// Get the base endpoint URL for this namespace pub fn endpoint(&self) -> &str { - &self.reqwest_config.base_path + self.rest_client.base_path() } } @@ -399,16 +706,20 @@ impl LanceNamespace for RestNamespace { request: ListNamespacesRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - namespace_api::list_namespaces( - &self.reqwest_config, - &id, - Some(&self.delimiter), - request.page_token.as_deref(), - request.limit, - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/namespace/{}/list", encoded_id); + let mut query = vec![("delimiter", self.delimiter.as_str())]; + let page_token_str; + if let Some(ref pt) = request.page_token { + page_token_str = pt.clone(); + query.push(("page_token", page_token_str.as_str())); + } + let limit_str; + if let Some(limit) = request.limit { + limit_str = limit.to_string(); + query.push(("limit", limit_str.as_str())); + } + self.get_json(&path, &query, "list_namespaces", &id).await } async fn describe_namespace( @@ -416,10 +727,11 @@ impl LanceNamespace for RestNamespace { request: DescribeNamespaceRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - namespace_api::describe_namespace(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/namespace/{}/describe", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "describe_namespace", &id) .await - .map_err(convert_api_error) } async fn create_namespace( @@ -427,79 +739,93 @@ impl LanceNamespace for RestNamespace { request: CreateNamespaceRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - namespace_api::create_namespace(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/namespace/{}/create", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "create_namespace", &id) .await - .map_err(convert_api_error) } async fn drop_namespace(&self, request: DropNamespaceRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - namespace_api::drop_namespace(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/namespace/{}/drop", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "drop_namespace", &id) .await - .map_err(convert_api_error) } async fn namespace_exists(&self, request: NamespaceExistsRequest) -> Result<()> { let id = object_id_str(&request.id, &self.delimiter)?; - - namespace_api::namespace_exists(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/namespace/{}/exists", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json_no_content(&path, &query, &request, "namespace_exists", &id) .await - .map_err(convert_api_error) } async fn list_tables(&self, request: ListTablesRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::list_tables( - &self.reqwest_config, - &id, - Some(&self.delimiter), - request.page_token.as_deref(), - request.limit, - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/namespace/{}/table/list", encoded_id); + let mut query = vec![("delimiter", self.delimiter.as_str())]; + let page_token_str; + if let Some(ref pt) = request.page_token { + page_token_str = pt.clone(); + query.push(("page_token", page_token_str.as_str())); + } + let limit_str; + if let Some(limit) = request.limit { + limit_str = limit.to_string(); + query.push(("limit", limit_str.as_str())); + } + self.get_json(&path, &query, "list_tables", &id).await } async fn describe_table(&self, request: DescribeTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::describe_table( - &self.reqwest_config, - &id, - request.clone(), - Some(&self.delimiter), - request.with_table_uri, - request.load_detailed_metadata, - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/describe", encoded_id); + let mut query = vec![("delimiter", self.delimiter.as_str())]; + let with_uri_str; + if let Some(with_uri) = request.with_table_uri { + with_uri_str = with_uri.to_string(); + query.push(("with_table_uri", with_uri_str.as_str())); + } + let detailed_str; + if let Some(detailed) = request.load_detailed_metadata { + detailed_str = detailed.to_string(); + query.push(("load_detailed_metadata", detailed_str.as_str())); + } + self.post_json(&path, &query, &request, "describe_table", &id) + .await } async fn register_table(&self, request: RegisterTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::register_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/register", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "register_table", &id) .await - .map_err(convert_api_error) } async fn table_exists(&self, request: TableExistsRequest) -> Result<()> { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::table_exists(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/exists", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json_no_content(&path, &query, &request, "table_exists", &id) .await - .map_err(convert_api_error) } async fn drop_table(&self, request: DropTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::drop_table(&self.reqwest_config, &id, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/drop", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "drop_table", &id) .await - .map_err(convert_api_error) } async fn deregister_table( @@ -507,18 +833,19 @@ impl LanceNamespace for RestNamespace { request: DeregisterTableRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::deregister_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/deregister", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "deregister_table", &id) .await - .map_err(convert_api_error) } async fn count_table_rows(&self, request: CountTableRowsRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::count_table_rows(&self.reqwest_config, &id, request, Some(&self.delimiter)) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/count_rows", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.get_json(&path, &query, "count_table_rows", &id).await } async fn create_table( @@ -527,16 +854,16 @@ impl LanceNamespace for RestNamespace { request_data: Bytes, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::create_table( - &self.reqwest_config, - &id, - request_data.to_vec(), - Some(&self.delimiter), - request.mode.as_deref(), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/create", encoded_id); + let mut query = vec![("delimiter", self.delimiter.as_str())]; + let mode_str; + if let Some(ref mode) = request.mode { + mode_str = mode.clone(); + query.push(("mode", mode_str.as_str())); + } + self.post_binary_json(&path, &query, request_data.to_vec(), "create_table", &id) + .await } async fn create_empty_table( @@ -544,18 +871,20 @@ impl LanceNamespace for RestNamespace { request: CreateEmptyTableRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::create_empty_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/create-empty", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "create_empty_table", &id) .await - .map_err(convert_api_error) } async fn declare_table(&self, request: DeclareTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::declare_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/declare", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "declare_table", &id) .await - .map_err(convert_api_error) } async fn insert_into_table( @@ -564,16 +893,22 @@ impl LanceNamespace for RestNamespace { request_data: Bytes, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::insert_into_table( - &self.reqwest_config, - &id, + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/insert", encoded_id); + let mut query = vec![("delimiter", self.delimiter.as_str())]; + let mode_str; + if let Some(ref mode) = request.mode { + mode_str = mode.clone(); + query.push(("mode", mode_str.as_str())); + } + self.post_binary_json( + &path, + &query, request_data.to_vec(), - Some(&self.delimiter), - request.mode.as_deref(), + "insert_into_table", + &id, ) .await - .map_err(convert_api_error) } async fn merge_insert_into_table( @@ -582,36 +917,72 @@ impl LanceNamespace for RestNamespace { request_data: Bytes, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; + let encoded_id = urlencode(&id); let on = request.on.as_deref().ok_or_else(|| Error::Namespace { source: "'on' field is required for merge insert".into(), location: snafu::location!(), })?; - table_api::merge_insert_into_table( - &self.reqwest_config, - &id, - on, + let path = format!("/v1/table/{}/merge_insert", encoded_id); + let mut query = vec![("delimiter", self.delimiter.as_str()), ("on", on)]; + + let when_matched_update_all_str; + if let Some(v) = request.when_matched_update_all { + when_matched_update_all_str = v.to_string(); + query.push(( + "when_matched_update_all", + when_matched_update_all_str.as_str(), + )); + } + if let Some(ref v) = request.when_matched_update_all_filt { + query.push(("when_matched_update_all_filt", v.as_str())); + } + let when_not_matched_insert_all_str; + if let Some(v) = request.when_not_matched_insert_all { + when_not_matched_insert_all_str = v.to_string(); + query.push(( + "when_not_matched_insert_all", + when_not_matched_insert_all_str.as_str(), + )); + } + let when_not_matched_by_source_delete_str; + if let Some(v) = request.when_not_matched_by_source_delete { + when_not_matched_by_source_delete_str = v.to_string(); + query.push(( + "when_not_matched_by_source_delete", + when_not_matched_by_source_delete_str.as_str(), + )); + } + if let Some(ref v) = request.when_not_matched_by_source_delete_filt { + query.push(("when_not_matched_by_source_delete_filt", v.as_str())); + } + if let Some(ref v) = request.timeout { + query.push(("timeout", v.as_str())); + } + let use_index_str; + if let Some(v) = request.use_index { + use_index_str = v.to_string(); + query.push(("use_index", use_index_str.as_str())); + } + + self.post_binary_json( + &path, + &query, request_data.to_vec(), - Some(&self.delimiter), - request.when_matched_update_all, - request.when_matched_update_all_filt.as_deref(), - request.when_not_matched_insert_all, - request.when_not_matched_by_source_delete, - request.when_not_matched_by_source_delete_filt.as_deref(), - request.timeout.as_deref(), - request.use_index, + "merge_insert_into_table", + &id, ) .await - .map_err(convert_api_error) } async fn update_table(&self, request: UpdateTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::update_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/update", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "update_table", &id) .await - .map_err(convert_api_error) } async fn delete_from_table( @@ -619,27 +990,52 @@ impl LanceNamespace for RestNamespace { request: DeleteFromTableRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::delete_from_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/delete", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "delete_from_table", &id) .await - .map_err(convert_api_error) } async fn query_table(&self, request: QueryTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/query", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + + let url = format!("{}{}", self.rest_client.base_path(), path); + let req_builder = self + .rest_client + .client() + .post(&url) + .query(&query) + .json(&request); + + let resp = self + .rest_client + .execute(req_builder, "query_table", &id) + .await + .map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; - let response = - table_api::query_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) - .await - .map_err(convert_api_error)?; - - // Convert response to bytes - let bytes = response.bytes().await.map_err(|e| Error::IO { - source: box_error(e), - location: snafu::location!(), - })?; - - Ok(bytes) + let status = resp.status(); + if status.is_success() { + resp.bytes().await.map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + }) + } else { + let content = resp.text().await.map_err(|e| Error::IO { + source: box_error(e), + location: snafu::location!(), + })?; + Err(Error::Namespace { + source: format!("Response error: status={}, content={}", status, content).into(), + location: snafu::location!(), + }) + } } async fn create_table_index( @@ -647,10 +1043,11 @@ impl LanceNamespace for RestNamespace { request: CreateTableIndexRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::create_table_index(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/create_index", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "create_table_index", &id) .await - .map_err(convert_api_error) } async fn list_table_indices( @@ -658,10 +1055,11 @@ impl LanceNamespace for RestNamespace { request: ListTableIndicesRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::list_table_indices(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/index/list", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "list_table_indices", &id) .await - .map_err(convert_api_error) } async fn describe_table_index_stats( @@ -669,20 +1067,16 @@ impl LanceNamespace for RestNamespace { request: DescribeTableIndexStatsRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - // Note: The index_name parameter seems to be missing from the request structure - // This might need to be adjusted based on the actual API - let index_name = ""; // This should come from somewhere in the request - - table_api::describe_table_index_stats( - &self.reqwest_config, - &id, - index_name, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let index_name = request.index_name.as_deref().unwrap_or(""); + let path = format!( + "/v1/table/{}/index/{}/stats", + encoded_id, + urlencode(index_name) + ); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "describe_table_index_stats", &id) + .await } async fn describe_transaction( @@ -690,15 +1084,11 @@ impl LanceNamespace for RestNamespace { request: DescribeTransactionRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - transaction_api::describe_transaction( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/transaction/{}/describe", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "describe_transaction", &id) + .await } async fn alter_transaction( @@ -706,15 +1096,11 @@ impl LanceNamespace for RestNamespace { request: AlterTransactionRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - transaction_api::alter_transaction( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/transaction/{}/alter", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "alter_transaction", &id) + .await } async fn create_table_scalar_index( @@ -722,15 +1108,11 @@ impl LanceNamespace for RestNamespace { request: CreateTableIndexRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::create_table_scalar_index( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/create_scalar_index", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "create_table_scalar_index", &id) + .await } async fn drop_table_index( @@ -738,39 +1120,50 @@ impl LanceNamespace for RestNamespace { request: DropTableIndexRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - + let encoded_id = urlencode(&id); let index_name = request.index_name.as_deref().unwrap_or(""); - - table_api::drop_table_index(&self.reqwest_config, &id, index_name, Some(&self.delimiter)) + let path = format!( + "/v1/table/{}/index/{}/drop", + encoded_id, + urlencode(index_name) + ); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "drop_table_index", &id) .await - .map_err(convert_api_error) } async fn list_all_tables(&self, request: ListTablesRequest) -> Result { - table_api::list_all_tables( - &self.reqwest_config, - Some(&self.delimiter), - request.page_token.as_deref(), - request.limit, - ) - .await - .map_err(convert_api_error) + let path = "/v1/table"; + let mut query = vec![("delimiter", self.delimiter.as_str())]; + let page_token_str; + if let Some(ref pt) = request.page_token { + page_token_str = pt.clone(); + query.push(("page_token", page_token_str.as_str())); + } + let limit_str; + if let Some(limit) = request.limit { + limit_str = limit.to_string(); + query.push(("limit", limit_str.as_str())); + } + self.get_json(path, &query, "list_all_tables", "").await } async fn restore_table(&self, request: RestoreTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::restore_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/restore", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "restore_table", &id) .await - .map_err(convert_api_error) } async fn rename_table(&self, request: RenameTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::rename_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/rename", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "rename_table", &id) .await - .map_err(convert_api_error) } async fn list_table_versions( @@ -778,16 +1171,21 @@ impl LanceNamespace for RestNamespace { request: ListTableVersionsRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::list_table_versions( - &self.reqwest_config, - &id, - Some(&self.delimiter), - request.page_token.as_deref(), - request.limit, - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/version/list", encoded_id); + let mut query = vec![("delimiter", self.delimiter.as_str())]; + let page_token_str; + if let Some(ref pt) = request.page_token { + page_token_str = pt.clone(); + query.push(("page_token", page_token_str.as_str())); + } + let limit_str; + if let Some(limit) = request.limit { + limit_str = limit.to_string(); + query.push(("limit", limit_str.as_str())); + } + self.get_json(&path, &query, "list_table_versions", &id) + .await } async fn update_table_schema_metadata( @@ -795,18 +1193,19 @@ impl LanceNamespace for RestNamespace { request: UpdateTableSchemaMetadataRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/schema_metadata/update", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; let metadata = request.metadata.unwrap_or_default(); - - let result = table_api::update_table_schema_metadata( - &self.reqwest_config, - &id, - metadata, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error)?; - + let result: HashMap = self + .post_json( + &path, + &query, + &metadata, + "update_table_schema_metadata", + &id, + ) + .await?; Ok(UpdateTableSchemaMetadataResponse { metadata: Some(result), ..Default::default() @@ -818,10 +1217,11 @@ impl LanceNamespace for RestNamespace { request: GetTableStatsRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::get_table_stats(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/stats", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "get_table_stats", &id) .await - .map_err(convert_api_error) } async fn explain_table_query_plan( @@ -829,15 +1229,11 @@ impl LanceNamespace for RestNamespace { request: ExplainTableQueryPlanRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::explain_table_query_plan( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/explain_plan", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "explain_table_query_plan", &id) + .await } async fn analyze_table_query_plan( @@ -845,15 +1241,11 @@ impl LanceNamespace for RestNamespace { request: AnalyzeTableQueryPlanRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::analyze_table_query_plan( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/analyze_plan", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "analyze_table_query_plan", &id) + .await } async fn alter_table_add_columns( @@ -861,15 +1253,11 @@ impl LanceNamespace for RestNamespace { request: AlterTableAddColumnsRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::alter_table_add_columns( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/add_columns", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "alter_table_add_columns", &id) + .await } async fn alter_table_alter_columns( @@ -877,15 +1265,11 @@ impl LanceNamespace for RestNamespace { request: AlterTableAlterColumnsRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::alter_table_alter_columns( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/alter_columns", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "alter_table_alter_columns", &id) + .await } async fn alter_table_drop_columns( @@ -893,15 +1277,11 @@ impl LanceNamespace for RestNamespace { request: AlterTableDropColumnsRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - table_api::alter_table_drop_columns( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/drop_columns", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "alter_table_drop_columns", &id) + .await } async fn list_table_tags( @@ -909,16 +1289,20 @@ impl LanceNamespace for RestNamespace { request: ListTableTagsRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - tag_api::list_table_tags( - &self.reqwest_config, - &id, - Some(&self.delimiter), - request.page_token.as_deref(), - request.limit, - ) - .await - .map_err(convert_api_error) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/tags/list", encoded_id); + let mut query = vec![("delimiter", self.delimiter.as_str())]; + let page_token_str; + if let Some(ref pt) = request.page_token { + page_token_str = pt.clone(); + query.push(("page_token", page_token_str.as_str())); + } + let limit_str; + if let Some(limit) = request.limit { + limit_str = limit.to_string(); + query.push(("limit", limit_str.as_str())); + } + self.get_json(&path, &query, "list_table_tags", &id).await } async fn get_table_tag_version( @@ -926,10 +1310,11 @@ impl LanceNamespace for RestNamespace { request: GetTableTagVersionRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - tag_api::get_table_tag_version(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/tags/version", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "get_table_tag_version", &id) .await - .map_err(convert_api_error) } async fn create_table_tag( @@ -937,10 +1322,11 @@ impl LanceNamespace for RestNamespace { request: CreateTableTagRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - tag_api::create_table_tag(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/tags/create", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "create_table_tag", &id) .await - .map_err(convert_api_error) } async fn delete_table_tag( @@ -948,10 +1334,11 @@ impl LanceNamespace for RestNamespace { request: DeleteTableTagRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - tag_api::delete_table_tag(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/tags/delete", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "delete_table_tag", &id) .await - .map_err(convert_api_error) } async fn update_table_tag( @@ -959,16 +1346,18 @@ impl LanceNamespace for RestNamespace { request: UpdateTableTagRequest, ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - - tag_api::update_table_tag(&self.reqwest_config, &id, request, Some(&self.delimiter)) + let encoded_id = urlencode(&id); + let path = format!("/v1/table/{}/tags/update", encoded_id); + let query = [("delimiter", self.delimiter.as_str())]; + self.post_json(&path, &query, &request, "update_table_tag", &id) .await - .map_err(convert_api_error) } fn namespace_id(&self) -> String { format!( "RestNamespace {{ endpoint: {:?}, delimiter: {:?} }}", - self.reqwest_config.base_path, self.delimiter + self.rest_client.base_path(), + self.delimiter ) } } @@ -1153,10 +1542,7 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); - - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -1192,10 +1578,7 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); - - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -1228,10 +1611,7 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); - - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); let request = CreateNamespaceRequest { id: Some(vec!["test".to_string(), "newnamespace".to_string()]), @@ -1264,10 +1644,7 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); - - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); let request = CreateTableRequest { id: Some(vec![ @@ -1302,10 +1679,7 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); - - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); let request = InsertIntoTableRequest { id: Some(vec![ @@ -1325,4 +1699,176 @@ mod tests { let response = result.unwrap(); assert_eq!(response.transaction_id, Some("txn-123".to_string())); } + + // Integration tests for DynamicContextProvider + + #[derive(Debug)] + struct TestContextProvider { + headers: HashMap, + } + + impl DynamicContextProvider for TestContextProvider { + fn provide_context(&self, _info: &OperationInfo) -> HashMap { + self.headers.clone() + } + } + + #[tokio::test] + async fn test_context_provider_headers_sent() { + let mock_server = MockServer::start().await; + + // Mock expects the context header + Mock::given(method("GET")) + .and(path("/v1/namespace/test/list")) + .and(wiremock::matchers::header( + "X-Context-Token", + "dynamic-token", + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "namespaces": [] + }))) + .mount(&mock_server) + .await; + + // Create context provider + let mut context_headers = HashMap::new(); + context_headers.insert( + "headers.X-Context-Token".to_string(), + "dynamic-token".to_string(), + ); + let provider = Arc::new(TestContextProvider { + headers: context_headers, + }); + + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .context_provider(provider) + .build(); + + let request = ListNamespacesRequest { + id: Some(vec!["test".to_string()]), + ..Default::default() + }; + + let result = namespace.list_namespaces(request).await; + assert!(result.is_ok(), "Failed: {:?}", result.err()); + } + + #[tokio::test] + async fn test_base_headers_merged_with_context_headers() { + let mock_server = MockServer::start().await; + + // Mock expects BOTH base header AND context header + Mock::given(method("GET")) + .and(path("/v1/namespace/test/list")) + .and(wiremock::matchers::header( + "Authorization", + "Bearer base-token", + )) + .and(wiremock::matchers::header( + "X-Context-Token", + "dynamic-token", + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "namespaces": [] + }))) + .mount(&mock_server) + .await; + + // Create context provider + let mut context_headers = HashMap::new(); + context_headers.insert( + "headers.X-Context-Token".to_string(), + "dynamic-token".to_string(), + ); + let provider = Arc::new(TestContextProvider { + headers: context_headers, + }); + + // Create namespace with base header AND context provider + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .header("Authorization", "Bearer base-token") + .context_provider(provider) + .build(); + + let request = ListNamespacesRequest { + id: Some(vec!["test".to_string()]), + ..Default::default() + }; + + let result = namespace.list_namespaces(request).await; + assert!(result.is_ok(), "Failed: {:?}", result.err()); + } + + #[tokio::test] + async fn test_context_headers_override_base_headers() { + let mock_server = MockServer::start().await; + + // Mock expects the CONTEXT header value (not base) + Mock::given(method("GET")) + .and(path("/v1/namespace/test/list")) + .and(wiremock::matchers::header( + "Authorization", + "Bearer context-override-token", + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "namespaces": [] + }))) + .mount(&mock_server) + .await; + + // Context provider that overrides Authorization header + let mut context_headers = HashMap::new(); + context_headers.insert( + "headers.Authorization".to_string(), + "Bearer context-override-token".to_string(), + ); + let provider = Arc::new(TestContextProvider { + headers: context_headers, + }); + + // Create namespace with base header that will be overridden + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .header("Authorization", "Bearer base-token") + .context_provider(provider) + .build(); + + let request = ListNamespacesRequest { + id: Some(vec!["test".to_string()]), + ..Default::default() + }; + + let result = namespace.list_namespaces(request).await; + assert!(result.is_ok(), "Failed: {:?}", result.err()); + } + + #[tokio::test] + async fn test_no_context_provider_uses_base_headers_only() { + let mock_server = MockServer::start().await; + + // Mock expects only the base header + Mock::given(method("GET")) + .and(path("/v1/namespace/test/list")) + .and(wiremock::matchers::header( + "Authorization", + "Bearer base-only", + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "namespaces": [] + }))) + .mount(&mock_server) + .await; + + // Create namespace WITHOUT context provider, only base headers + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .header("Authorization", "Bearer base-only") + .build(); + + let request = ListNamespacesRequest { + id: Some(vec!["test".to_string()]), + ..Default::default() + }; + + let result = namespace.list_namespaces(request).await; + assert!(result.is_ok(), "Failed: {:?}", result.err()); + } } diff --git a/rust/lance-namespace-impls/src/rest_adapter.rs b/rust/lance-namespace-impls/src/rest_adapter.rs index 4a12b92838a..899863793ff 100644 --- a/rust/lance-namespace-impls/src/rest_adapter.rs +++ b/rust/lance-namespace-impls/src/rest_adapter.rs @@ -2776,5 +2776,131 @@ mod tests { .unwrap(); assert_eq!(a_col.values(), &[100, 200]); } + + // ============================================================================ + // DynamicContextProvider Integration Test + // ============================================================================ + + use crate::context::{DynamicContextProvider, OperationInfo}; + use std::collections::HashMap; + + /// Test context provider that adds custom headers to every request. + #[derive(Debug)] + struct TestDynamicContextProvider { + headers: HashMap, + } + + impl DynamicContextProvider for TestDynamicContextProvider { + fn provide_context(&self, _info: &OperationInfo) -> HashMap { + self.headers.clone() + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_rest_namespace_with_context_provider() { + let temp_dir = TempDir::new().unwrap(); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + + // Create DirectoryNamespace backend with manifest enabled + let backend = DirectoryNamespaceBuilder::new(&temp_path) + .manifest_enabled(true) + .build() + .await + .unwrap(); + let backend = Arc::new(backend); + + // Start REST server + let config = RestAdapterConfig { + port: 0, + ..Default::default() + }; + + let server = RestAdapter::new(backend.clone(), config); + let server_handle = server.start().await.unwrap(); + let actual_port = server_handle.port(); + + // Create context provider that adds custom headers + let mut context_headers = HashMap::new(); + context_headers.insert( + "headers.X-Custom-Auth".to_string(), + "test-auth-token".to_string(), + ); + context_headers.insert( + "headers.X-Request-Source".to_string(), + "integration-test".to_string(), + ); + + let provider = Arc::new(TestDynamicContextProvider { + headers: context_headers, + }); + + // Create RestNamespace client with context provider and base headers + let server_url = format!("http://127.0.0.1:{}", actual_port); + let namespace = RestNamespaceBuilder::new(&server_url) + .delimiter("$") + .header("X-Base-Header", "base-value") + .context_provider(provider) + .build(); + + // Create a namespace - should work with context provider + let create_req = CreateNamespaceRequest { + id: Some(vec!["context_test_ns".to_string()]), + properties: None, + mode: None, + identity: None, + context: None, + }; + let result = namespace.create_namespace(create_req).await; + assert!(result.is_ok(), "Failed to create namespace: {:?}", result); + + // List namespaces - should also work + let list_req = ListNamespacesRequest { + id: Some(vec![]), + limit: Some(10), + page_token: None, + identity: None, + context: None, + }; + let result = namespace.list_namespaces(list_req).await; + assert!(result.is_ok(), "Failed to list namespaces: {:?}", result); + let response = result.unwrap(); + assert!( + response.namespaces.contains(&"context_test_ns".to_string()), + "Namespace not found in list" + ); + + // Create a table - should work with context provider + let table_data = create_test_arrow_data(); + let create_table_req = CreateTableRequest { + id: Some(vec![ + "context_test_ns".to_string(), + "test_table".to_string(), + ]), + mode: Some("create".to_string()), + identity: None, + context: None, + }; + let result = namespace.create_table(create_table_req, table_data).await; + assert!(result.is_ok(), "Failed to create table: {:?}", result); + + // Describe the table - should work with context provider + let describe_req = DescribeTableRequest { + id: Some(vec![ + "context_test_ns".to_string(), + "test_table".to_string(), + ]), + with_table_uri: None, + load_detailed_metadata: None, + vend_credentials: None, + version: None, + identity: None, + context: None, + }; + let result = namespace.describe_table(describe_req).await; + assert!(result.is_ok(), "Failed to describe table: {:?}", result); + + // Cleanup + server_handle.shutdown(); + } } }