Skip to content

Commit c04d408

Browse files
committed
Support custom cache for OAuth2 tokens
1 parent aee6064 commit c04d408

File tree

3 files changed

+99
-10
lines changed

3 files changed

+99
-10
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,38 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins
227227
)
228228
```
229229

230+
A custom caching implementation can be provided by creating a class implementing the `trino.auth.OAuth2TokenCache` abstract class and adding it as in `OAuth2Authentication(cache=my_custom_cache_impl)`. The custom caching implementation enables usage in multi-user environments (notebooks, web applications) in combination with a custom `redirect_auth_url_handler` as explained above.
231+
232+
```python
233+
from typing import Optional
234+
235+
from trino.auth import OAuth2Authentication, OAuth2TokenCache
236+
from trino.dbapi import connect
237+
238+
239+
class MyCustomCacheImpl(OAuth2TokenCache):
240+
def get_token_from_cache(self, host: str) -> Optional[str]:
241+
# Retrieve your cached token from a distributed system
242+
# and return it
243+
pass
244+
245+
def store_token_to_cache(self, host: str, token: str) -> None:
246+
# Store your cached token in a distributed system
247+
pass
248+
249+
250+
def my_custom_redirect_handler(url: str) -> None:
251+
# ensure the url is opened by the user that should perform the authentication
252+
pass
253+
254+
conn = connect(
255+
user="<username>",
256+
auth=OAuth2Authentication(cache=MyCustomCacheImpl(), redirect_auth_url_handler=my_custom_redirect_handler),
257+
http_scheme="https",
258+
...
259+
)
260+
```
261+
230262
### Certificate Authentication
231263

232264
`CertificateAuthentication` class can be used to connect to Trino cluster configured with [certificate based authentication](https://trino.io/docs/current/security/certificate.html). `CertificateAuthentication` requires paths to a valid client certificate and private key.

tests/unit/test_dbapi.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# limitations under the License.
1212
import threading
1313
import uuid
14-
from unittest.mock import patch
14+
from unittest.mock import patch, MagicMock
1515

1616
import httpretty
1717
from httpretty import httprettified
@@ -20,7 +20,7 @@
2020
from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \
2121
GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS
2222
from trino import constants
23-
from trino.auth import OAuth2Authentication
23+
from trino.auth import OAuth2Authentication, OAuth2TokenCache
2424
from trino.dbapi import connect
2525

2626

@@ -107,6 +107,53 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data):
107107
assert len(_get_token_requests(challenge_id)) == 2
108108

109109

110+
@httprettified
111+
def test_custom_token_cache_is_invoked(sample_post_response_data):
112+
host = "coordinator"
113+
token = str(uuid.uuid4())
114+
challenge_id = str(uuid.uuid4())
115+
116+
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
117+
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
118+
119+
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
120+
121+
# bind post statement
122+
httpretty.register_uri(
123+
method=httpretty.POST,
124+
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
125+
body=post_statement_callback)
126+
127+
# bind get token
128+
get_token_callback = GetTokenCallback(token_server, token)
129+
httpretty.register_uri(
130+
method=httpretty.GET,
131+
uri=token_server,
132+
body=get_token_callback)
133+
134+
redirect_handler = RedirectHandler()
135+
136+
custom_cache = MagicMock(OAuth2TokenCache)
137+
custom_cache.get_token_from_cache = MagicMock(side_effect=[None, token, token, token])
138+
custom_cache.store_token_to_cache = MagicMock()
139+
140+
with connect(
141+
host,
142+
user="test",
143+
auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler, cache=custom_cache),
144+
http_scheme=constants.HTTPS
145+
) as conn:
146+
conn.cursor().execute("SELECT 1")
147+
conn.cursor().execute("SELECT 2")
148+
conn.cursor().execute("SELECT 3")
149+
150+
assert len(_get_token_requests(challenge_id)) == 1
151+
custom_cache.get_token_from_cache.assert_called_with(host)
152+
assert custom_cache.get_token_from_cache.call_count == 4
153+
custom_cache.store_token_to_cache.assert_called_with(host, token)
154+
assert custom_cache.store_token_to_cache.call_count == 1
155+
156+
110157
@httprettified
111158
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data):
112159
token = str(uuid.uuid4())

trino/auth.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def __call__(self, url: str):
202202
handler(url)
203203

204204

205-
class _OAuth2TokenCache(metaclass=abc.ABCMeta):
205+
class OAuth2TokenCache(metaclass=abc.ABCMeta):
206206
"""
207207
Abstract class for OAuth token cache, inherit from this class to implement your own token cache.
208208
"""
@@ -216,7 +216,7 @@ def store_token_to_cache(self, host: str, token: str) -> None:
216216
pass
217217

218218

219-
class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
219+
class _OAuth2TokenInMemoryCache(OAuth2TokenCache):
220220
"""
221221
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
222222
"""
@@ -231,7 +231,7 @@ def store_token_to_cache(self, host: str, token: str) -> None:
231231
self._cache[host] = token
232232

233233

234-
class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
234+
class _OAuth2KeyRingTokenCache(OAuth2TokenCache):
235235
"""
236236
Keyring Token Cache implementation
237237
"""
@@ -272,10 +272,9 @@ class _OAuth2TokenBearer(AuthBase):
272272
MAX_OAUTH_ATTEMPTS = 5
273273
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)
274274

275-
def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
275+
def __init__(self, redirect_auth_url_handler: Callable[[str], None], custom_cache: Optional[OAuth2TokenCache]):
276276
self._redirect_auth_url = redirect_auth_url_handler
277-
keyring_cache = _OAuth2KeyRingTokenCache()
278-
self._token_cache = keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache()
277+
self._token_cache = self._setup_cache(custom_cache)
279278
self._token_lock = threading.Lock()
280279
self._inside_oauth_attempt_lock = threading.Lock()
281280
self._inside_oauth_attempt_blocker = threading.Event()
@@ -291,6 +290,17 @@ def __call__(self, r):
291290

292291
return r
293292

293+
def _setup_cache(self, custom_cache):
294+
if custom_cache is not None:
295+
if not isinstance(custom_cache, OAuth2TokenCache):
296+
raise exceptions.TrinoAuthError("Custom cache does not implement `trino.auth.OAuth2TokenCache` "
297+
"interface")
298+
return custom_cache
299+
keyring_cache = _OAuth2KeyRingTokenCache()
300+
if keyring_cache.is_keyring_available():
301+
return keyring_cache
302+
return _OAuth2TokenInMemoryCache()
303+
294304
def _authenticate(self, response, **kwargs):
295305
if not 400 <= response.status_code < 500:
296306
return response
@@ -396,9 +406,9 @@ class OAuth2Authentication(Authentication):
396406
def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([
397407
WebBrowserRedirectHandler(),
398408
ConsoleRedirectHandler()
399-
])):
409+
]), cache: Optional[OAuth2TokenCache] = None):
400410
self._redirect_auth_url = redirect_auth_url_handler
401-
self._bearer = _OAuth2TokenBearer(self._redirect_auth_url)
411+
self._bearer = _OAuth2TokenBearer(self._redirect_auth_url, custom_cache=cache)
402412

403413
def set_http_session(self, http_session):
404414
http_session.auth = self._bearer

0 commit comments

Comments
 (0)