Skip to content

Add ThreadPoolTaskExecutorRepeatTemplate #4815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package org.springframework.batch.repeat.support;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.batch.repeat.RepeatCallback;
import org.springframework.batch.repeat.RepeatContext;
import org.springframework.batch.repeat.RepeatStatus;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;

/**
* ThreadPoolTaskExecutorRepeatTemplate without throttleLimit setting.
*
* @author linus.yan
* @since 2025-04-25
*/
public class ThreadPoolTaskExecutorRepeatTemplate extends RepeatTemplate {

private static final Logger logger = LoggerFactory.getLogger(ThreadPoolTaskExecutorRepeatTemplate.class);

private ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();

// setter of taskExecutor
public void setTaskExecutor(ThreadPoolTaskExecutor taskExecutor) {
Assert.notNull(taskExecutor, "taskExecutor must not be null");
this.taskExecutor = taskExecutor;
}

private RepeatStatus status = RepeatStatus.CONTINUABLE;

protected RepeatStatus getNextResult(RepeatContext context, RepeatCallback callback, RepeatInternalState state)
throws Throwable {
RepeatStatusInternalState internalState = (RepeatStatusInternalState) state;

do {
ExecutingRunnable runnable = new ExecutingRunnable(callback, context, internalState);
this.taskExecutor.execute(runnable);
this.update(context);
}
while (internalState.getStatus().isContinuable() && !this.isComplete(context));

while (taskExecutor.getActiveCount() > 0) {
// wait for all tasks to finish
}

return internalState.getStatus();
}

protected boolean waitForResults(RepeatInternalState state) {
return ((RepeatStatusInternalState) state).getStatus().isContinuable();
}

protected RepeatInternalState createInternalState(RepeatContext context) {
return new RepeatStatusInternalState();
}

private class ExecutingRunnable implements Runnable {

private final RepeatCallback callback;

private final RepeatContext context;

private volatile RepeatStatusInternalState internalState;

private volatile Throwable error;

public ExecutingRunnable(RepeatCallback callback, RepeatContext context,
RepeatStatusInternalState internalState) {
this.callback = callback;
this.context = context;
this.internalState = internalState;
}

public void run() {
boolean clearContext = false;
RepeatStatus result = null;
try {
if (RepeatSynchronizationManager.getContext() == null) {
clearContext = true;
RepeatSynchronizationManager.register(this.context);
}

if (ThreadPoolTaskExecutorRepeatTemplate.this.logger.isDebugEnabled()) {
ThreadPoolTaskExecutorRepeatTemplate.this.logger
.debug("Repeat operation about to start at count=" + this.context.getStartedCount());
}

result = callback.doInIteration(context);
}
catch (Throwable e) {
this.error = e;
}
finally {
if (result == null) {
result = RepeatStatus.FINISHED;
}

internalState.setStatus(status.and(result.isContinuable()));

if (clearContext) {
RepeatSynchronizationManager.clear();
}
}

}

public Throwable getError() {
return this.error;
}

public RepeatContext getContext() {
return this.context;
}

}

private static class RepeatStatusInternalState extends RepeatInternalStateSupport {

private RepeatStatus status = RepeatStatus.CONTINUABLE;

public void setStatus(RepeatStatus status) {
this.status = status;
}

public RepeatStatus getStatus() {
return status;
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/*
* Copyright 2006-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.batch.repeat.support;

import org.junit.jupiter.api.Test;
import org.springframework.batch.item.Chunk;
import org.springframework.batch.item.ExecutionContext;
import org.springframework.batch.repeat.RepeatCallback;
import org.springframework.batch.repeat.RepeatContext;
import org.springframework.batch.repeat.RepeatStatus;
import org.springframework.batch.repeat.callback.NestedRepeatCallback;
import org.springframework.batch.repeat.policy.SimpleCompletionPolicy;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

import static org.junit.jupiter.api.Assertions.*;

class ThreadPoolTaskExecutorRepeatTemplateAsynchronousTests extends AbstractTradeBatchTests {

private final RepeatTemplate template = getRepeatTemplate();

private int count = 0;

private RepeatTemplate getRepeatTemplate() {
ThreadPoolTaskExecutorRepeatTemplate template = new ThreadPoolTaskExecutorRepeatTemplate();
template.setTaskExecutor(new ThreadPoolTaskExecutor());
// Set default completion above number of items in input file
template.setCompletionPolicy(new SimpleCompletionPolicy(8));
return template;
}

@Test
void testEarlyCompletionWithException() {

ThreadPoolTaskExecutorRepeatTemplate template = new ThreadPoolTaskExecutorRepeatTemplate();
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
template.setCompletionPolicy(new SimpleCompletionPolicy(20));
template.setTaskExecutor(taskExecutor);
Exception exception = assertThrows(IllegalStateException.class, () -> template.iterate(context -> {
count++;
throw new IllegalStateException("foo!");
}));
assertEquals("foo!", exception.getMessage());

assertTrue(count >= 1, "Too few attempts: " + count);
assertTrue(count <= 10, "Too many attempts: " + count);

}

@Test
void testExceptionHandlerSwallowsException() {

ThreadPoolTaskExecutorRepeatTemplate template = new ThreadPoolTaskExecutorRepeatTemplate();
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
template.setCompletionPolicy(new SimpleCompletionPolicy(4));
template.setTaskExecutor(taskExecutor);

template.setExceptionHandler((context, throwable) -> count++);
template.iterate(context -> {
throw new IllegalStateException("foo!");
});

assertTrue(count >= 1, "Too few attempts: " + count);
assertTrue(count <= 10, "Too many attempts: " + count);

}

@Test
void testNestedSession() {

RepeatTemplate outer = getRepeatTemplate();
RepeatTemplate inner = new RepeatTemplate();

outer.iterate(new NestedRepeatCallback(inner, context -> {
count++;
assertNotNull(context);
assertNotSame(context, context.getParent(), "Nested batch should have new session");
assertSame(context, RepeatSynchronizationManager.getContext());
return RepeatStatus.FINISHED;
}) {
@Override
public RepeatStatus doInIteration(RepeatContext context) throws Exception {
count++;
assertNotNull(context);
assertSame(context, RepeatSynchronizationManager.getContext());
return super.doInIteration(context);
}
});

assertTrue(count >= 1, "Too few attempts: " + count);
assertTrue(count <= 10, "Too many attempts: " + count);

}

/**
* Run a batch with a single template that itself has an async task executor. The
* result is a batch that runs in multiple threads (up to the throttle limit of the
* template).
*/
@Test
void testMultiThreadAsynchronousExecution() {

final String threadName = Thread.currentThread().getName();
final Set<String> threadNames = new HashSet<>();

final RepeatCallback callback = context -> {
assertNotSame(threadName, Thread.currentThread().getName());
threadNames.add(Thread.currentThread().getName());
Thread.sleep(100);
Trade item = provider.read();
if (item != null) {
processor.write(Chunk.of(item));
}
return RepeatStatus.continueIf(item != null);
};

template.iterate(callback);
// Shouldn't be necessary to wait:
// Thread.sleep(500);
assertEquals(NUMBER_OF_ITEMS, processor.count);
assertTrue(threadNames.size() > 1);
}

@Test
@SuppressWarnings("removal")
void testThrottleLimit() {

int throttleLimit = 600;

ThreadPoolTaskExecutorRepeatTemplate template = new ThreadPoolTaskExecutorRepeatTemplate();
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
template.setTaskExecutor(taskExecutor);

String threadName = Thread.currentThread().getName();
Set<String> threadNames = ConcurrentHashMap.newKeySet();
List<String> items = Collections.synchronizedList(new ArrayList<>());

RepeatCallback callback = context -> {
assertNotSame(threadName, Thread.currentThread().getName());
Trade item = provider.read();
threadNames.add(Thread.currentThread().getName() + " : " + item);
items.add(String.valueOf(item));
if (item != null) {
processor.write(Chunk.of(item));
// Do some more I/O
for (int i = 0; i < 10; i++) {
TradeItemReader provider = new TradeItemReader(resource);
provider.open(new ExecutionContext());
while (provider.read() != null)
continue;
provider.close();
}
}
return RepeatStatus.continueIf(item != null);
};

template.iterate(callback);
// Shouldn't be necessary to wait:
// Thread.sleep(500);
assertEquals(NUMBER_OF_ITEMS, processor.count);
assertTrue(threadNames.size() > 1);
int frequency = Collections.frequency(items, "null");
assertTrue(frequency <= throttleLimit);
}

/**
* Wrap an otherwise synchronous batch in a callback to an asynchronous template.
*/
@Test
void testSingleThreadAsynchronousExecution() {
ThreadPoolTaskExecutorRepeatTemplate jobTemplate = new ThreadPoolTaskExecutorRepeatTemplate();
final RepeatTemplate stepTemplate = new RepeatTemplate();
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
jobTemplate.setTaskExecutor(taskExecutor);

final String threadName = Thread.currentThread().getName();
final Set<String> threadNames = new HashSet<>();

final RepeatCallback stepCallback = new ItemReaderRepeatCallback<>(provider, processor) {
@Override
public RepeatStatus doInIteration(RepeatContext context) throws Exception {
assertNotSame(threadName, Thread.currentThread().getName());
threadNames.add(Thread.currentThread().getName());
Thread.sleep(100);
TradeItemReader provider = new TradeItemReader(resource);
provider.open(new ExecutionContext());
while (provider.read() != null)
;
return super.doInIteration(context);
}
};
RepeatCallback jobCallback = context -> {
stepTemplate.iterate(stepCallback);
return RepeatStatus.FINISHED;
};

jobTemplate.iterate(jobCallback);
// Shouldn't be necessary to wait:
// Thread.sleep(500);
assertEquals(NUMBER_OF_ITEMS, processor.count);
// Because of the throttling and queueing internally to a TaskExecutor,
// more than one thread will be used - the number used is the
// concurrency limit in the task executor, plus 1.
assertTrue(threadNames.size() >= 1);
}

}
Loading