diff --git a/src/store.rs b/src/store.rs index 1e5fab472..32d595443 100644 --- a/src/store.rs +++ b/src/store.rs @@ -16,6 +16,7 @@ // under the License. use std::sync::Arc; +use std::time::Duration; use pyo3::prelude::*; @@ -24,6 +25,7 @@ use object_store::azure::{MicrosoftAzure, MicrosoftAzureBuilder}; use object_store::gcp::{GoogleCloudStorage, GoogleCloudStorageBuilder}; use object_store::http::{HttpBuilder, HttpStore}; use object_store::local::LocalFileSystem; +use object_store::ClientOptions; use pyo3::exceptions::PyValueError; use url::Url; @@ -164,6 +166,41 @@ impl PyGoogleCloudContext { } } +#[pyclass(name = "ClientOptions", module = "datafusion.store", subclass)] +#[derive(Debug, Clone)] +pub struct PyClientOptions { + pub inner: ClientOptions, +} + +impl Default for PyClientOptions { + fn default() -> Self { + Self::new() + } +} + +#[pymethods] +impl PyClientOptions { + #[pyo3(signature=())] + #[new] + pub fn new() -> Self { + Self { + inner: ClientOptions::new(), + } + } + + #[pyo3(signature = (timeout))] + pub fn with_timeout(&mut self, timeout: Duration) -> Self { + self.inner = self.inner.clone().with_timeout(timeout); + self.clone() + } + + #[pyo3(signature = (timeout))] + pub fn with_connect_timeout(&mut self, timeout: Duration) -> Self { + self.inner = self.inner.clone().with_connect_timeout(timeout); + self.clone() + } +} + #[pyclass(name = "AmazonS3", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyAmazonS3Context { @@ -174,7 +211,7 @@ pub struct PyAmazonS3Context { #[pymethods] impl PyAmazonS3Context { #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, endpoint=None, allow_http=false, imdsv1_fallback=false))] + #[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, endpoint=None, client_options=None, allow_http=false, imdsv1_fallback=false))] #[new] fn new( bucket_name: String, @@ -182,6 +219,7 @@ impl PyAmazonS3Context { access_key_id: Option, secret_access_key: Option, endpoint: Option, + client_options: Option, //retry_config: RetryConfig, allow_http: bool, imdsv1_fallback: bool, @@ -209,6 +247,10 @@ impl PyAmazonS3Context { builder = builder.with_imdsv1_fallback(); }; + if let Some(client_options) = client_options { + builder = builder.with_client_options(client_options.inner); + }; + let store = builder .with_bucket_name(bucket_name.clone()) //.with_retry_config(retry_config) #TODO: add later @@ -250,6 +292,7 @@ impl PyHttpContext { } pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;