diff --git a/src/main/java/org/dataloader/DataLoaderHelper.java b/src/main/java/org/dataloader/DataLoaderHelper.java index f4e3915..d01b930 100644 --- a/src/main/java/org/dataloader/DataLoaderHelper.java +++ b/src/main/java/org/dataloader/DataLoaderHelper.java @@ -155,11 +155,13 @@ CompletableFuture load(K key, Object loadContext) { } } + @SuppressWarnings("unchecked") Object getCacheKey(K key) { return loaderOptions.cacheKeyFunction().isPresent() ? loaderOptions.cacheKeyFunction().get().getKey(key) : key; } + @SuppressWarnings("unchecked") Object getCacheKeyWithContext(K key, Object context) { return loaderOptions.cacheKeyFunction().isPresent() ? loaderOptions.cacheKeyFunction().get().getKeyWithContext(key, context) : key; @@ -511,6 +513,7 @@ private CompletableFuture> invokeBatchPublisher(List keys, List loadFunction = (BatchPublisherWithContext) batchLoadFunction; if (batchLoaderScheduler != null) { BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber, environment); @@ -519,6 +522,7 @@ private CompletableFuture> invokeBatchPublisher(List keys, List loadFunction = (BatchPublisher) batchLoadFunction; if (batchLoaderScheduler != null) { BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber); @@ -536,6 +540,7 @@ private CompletableFuture> invokeMappedBatchPublisher(List keys, List BatchLoaderScheduler batchLoaderScheduler = loaderOptions.getBatchLoaderScheduler(); if (batchLoadFunction instanceof MappedBatchPublisherWithContext) { + //noinspection unchecked MappedBatchPublisherWithContext loadFunction = (MappedBatchPublisherWithContext) batchLoadFunction; if (batchLoaderScheduler != null) { BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber, environment); @@ -544,6 +549,7 @@ private CompletableFuture> invokeMappedBatchPublisher(List keys, List loadFunction.load(keys, subscriber, environment); } } else { + //noinspection unchecked MappedBatchPublisher loadFunction = (MappedBatchPublisher) batchLoadFunction; if (batchLoaderScheduler != null) { BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber); @@ -618,24 +624,23 @@ private static DispatchResult emptyDispatchResult() { return (DispatchResult) EMPTY_DISPATCH_RESULT; } - private class DataLoaderSubscriber implements Subscriber { + private abstract class DataLoaderSubscriberBase implements Subscriber { - private final CompletableFuture> valuesFuture; - private final List keys; - private final List callContexts; - private final List> queuedFutures; + final CompletableFuture> valuesFuture; + final List keys; + final List callContexts; + final List> queuedFutures; - private final List clearCacheKeys = new ArrayList<>(); - private final List completedValues = new ArrayList<>(); - private int idx = 0; - private boolean onErrorCalled = false; - private boolean onCompleteCalled = false; + List clearCacheKeys = new ArrayList<>(); + List completedValues = new ArrayList<>(); + boolean onErrorCalled = false; + boolean onCompleteCalled = false; - private DataLoaderSubscriber( - CompletableFuture> valuesFuture, - List keys, - List callContexts, - List> queuedFutures + DataLoaderSubscriberBase( + CompletableFuture> valuesFuture, + List keys, + List callContexts, + List> queuedFutures ) { this.valuesFuture = valuesFuture; this.keys = keys; @@ -648,40 +653,87 @@ public void onSubscribe(Subscription subscription) { subscription.request(keys.size()); } - // onNext may be called by multiple threads - for the time being, we pass 'synchronized' to guarantee - // correctness (at the cost of speed). @Override - public synchronized void onNext(V value) { + public void onNext(T v) { assertState(!onErrorCalled, () -> "onError has already been called; onNext may not be invoked."); assertState(!onCompleteCalled, () -> "onComplete has already been called; onNext may not be invoked."); + } - K key = keys.get(idx); - Object callContext = callContexts.get(idx); - CompletableFuture future = queuedFutures.get(idx); + @Override + public void onComplete() { + assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked."); + onCompleteCalled = true; + } + + @Override + public void onError(Throwable throwable) { + assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked."); + onErrorCalled = true; + + stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts)); + } + + /* + * A value has arrived - how do we complete the future that's associated with it in a common way + */ + void onNextValue(K key, V value, Object callContext, List> futures) { if (value instanceof Try) { // we allow the batch loader to return a Try so we can better represent a computation // that might have worked or not. + //noinspection unchecked Try tryValue = (Try) value; if (tryValue.isSuccess()) { - future.complete(tryValue.get()); + futures.forEach(f -> f.complete(tryValue.get())); } else { stats.incrementLoadErrorCount(new IncrementLoadErrorCountStatisticsContext<>(key, callContext)); - future.completeExceptionally(tryValue.getThrowable()); - clearCacheKeys.add(keys.get(idx)); + futures.forEach(f -> f.completeExceptionally(tryValue.getThrowable())); + clearCacheKeys.add(key); } } else { - future.complete(value); + futures.forEach(f -> f.complete(value)); } + } + + Throwable unwrapThrowable(Throwable ex) { + if (ex instanceof CompletionException) { + ex = ex.getCause(); + } + return ex; + } + } + + private class DataLoaderSubscriber extends DataLoaderSubscriberBase { + + private int idx = 0; + + private DataLoaderSubscriber( + CompletableFuture> valuesFuture, + List keys, + List callContexts, + List> queuedFutures + ) { + super(valuesFuture, keys, callContexts, queuedFutures); + } + + // onNext may be called by multiple threads - for the time being, we pass 'synchronized' to guarantee + // correctness (at the cost of speed). + @Override + public synchronized void onNext(V value) { + super.onNext(value); + + K key = keys.get(idx); + Object callContext = callContexts.get(idx); + CompletableFuture future = queuedFutures.get(idx); + onNextValue(key, value, callContext, List.of(future)); completedValues.add(value); idx++; } - @Override - public void onComplete() { - assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked."); - onCompleteCalled = true; + @Override + public synchronized void onComplete() { + super.onComplete(); assertResultSize(keys, completedValues); possiblyClearCacheEntriesOnExceptions(clearCacheKeys); @@ -689,14 +741,9 @@ public void onComplete() { } @Override - public void onError(Throwable ex) { - assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked."); - onErrorCalled = true; - - stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts)); - if (ex instanceof CompletionException) { - ex = ex.getCause(); - } + public synchronized void onError(Throwable ex) { + super.onError(ex); + ex = unwrapThrowable(ex); // Set the remaining keys to the exception. for (int i = idx; i < queuedFutures.size(); i++) { K key = keys.get(i); @@ -705,33 +752,25 @@ public void onError(Throwable ex) { // clear any cached view of this key because they all failed dataLoader.clear(key); } + valuesFuture.completeExceptionally(ex); } + } - private class DataLoaderMapEntrySubscriber implements Subscriber> { - private final CompletableFuture> valuesFuture; - private final List keys; - private final List callContexts; - private final List> queuedFutures; + private class DataLoaderMapEntrySubscriber extends DataLoaderSubscriberBase> { + private final Map callContextByKey; private final Map>> queuedFuturesByKey; - - private final List clearCacheKeys = new ArrayList<>(); private final Map completedValuesByKey = new HashMap<>(); - private boolean onErrorCalled = false; - private boolean onCompleteCalled = false; + private DataLoaderMapEntrySubscriber( - CompletableFuture> valuesFuture, - List keys, - List callContexts, - List> queuedFutures + CompletableFuture> valuesFuture, + List keys, + List callContexts, + List> queuedFutures ) { - this.valuesFuture = valuesFuture; - this.keys = keys; - this.callContexts = callContexts; - this.queuedFutures = queuedFutures; - + super(valuesFuture, keys, callContexts, queuedFutures); this.callContextByKey = new HashMap<>(); this.queuedFuturesByKey = new HashMap<>(); for (int idx = 0; idx < queuedFutures.size(); idx++) { @@ -743,42 +782,24 @@ private DataLoaderMapEntrySubscriber( } } - @Override - public void onSubscribe(Subscription subscription) { - subscription.request(keys.size()); - } @Override - public void onNext(Map.Entry entry) { - assertState(!onErrorCalled, () -> "onError has already been called; onNext may not be invoked."); - assertState(!onCompleteCalled, () -> "onComplete has already been called; onNext may not be invoked."); + public synchronized void onNext(Map.Entry entry) { + super.onNext(entry); K key = entry.getKey(); V value = entry.getValue(); Object callContext = callContextByKey.get(key); List> futures = queuedFuturesByKey.get(key); - if (value instanceof Try) { - // we allow the batch loader to return a Try so we can better represent a computation - // that might have worked or not. - Try tryValue = (Try) value; - if (tryValue.isSuccess()) { - futures.forEach(f -> f.complete(tryValue.get())); - } else { - stats.incrementLoadErrorCount(new IncrementLoadErrorCountStatisticsContext<>(key, callContext)); - futures.forEach(f -> f.completeExceptionally(tryValue.getThrowable())); - clearCacheKeys.add(key); - } - } else { - futures.forEach(f -> f.complete(value)); - } + + onNextValue(key, value, callContext, futures); completedValuesByKey.put(key, value); } @Override - public void onComplete() { - assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked."); - onCompleteCalled = true; + public synchronized void onComplete() { + super.onComplete(); possiblyClearCacheEntriesOnExceptions(clearCacheKeys); List values = new ArrayList<>(keys.size()); @@ -790,14 +811,9 @@ public void onComplete() { } @Override - public void onError(Throwable ex) { - assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked."); - onErrorCalled = true; - - stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts)); - if (ex instanceof CompletionException) { - ex = ex.getCause(); - } + public synchronized void onError(Throwable ex) { + super.onError(ex); + ex = unwrapThrowable(ex); // Complete the futures for the remaining keys with the exception. for (int idx = 0; idx < queuedFutures.size(); idx++) { K key = keys.get(idx); @@ -810,6 +826,7 @@ public void onError(Throwable ex) { dataLoader.clear(key); } } + valuesFuture.completeExceptionally(ex); } } }