fix: add checkpoint when enqueue scan tasks for scan all (#18680)

Fix the scanAll cannot be stopped in case of large number of artifacts,
add the checkpoint before submit scan tasks, mark the scanAll stopped
flag in the redis.

Fixes: #18044

Signed-off-by: chlins <chenyuzh@vmware.com>
This commit is contained in:
Chlins Zhang 2023-06-05 15:12:54 +08:00 committed by GitHub
parent 9d28d1f43f
commit fbeeaa7537
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 127 additions and 19 deletions

View File

@ -40,7 +40,13 @@ func Iterator(ctx context.Context, chunkSize int, query *q.Query, option *Option
}
for _, artifact := range artifacts {
ch <- artifact
select {
case <-ctx.Done():
log.G(ctx).Errorf("context done, list artifacts exited, error: %v", ctx.Err())
return
case ch <- artifact:
continue
}
}
if len(artifacts) < chunkSize {

View File

@ -21,6 +21,7 @@ import (
"reflect"
"strings"
"sync"
"time"
"github.com/google/uuid"
@ -30,6 +31,7 @@ import (
sc "github.com/goharbor/harbor/src/controller/scanner"
"github.com/goharbor/harbor/src/controller/tag"
"github.com/goharbor/harbor/src/jobservice/job"
"github.com/goharbor/harbor/src/lib/cache"
"github.com/goharbor/harbor/src/lib/config"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/log"
@ -50,8 +52,12 @@ import (
"github.com/goharbor/harbor/src/pkg/task"
)
// DefaultController is a default singleton scan API controller.
var DefaultController = NewController()
var (
// DefaultController is a default singleton scan API controller.
DefaultController = NewController()
errScanAllStopped = errors.New("scanAll stopped")
)
// const definitions
const (
@ -74,6 +80,9 @@ type uuidGenerator func() (string, error)
// utility methods.
type configGetter func(cfg string) (string, error)
// cacheGetter returns cache
type cacheGetter func() cache.Cache
// launchScanJobParam is a param to launch scan job.
type launchScanJobParam struct {
ExecutionID int64
@ -109,6 +118,8 @@ type basicController struct {
taskMgr task.Manager
// Converter for V1 report to V2 report
reportConverter postprocessors.NativeScanReportConverter
// cache stores the stop scan all marks
cache cacheGetter
}
// NewController news a scan API controller
@ -154,6 +165,9 @@ func NewController() Controller {
taskMgr: task.Mgr,
// Get the scan V1 to V2 report converters
reportConverter: postprocessors.Converter,
cache: func() cache.Cache {
return cache.Default()
},
}
}
@ -368,6 +382,44 @@ func (bc *basicController) ScanAll(ctx context.Context, trigger string, async bo
return executionID, nil
}
func (bc *basicController) StopScanAll(ctx context.Context, executionID int64, async bool) error {
stopScanAll := func(ctx context.Context, executionID int64) error {
// mark scan all stopped
if err := bc.markScanAllStopped(ctx, executionID); err != nil {
return err
}
// stop the execution and sub tasks
return bc.execMgr.Stop(ctx, executionID)
}
if async {
go func() {
if err := stopScanAll(ctx, executionID); err != nil {
log.Errorf("failed to stop scan all, error: %v", err)
}
}()
return nil
}
return stopScanAll(ctx, executionID)
}
func scanAllStoppedKey(execID int64) string {
return fmt.Sprintf("scan_all:execution_id:%d:stopped", execID)
}
func (bc *basicController) markScanAllStopped(ctx context.Context, execID int64) error {
// set the expire time to 2 hours, the duration should be large enough
// for controller to capture the stop flag, leverage the key recycled
// by redis TTL, no need to clean by scan controller as the new scan all
// will have a new unique execution id, the old key has no effects to anything.
return bc.cache().Save(ctx, scanAllStoppedKey(execID), "", 2*time.Hour)
}
func (bc *basicController) isScanAllStopped(ctx context.Context, execID int64) bool {
return bc.cache().Contains(ctx, scanAllStoppedKey(execID))
}
func (bc *basicController) startScanAll(ctx context.Context, executionID int64) error {
batchSize := 50
@ -379,8 +431,15 @@ func (bc *basicController) startScanAll(ctx context.Context, executionID int64)
UnsupportCount int `json:"unsupport_count"`
UnknowCount int `json:"unknow_count"`
}{}
// with cancel function to signal downstream worker
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for artifact := range ar.Iterator(ctx, batchSize, nil, nil) {
if bc.isScanAllStopped(ctx, executionID) {
return errScanAllStopped
}
summary.TotalCount++
scan := func(ctx context.Context) error {

View File

@ -30,6 +30,7 @@ import (
"github.com/goharbor/harbor/src/common/rbac"
"github.com/goharbor/harbor/src/controller/artifact"
"github.com/goharbor/harbor/src/controller/robot"
"github.com/goharbor/harbor/src/lib/cache"
"github.com/goharbor/harbor/src/lib/config"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q"
@ -49,6 +50,7 @@ import (
robottesting "github.com/goharbor/harbor/src/testing/controller/robot"
scannertesting "github.com/goharbor/harbor/src/testing/controller/scanner"
tagtesting "github.com/goharbor/harbor/src/testing/controller/tag"
mockcache "github.com/goharbor/harbor/src/testing/lib/cache"
ormtesting "github.com/goharbor/harbor/src/testing/lib/orm"
"github.com/goharbor/harbor/src/testing/mock"
accessorytesting "github.com/goharbor/harbor/src/testing/pkg/accessory"
@ -77,6 +79,7 @@ type ControllerTestSuite struct {
ar artifact.Controller
c Controller
reportConverter *postprocessorstesting.ScanReportV1ToV2Converter
cache *mockcache.Cache
}
// TestController is the entry point of ControllerTestSuite.
@ -271,6 +274,8 @@ func (suite *ControllerTestSuite) SetupSuite() {
suite.taskMgr = &tasktesting.Manager{}
suite.cache = &mockcache.Cache{}
suite.c = &basicController{
manager: mgr,
ar: suite.ar,
@ -298,6 +303,7 @@ func (suite *ControllerTestSuite) SetupSuite() {
execMgr: suite.execMgr,
taskMgr: suite.taskMgr,
reportConverter: &postprocessorstesting.ScanReportV1ToV2Converter{},
cache: func() cache.Cache { return suite.cache },
}
}
@ -522,25 +528,25 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() {
func (suite *ControllerTestSuite) TestScanAll() {
{
// no artifacts found when scan all
ctx := context.TODO()
executionID := int64(1)
suite.execMgr.On(
"Create", ctx, "SCAN_ALL", int64(0), "SCHEDULE",
"Create", mock.Anything, "SCAN_ALL", int64(0), "SCHEDULE",
).Return(executionID, nil).Once()
mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once()
mock.OnAnything(suite.artifactCtl, "List").Return([]*artifact.Artifact{}, nil).Once()
suite.taskMgr.On("Count", ctx, q.New(q.KeyWords{"execution_id": executionID})).Return(int64(0), nil).Once()
suite.taskMgr.On("Count", mock.Anything, q.New(q.KeyWords{"execution_id": executionID})).Return(int64(0), nil).Once()
mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()
suite.execMgr.On("MarkDone", ctx, executionID, mock.Anything).Return(nil).Once()
suite.execMgr.On("MarkDone", mock.Anything, executionID, mock.Anything).Return(nil).Once()
_, err := suite.c.ScanAll(ctx, "SCHEDULE", false)
suite.cache.On("Contains", mock.Anything, scanAllStoppedKey(1)).Return(false).Once()
_, err := suite.c.ScanAll(context.TODO(), "SCHEDULE", false)
suite.NoError(err)
}
@ -551,7 +557,7 @@ func (suite *ControllerTestSuite) TestScanAll() {
executionID := int64(1)
suite.execMgr.On(
"Create", ctx, "SCAN_ALL", int64(0), "SCHEDULE",
"Create", mock.Anything, "SCAN_ALL", int64(0), "SCHEDULE",
).Return(executionID, nil).Once()
mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once()
@ -568,13 +574,28 @@ func (suite *ControllerTestSuite) TestScanAll() {
mock.OnAnything(suite.reportMgr, "Create").Return("uuid", nil).Once()
mock.OnAnything(suite.taskMgr, "Create").Return(int64(0), fmt.Errorf("failed")).Once()
mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()
suite.execMgr.On("MarkError", ctx, executionID, mock.Anything).Return(nil).Once()
suite.execMgr.On("MarkError", mock.Anything, executionID, mock.Anything).Return(nil).Once()
_, err := suite.c.ScanAll(ctx, "SCHEDULE", false)
suite.NoError(err)
}
}
func (suite *ControllerTestSuite) TestStopScanAll() {
mockExecID := int64(100)
// mock error case
mockErr := fmt.Errorf("stop scan all error")
suite.cache.On("Save", mock.Anything, scanAllStoppedKey(mockExecID), mock.Anything, mock.Anything).Return(mockErr).Once()
err := suite.c.StopScanAll(context.TODO(), mockExecID, false)
suite.EqualError(err, mockErr.Error())
// mock normal case
suite.cache.On("Save", mock.Anything, scanAllStoppedKey(mockExecID), mock.Anything, mock.Anything).Return(nil).Once()
suite.execMgr.On("Stop", mock.Anything, mockExecID).Return(nil).Once()
err = suite.c.StopScanAll(context.TODO(), mockExecID, false)
suite.NoError(err)
}
func (suite *ControllerTestSuite) TestDeleteReports() {
suite.reportMgr.On("DeleteByDigests", context.TODO(), "digest").Return(nil).Once()

View File

@ -157,7 +157,7 @@ func (suite *CallbackTestSuite) TestScanAllCallback() {
mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()
suite.execMgr.On("MarkDone", context.TODO(), executionID, mock.Anything).Return(nil).Once()
suite.execMgr.On("MarkDone", mock.Anything, executionID, mock.Anything).Return(nil).Once()
suite.NoError(scanAllCallback(context.TODO(), ""))
}

View File

@ -115,6 +115,16 @@ type Controller interface {
// error : non nil error if any errors occurred
ScanAll(ctx context.Context, trigger string, async bool) (int64, error)
// StopScanAll stops the scanAll
//
// Arguments:
// ctx context.Context : the context for this method
// executionID int64 : the id of scan all execution
// async bool : stop scan all in background
// Returns:
// error : non nil error if any errors occurred
StopScanAll(ctx context.Context, executionID int64, async bool) error
// GetVulnerable returns the vulnerable of the artifact for the allowlist
//
// Arguments:

View File

@ -28,7 +28,6 @@ import (
"github.com/goharbor/harbor/src/controller/scanner"
"github.com/goharbor/harbor/src/jobservice/job"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/log"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q"
"github.com/goharbor/harbor/src/pkg/scheduler"
@ -74,12 +73,10 @@ func (s *scanAllAPI) StopScanAll(ctx context.Context, params operation.StopScanA
if execution == nil {
return s.SendError(ctx, errors.BadRequestError(nil).WithMessage("no scan all job is found currently"))
}
go func(ctx context.Context, eid int64) {
err := s.execMgr.Stop(ctx, eid)
if err != nil {
log.Errorf("failed to stop the execution of executionID=%+v", execution.ID)
}
}(s.makeCtx(), execution.ID)
if err = s.scanCtl.StopScanAll(s.makeCtx(), execution.ID, true); err != nil {
return s.SendError(ctx, err)
}
return operation.NewStopScanAllAccepted()
}

View File

@ -247,6 +247,7 @@ func (suite *ScanAllTestSuite) TestStopScanAll() {
times := 3
suite.Security.On("IsAuthenticated").Return(true).Times(times)
suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Times(times)
mock.OnAnything(suite.scanCtl, "StopScanAll").Return(nil).Times(times)
mock.OnAnything(suite.scannerCtl, "ListRegistrations").Return([]*scanner.Registration{{ID: int64(1)}}, nil).Times(times)
{

View File

@ -205,6 +205,20 @@ func (_m *Controller) Stop(ctx context.Context, _a1 *artifact.Artifact) error {
return r0
}
// StopScanAll provides a mock function with given fields: ctx, executionID, async
func (_m *Controller) StopScanAll(ctx context.Context, executionID int64, async bool) error {
ret := _m.Called(ctx, executionID, async)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, bool) error); ok {
r0 = rf(ctx, executionID, async)
} else {
r0 = ret.Error(0)
}
return r0
}
type mockConstructorTestingTNewController interface {
mock.TestingT
Cleanup(func())