diff --git a/consumer.go b/consumer.go index f160933..671f39f 100644 --- a/consumer.go +++ b/consumer.go @@ -7,9 +7,7 @@ import ( "sync/atomic" "time" - "github.com/goccy/go-json" "github.com/golang-queue/queue/core" - "github.com/golang-queue/queue/job" ) var _ core.Worker = (*Consumer)(nil) @@ -29,84 +27,9 @@ type Consumer struct { requestTimeout time.Duration } -func (s *Consumer) handle(m *job.Message) error { - // create channel with buffer size 1 to avoid goroutine leak - done := make(chan error, 1) - panicChan := make(chan interface{}, 1) - startTime := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), m.Timeout) - defer func() { - cancel() - }() - - // run the job - go func() { - // handle panic issue - defer func() { - if p := recover(); p != nil { - panicChan <- p - } - }() - - // run custom process function - var err error - loop: - for { - if m.Task != nil { - err = m.Task(ctx) - } else { - err = s.runFunc(ctx, m) - } - - // check error and retry count - if err == nil || m.RetryCount == 0 { - break - } - m.RetryCount-- - - select { - case <-time.After(m.RetryDelay): // retry delay time - case <-ctx.Done(): // timeout reached - err = ctx.Err() - break loop - } - } - - done <- err - }() - - select { - case p := <-panicChan: - panic(p) - case <-ctx.Done(): // timeout reached - return ctx.Err() - case <-s.stop: // shutdown service - // cancel job - cancel() - - leftTime := m.Timeout - time.Since(startTime) - // wait job - select { - case <-time.After(leftTime): - return context.DeadlineExceeded - case err := <-done: // job finish - return err - case p := <-panicChan: - panic(p) - } - case err := <-done: // job finish - return err - } -} - // Run to execute new task -func (s *Consumer) Run(task core.QueuedMessage) error { - data := task.(*job.Message) - if data.Task == nil { - _ = json.Unmarshal(task.Bytes(), data) - } - - return s.handle(data) +func (s *Consumer) Run(ctx context.Context, task core.QueuedMessage) error { + return s.runFunc(ctx, task) } // Shutdown the worker diff --git a/consumer_test.go b/consumer_test.go index 0cfccda..18e41c8 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -207,121 +207,6 @@ func TestGoroutinePanic(t *testing.T) { q.Release() } -func TestHandleTimeout(t *testing.T) { - m := &job.Message{ - Timeout: 100 * time.Millisecond, - Payload: []byte("foo"), - } - w := NewConsumer( - WithFn(func(ctx context.Context, m core.QueuedMessage) error { - time.Sleep(200 * time.Millisecond) - return nil - }), - ) - - err := w.handle(m) - assert.Error(t, err) - assert.Equal(t, context.DeadlineExceeded, err) - - m = &job.Message{ - Timeout: 150 * time.Millisecond, - Payload: []byte("foo"), - } - - w = NewConsumer( - WithFn(func(ctx context.Context, m core.QueuedMessage) error { - time.Sleep(200 * time.Millisecond) - return nil - }), - ) - - done := make(chan error) - go func() { - done <- w.handle(m) - }() - - err = <-done - assert.Error(t, err) - assert.Equal(t, context.DeadlineExceeded, err) -} - -func TestJobComplete(t *testing.T) { - m := &job.Message{ - Timeout: 100 * time.Millisecond, - Payload: []byte("foo"), - } - w := NewConsumer( - WithFn(func(ctx context.Context, m core.QueuedMessage) error { - return errors.New("job completed") - }), - ) - - err := w.handle(m) - assert.Error(t, err) - assert.Equal(t, errors.New("job completed"), err) - - m = &job.Message{ - Timeout: 250 * time.Millisecond, - Payload: []byte("foo"), - } - - w = NewConsumer( - WithFn(func(ctx context.Context, m core.QueuedMessage) error { - time.Sleep(200 * time.Millisecond) - return errors.New("job completed") - }), - ) - - done := make(chan error) - go func() { - done <- w.handle(m) - }() - - err = <-done - assert.Error(t, err) - assert.Equal(t, errors.New("job completed"), err) -} - -func TestTaskJobComplete(t *testing.T) { - m := &job.Message{ - Timeout: 100 * time.Millisecond, - Task: func(ctx context.Context) error { - return errors.New("job completed") - }, - } - w := NewConsumer() - - err := w.handle(m) - assert.Error(t, err) - assert.Equal(t, errors.New("job completed"), err) - - m = &job.Message{ - Timeout: 250 * time.Millisecond, - Task: func(ctx context.Context) error { - return nil - }, - } - - w = NewConsumer() - done := make(chan error) - go func() { - done <- w.handle(m) - }() - - err = <-done - assert.NoError(t, err) - - // job timeout - m = &job.Message{ - Timeout: 50 * time.Millisecond, - Task: func(ctx context.Context) error { - time.Sleep(60 * time.Millisecond) - return nil - }, - } - assert.Equal(t, context.DeadlineExceeded, w.handle(m)) -} - func TestIncreaseWorkerCount(t *testing.T) { w := NewConsumer( WithLogger(NewEmptyLogger()), diff --git a/core/worker.go b/core/worker.go index 61f5bc2..56c4c53 100644 --- a/core/worker.go +++ b/core/worker.go @@ -1,9 +1,11 @@ package core +import "context" + // Worker interface type Worker interface { // Run is called to start the worker - Run(task QueuedMessage) error + Run(ctx context.Context, task QueuedMessage) error // Shutdown is called if stop all worker Shutdown() error // Queue to send message in Queue diff --git a/mocks/mock_worker.go b/mocks/mock_worker.go index 1d50bca..5440693 100644 --- a/mocks/mock_worker.go +++ b/mocks/mock_worker.go @@ -5,6 +5,7 @@ package mocks import ( + context "context" reflect "reflect" core "github.com/golang-queue/queue/core" @@ -64,17 +65,17 @@ func (mr *MockWorkerMockRecorder) Request() *gomock.Call { } // Run mocks base method. -func (m *MockWorker) Run(arg0 core.QueuedMessage) error { +func (m *MockWorker) Run(arg0 context.Context, arg1 core.QueuedMessage) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Run", arg0) + ret := m.ctrl.Call(m, "Run", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // Run indicates an expected call of Run. -func (mr *MockWorkerMockRecorder) Run(arg0 interface{}) *gomock.Call { +func (mr *MockWorkerMockRecorder) Run(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockWorker)(nil).Run), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockWorker)(nil).Run), arg0, arg1) } // Shutdown mocks base method. diff --git a/queue.go b/queue.go index 46a5d12..3298cf2 100644 --- a/queue.go +++ b/queue.go @@ -1,11 +1,13 @@ package queue import ( + "context" "errors" "sync" "sync/atomic" "time" + "github.com/goccy/go-json" "github.com/golang-queue/queue/core" "github.com/golang-queue/queue/job" ) @@ -167,11 +169,90 @@ func (q *Queue) work(task core.QueuedMessage) { } }() - if err = q.worker.Run(task); err != nil { + if err = q.run(task); err != nil { q.logger.Errorf("runtime error: %s", err.Error()) } } +func (q *Queue) run(task core.QueuedMessage) error { + data := task.(*job.Message) + if data.Task == nil { + _ = json.Unmarshal(task.Bytes(), data) + } + + return q.handle(data) +} + +func (q *Queue) handle(m *job.Message) error { + // create channel with buffer size 1 to avoid goroutine leak + done := make(chan error, 1) + panicChan := make(chan interface{}, 1) + startTime := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), m.Timeout) + defer func() { + cancel() + }() + + // run the job + go func() { + // handle panic issue + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + + // run custom process function + var err error + loop: + for { + if m.Task != nil { + err = m.Task(ctx) + } else { + err = q.worker.Run(ctx, m) + } + + // check error and retry count + if err == nil || m.RetryCount == 0 { + break + } + m.RetryCount-- + + select { + case <-time.After(m.RetryDelay): // retry delay time + case <-ctx.Done(): // timeout reached + err = ctx.Err() + break loop + } + } + + done <- err + }() + + select { + case p := <-panicChan: + panic(p) + case <-ctx.Done(): // timeout reached + return ctx.Err() + case <-q.quit: // shutdown service + // cancel job + cancel() + + leftTime := m.Timeout - time.Since(startTime) + // wait job + select { + case <-time.After(leftTime): + return context.DeadlineExceeded + case err := <-done: // job finish + return err + case p := <-panicChan: + panic(p) + } + case err := <-done: // job finish + return err + } +} + // UpdateWorkerCount to update worker number dynamically. func (q *Queue) UpdateWorkerCount(num int) { q.workerCount = num diff --git a/queue_test.go b/queue_test.go index be541bc..ab99a0d 100644 --- a/queue_test.go +++ b/queue_test.go @@ -1,6 +1,8 @@ package queue import ( + "context" + "errors" "testing" "time" @@ -61,7 +63,7 @@ func TestNewQueueWithDefaultWorker(t *testing.T) { m.EXPECT().Bytes().Return([]byte("test")).AnyTimes() w.EXPECT().Shutdown().Return(nil) w.EXPECT().Request().Return(m, nil).AnyTimes() - w.EXPECT().Run(m).Return(nil).AnyTimes() + w.EXPECT().Run(context.Background(), m).Return(nil).AnyTimes() q, err = NewQueue( WithWorker(w), ) @@ -141,3 +143,118 @@ func TestCloseQueueAfterShutdown(t *testing.T) { assert.Error(t, err) assert.Equal(t, ErrQueueShutdown, err) } + +func TestHandleTimeout(t *testing.T) { + m := &job.Message{ + Timeout: 100 * time.Millisecond, + Payload: []byte("foo"), + } + w := NewConsumer( + WithFn(func(ctx context.Context, m core.QueuedMessage) error { + time.Sleep(200 * time.Millisecond) + return nil + }), + ) + + q, err := NewQueue( + WithWorker(w), + ) + assert.NoError(t, err) + assert.NotNil(t, q) + + err = q.handle(m) + assert.Error(t, err) + assert.Equal(t, context.DeadlineExceeded, err) + + done := make(chan error) + go func() { + done <- q.handle(m) + }() + + err = <-done + assert.Error(t, err) + assert.Equal(t, context.DeadlineExceeded, err) +} + +func TestJobComplete(t *testing.T) { + m := &job.Message{ + Timeout: 100 * time.Millisecond, + Payload: []byte("foo"), + } + w := NewConsumer( + WithFn(func(ctx context.Context, m core.QueuedMessage) error { + return errors.New("job completed") + }), + ) + + q, err := NewQueue( + WithWorker(w), + ) + assert.NoError(t, err) + assert.NotNil(t, q) + + err = q.handle(m) + assert.Error(t, err) + assert.Equal(t, errors.New("job completed"), err) + + m = &job.Message{ + Timeout: 250 * time.Millisecond, + Payload: []byte("foo"), + } + + w = NewConsumer( + WithFn(func(ctx context.Context, m core.QueuedMessage) error { + time.Sleep(200 * time.Millisecond) + return errors.New("job completed") + }), + ) + + q, err = NewQueue( + WithWorker(w), + ) + assert.NoError(t, err) + assert.NotNil(t, q) + + err = q.handle(m) + assert.Error(t, err) + assert.Equal(t, errors.New("job completed"), err) +} + +func TestTaskJobComplete(t *testing.T) { + m := &job.Message{ + Timeout: 100 * time.Millisecond, + Task: func(ctx context.Context) error { + return errors.New("job completed") + }, + } + w := NewConsumer() + + q, err := NewQueue( + WithWorker(w), + ) + assert.NoError(t, err) + assert.NotNil(t, q) + + err = q.handle(m) + assert.Error(t, err) + assert.Equal(t, errors.New("job completed"), err) + + m = &job.Message{ + Timeout: 250 * time.Millisecond, + Task: func(ctx context.Context) error { + return nil + }, + } + + assert.NoError(t, q.handle(m)) + + // job timeout + m = &job.Message{ + Timeout: 50 * time.Millisecond, + Task: func(ctx context.Context) error { + time.Sleep(60 * time.Millisecond) + return nil + }, + } + assert.Equal(t, context.DeadlineExceeded, q.handle(m)) +} diff --git a/worker_message.go b/worker_message.go index 277fdba..8196027 100644 --- a/worker_message.go +++ b/worker_message.go @@ -1,6 +1,7 @@ package queue import ( + "context" "errors" "time" @@ -14,7 +15,7 @@ type messageWorker struct { messages chan core.QueuedMessage } -func (w *messageWorker) Run(task core.QueuedMessage) error { +func (w *messageWorker) Run(_ context.Context, task core.QueuedMessage) error { if string(task.Bytes()) == "panic" { panic("show panic") } diff --git a/worker_task.go b/worker_task.go index 1fea86b..faa627a 100644 --- a/worker_task.go +++ b/worker_task.go @@ -5,7 +5,6 @@ import ( "errors" "github.com/golang-queue/queue/core" - "github.com/golang-queue/queue/job" ) var _ core.Worker = (*taskWorker)(nil) @@ -15,12 +14,7 @@ type taskWorker struct { messages chan core.QueuedMessage } -func (w *taskWorker) Run(task core.QueuedMessage) error { - if v, ok := task.(*job.Message); ok { - if v.Task != nil { - _ = v.Task(context.Background()) - } - } +func (w *taskWorker) Run(_ context.Context, task core.QueuedMessage) error { return nil }