diff --git a/cliv2/cmd/cliv2/instrumentation.go b/cliv2/cmd/cliv2/instrumentation.go index 09e7e81bf5..546f03cfbe 100644 --- a/cliv2/cmd/cliv2/instrumentation.go +++ b/cliv2/cmd/cliv2/instrumentation.go @@ -4,14 +4,19 @@ package main import _ "github.com/snyk/go-application-framework/pkg/networking/fips_enable" import ( + "context" + "encoding/json" "os/exec" + "strconv" "strings" "time" + "github.com/rs/zerolog" "github.com/snyk/go-application-framework/pkg/analytics" "github.com/snyk/go-application-framework/pkg/configuration" "github.com/snyk/go-application-framework/pkg/instrumentation" + "github.com/snyk/cli/cliv2/internal/constants" cli_utils "github.com/snyk/cli/cliv2/internal/utils" localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows" @@ -74,3 +79,76 @@ func updateInstrumentationDataBeforeSending(cliAnalytics analytics.Analytics, st cliAnalytics.GetInstrumentation().SetStatus(analytics.Failure) } } + +func sendAnalytics(ctx context.Context, a analytics.Analytics, debugLogger *zerolog.Logger) { + debugLogger.Print("Sending Analytics") + + a.SetApiUrl(globalConfiguration.GetString(configuration.API_URL)) + + request, err := a.GetRequest() + if err != nil { + debugLogger.Err(err).Msg("Failed to create Analytics request") + return + } + + // Use context to respect teardown timeout + request = request.WithContext(ctx) + + client := globalEngine.GetNetworkAccess().GetHttpClient() + res, err := client.Do(request) + if err != nil { + debugLogger.Err(err).Msg("Failed to send Analytics") + return + } + defer func() { + _ = res.Body.Close() + }() + + successfullySend := 200 <= res.StatusCode && res.StatusCode < 300 + if successfullySend { + debugLogger.Print("Analytics successfully send") + } else { + debugLogger.Print("Failed to send Analytics:", res.Status) + } +} + +func sendInstrumentation(ctx context.Context, eng workflow.Engine, instrumentor analytics.InstrumentationCollector, logger *zerolog.Logger) { + // Avoid duplicate data to be sent for IDE integrations that use the CLI + if !shallSendInstrumentation(eng.GetConfiguration(), instrumentor) { + logger.Print("This CLI call is not instrumented!") + return + } + + // add temporary static nodejs binary flag, remove once linuxstatic is official + staticNodeJsBinaryBool, parseErr := strconv.ParseBool(constants.StaticNodeJsBinary) + if parseErr != nil { + logger.Print("Failed to parse staticNodeJsBinary:", parseErr) + } else { + // the legacycli:: prefix is added to maintain compatibility with our monitoring dashboard + instrumentor.AddExtension("legacycli::static-nodejs-binary", staticNodeJsBinaryBool) + } + + logger.Print("Sending Instrumentation") + data, err := analytics.GetV2InstrumentationObject(instrumentor, analytics.WithLogger(logger)) + if err != nil { + logger.Err(err).Msg("Failed to derive data object") + } + + v2InstrumentationData := utils.ValueOf(json.Marshal(data)) + localConfiguration := globalConfiguration.Clone() + // the report analytics workflow needs --experimental to run + // we pass the flag here so that we report at every interaction + localConfiguration.Set(configuration.FLAG_EXPERIMENTAL, true) + localConfiguration.Set("inputData", string(v2InstrumentationData)) + _, err = eng.Invoke( + localworkflows.WORKFLOWID_REPORT_ANALYTICS, + workflow.WithConfig(localConfiguration), + workflow.WithContext(ctx), + ) + + if err != nil { + logger.Err(err).Msg("Failed to send Instrumentation") + } else { + logger.Print("Instrumentation successfully sent") + } +} diff --git a/cliv2/cmd/cliv2/main.go b/cliv2/cmd/cliv2/main.go index 07bb5f1f18..b66ac5a16f 100644 --- a/cliv2/cmd/cliv2/main.go +++ b/cliv2/cmd/cliv2/main.go @@ -11,9 +11,10 @@ import ( "io" "os" "os/exec" - "strconv" + "os/signal" "strings" "sync" + "syscall" "time" "github.com/google/uuid" @@ -75,6 +76,7 @@ import ( var internalOS string var globalEngine workflow.Engine var globalConfiguration configuration.Configuration +var globalContext context.Context var helpProvided bool var noopLogger zerolog.Logger = zerolog.New(io.Discard) @@ -88,6 +90,7 @@ const ( debug_level_flag string = "log-level" integrationNameFlag string = "integration-name" maxNetworkRequestAttempts string = "max-attempts" + teardownTimeout = 5 * time.Second ) type JsonErrorStruct struct { @@ -194,98 +197,33 @@ func runMainWorkflow(config configuration.Configuration, cmd *cobra.Command, arg globalLogger.Print("Running ", name) globalEngine.GetAnalytics().SetCommand(name) - err = runWorkflowAndProcessData(globalEngine, globalLogger, name) + err = runWorkflowAndProcessData(globalContext, globalEngine, globalLogger, name) return err } -func runWorkflowAndProcessData(engine workflow.Engine, logger *zerolog.Logger, name string) error { +func runWorkflowAndProcessData(ctx context.Context, engine workflow.Engine, logger *zerolog.Logger, name string) error { ic := engine.GetAnalytics().GetInstrumentation() - output, err := engine.Invoke(workflow.NewWorkflowIdentifier(name), workflow.WithInstrumentationCollector(ic)) + output, err := engine.Invoke(workflow.NewWorkflowIdentifier(name), workflow.WithContext(ctx), workflow.WithInstrumentationCollector(ic)) if err != nil { logger.Print("Failed to execute the command! ", err) return err } - outputFiltered, err := engine.Invoke(localworkflows.WORKFLOWID_FILTER_FINDINGS, workflow.WithInput(output), workflow.WithInstrumentationCollector(ic)) + outputFiltered, err := engine.Invoke(localworkflows.WORKFLOWID_FILTER_FINDINGS, workflow.WithContext(ctx), workflow.WithInput(output), workflow.WithInstrumentationCollector(ic)) if err != nil { logger.Err(err).Msg(err.Error()) return err } - _, err = engine.Invoke(localworkflows.WORKFLOWID_OUTPUT_WORKFLOW, workflow.WithInput(outputFiltered), workflow.WithInstrumentationCollector(ic)) + _, err = engine.Invoke(localworkflows.WORKFLOWID_OUTPUT_WORKFLOW, workflow.WithContext(ctx), workflow.WithInput(outputFiltered), workflow.WithInstrumentationCollector(ic)) if err == nil { err = getErrorFromWorkFlowData(engine, outputFiltered) } return err } -func sendAnalytics(analytics analytics.Analytics, debugLogger *zerolog.Logger) { - debugLogger.Print("Sending Analytics") - - analytics.SetApiUrl(globalConfiguration.GetString(configuration.API_URL)) - - res, err := analytics.Send() - if err != nil { - debugLogger.Err(err).Msg("Failed to send Analytics") - return - } - defer func() { _ = res.Body.Close() }() - - successfullySend := 200 <= res.StatusCode && res.StatusCode < 300 - if successfullySend { - debugLogger.Print("Analytics successfully send") - } else { - var details string - if res != nil { - details = res.Status - } - - debugLogger.Print("Failed to send Analytics:", details) - } -} - -func sendInstrumentation(eng workflow.Engine, instrumentor analytics.InstrumentationCollector, logger *zerolog.Logger) { - // Avoid duplicate data to be sent for IDE integrations that use the CLI - if !shallSendInstrumentation(eng.GetConfiguration(), instrumentor) { - logger.Print("This CLI call is not instrumented!") - return - } - - // add temporary static nodejs binary flag, remove once linuxstatic is official - staticNodeJsBinaryBool, parseErr := strconv.ParseBool(constants.StaticNodeJsBinary) - if parseErr != nil { - logger.Print("Failed to parse staticNodeJsBinary:", parseErr) - } else { - // the legacycli:: prefix is added to maintain compatibility with our monitoring dashboard - instrumentor.AddExtension("legacycli::static-nodejs-binary", staticNodeJsBinaryBool) - } - - logger.Print("Sending Instrumentation") - data, err := analytics.GetV2InstrumentationObject(instrumentor, analytics.WithLogger(logger)) - if err != nil { - logger.Err(err).Msg("Failed to derive data object") - } - - v2InstrumentationData := utils.ValueOf(json.Marshal(data)) - localConfiguration := globalConfiguration.Clone() - // the report analytics workflow needs --experimental to run - // we pass the flag here so that we report at every interaction - localConfiguration.Set(configuration.FLAG_EXPERIMENTAL, true) - localConfiguration.Set("inputData", string(v2InstrumentationData)) - _, err = eng.InvokeWithConfig( - localworkflows.WORKFLOWID_REPORT_ANALYTICS, - localConfiguration, - ) - - if err != nil { - logger.Err(err).Msg("Failed to send Instrumentation") - } else { - logger.Print("Instrumentation successfully sent") - } -} - func help(_ *cobra.Command, _ []string) error { helpProvided = true args := utils.RemoveSimilar(os.Args[1:], "--") // remove all double dash arguments to avoid issues with the help command @@ -548,11 +486,55 @@ func initExtensions(engine workflow.Engine, config configuration.Configuration) } } +// tearDown handles sending analytics and instrumentation +// It is used both for normal exit and signal-triggered exit +func tearDown(ctx context.Context, err error, errorList []error, startTime time.Time, ua networking.UserAgentInfo, cliAnalytics analytics.Analytics, networkAccess networking.NetworkAccess) int { + // Create a context with timeout for teardown operations to ensure we don't hang indefinitely + teardownCtx, cancel := context.WithTimeout(ctx, teardownTimeout) + defer cancel() + + if err != nil { + errorList, err = processError(err, errorList) + + for _, tempError := range errorList { + if tempError != nil { + cliAnalytics.AddError(tempError) + } + } + } + + exitCode := cliv2.DeriveExitCode(err) + globalLogger.Printf("Deriving Exit Code %d (cause: %v)", exitCode, err) + + displayError(err, globalEngine.GetUserInterface(), globalConfiguration, teardownCtx) + + updateInstrumentationDataBeforeSending(cliAnalytics, startTime, ua, exitCode) + + if !globalConfiguration.GetBool(configuration.ANALYTICS_DISABLED) { + sendAnalytics(teardownCtx, cliAnalytics, globalLogger) + } + sendInstrumentation(teardownCtx, globalEngine, cliAnalytics.GetInstrumentation(), globalLogger) + + // cleanup resources in use + // WARNING: deferred actions will execute AFTER cleanup; only defer if not impacted by this + if _, cleanupErr := globalEngine.Invoke(basic_workflows.WORKFLOWID_GLOBAL_CLEANUP, workflow.WithContext(teardownCtx)); cleanupErr != nil { + globalLogger.Printf("Failed to cleanup %v", cleanupErr) + } + + if globalConfiguration.GetBool(configuration.DEBUG) { + writeLogFooter(exitCode, errorList, globalConfiguration, networkAccess) + } + + return exitCode +} + func MainWithErrorCode() int { initDebugBuild() errorList := []error{} errorListMutex := sync.Mutex{} + var tearDownOnce sync.Once + var finalExitCode int startTime := time.Now() var err error @@ -633,9 +615,11 @@ func MainWithErrorCode() int { return constants.SNYK_EXIT_CODE_ERROR } - // init context - ctx := context.Background() + // init context with cancel function for signal handling + ctx, ctxCancel := context.WithCancel(context.Background()) + defer ctxCancel() // ensure context is canceled on exit ctx = context.WithValue(ctx, networking.InteractionIdKey, instrumentation.AssembleUrnFromUUID(interactionId)) + globalContext = ctx // add output flags as persistent flags outputWorkflow, _ := globalEngine.GetWorkflow(localworkflows.WORKFLOWID_OUTPUT_WORKFLOW) @@ -656,6 +640,35 @@ func MainWithErrorCode() int { cliAnalytics.GetInstrumentation().SetStage(instrumentation.DetermineStage(cliAnalytics.IsCiEnvironment())) cliAnalytics.GetInstrumentation().SetStatus(analytics.Success) + // prepare for signal handling + signalChan := make(chan os.Signal, 1) + exitCodeChan := make(chan int, 1) + + if globalConfiguration.GetBool(configuration.PREVIEW_FEATURES_ENABLED) { + // Set up signal handling to send instrumentation on premature termination + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-signalChan + globalLogger.Printf("Received signal %v, attempting to send instrumentation before exit", sig) + + // Cancel the context to terminate any running child processes + ctxCancel() + + tearDownOnce.Do(func() { + signalError := cli.NewTerminatedBySignalError(fmt.Sprintf("Signal: %v", sig)) + + errorListMutex.Lock() + errorListCopy := append([]error{}, errorList...) + errorListMutex.Unlock() + + finalExitCode = tearDown(ctx, signalError, errorListCopy, startTime, ua, cliAnalytics, networkAccess) + }) + // Send exit code to main goroutine instead of calling os.Exit directly + // This allows deferred functions (like lock cleanup) to run + exitCodeChan <- finalExitCode + }() + } + setTimeout(globalConfiguration, func() { os.Exit(constants.SNYK_EXIT_CODE_EX_UNAVAILABLE) }) @@ -681,40 +694,29 @@ func MainWithErrorCode() int { // ignore } - if err != nil { - errorList, err = processError(err, errorList) - - for _, tempError := range errorList { - if tempError != nil { - cliAnalytics.AddError(tempError) - } - } + // Check if signal handler already ran teardown + select { + case code := <-exitCodeChan: + // Signal was received and teardown completed - return its exit code + return code + default: + // No signal received - run normal teardown } - displayError(err, globalEngine.GetUserInterface(), globalConfiguration, ctx) - - exitCode := cliv2.DeriveExitCode(err) - globalLogger.Printf("Deriving Exit Code %d (cause: %v)", exitCode, err) - - updateInstrumentationDataBeforeSending(cliAnalytics, startTime, ua, exitCode) - - if !globalConfiguration.GetBool(configuration.ANALYTICS_DISABLED) { - sendAnalytics(cliAnalytics, globalLogger) + if globalConfiguration.GetBool(configuration.PREVIEW_FEATURES_ENABLED) { + // Stop signal handling before cleanup to prevent race conditions + signal.Stop(signalChan) } - sendInstrumentation(globalEngine, cliAnalytics.GetInstrumentation(), globalLogger) - // cleanup resources in use - // WARNING: deferred actions will execute AFTER cleanup; only defer if not impacted by this - _, err = globalEngine.Invoke(basic_workflows.WORKFLOWID_GLOBAL_CLEANUP) - if err != nil { - globalLogger.Printf("Failed to cleanup %v", err) - } + tearDownOnce.Do(func() { + errorListMutex.Lock() + errorListCopy := append([]error{}, errorList...) + errorListMutex.Unlock() - if debugEnabled { - writeLogFooter(exitCode, errorList, globalConfiguration, networkAccess) - } + finalExitCode = tearDown(ctx, err, errorListCopy, startTime, ua, cliAnalytics, networkAccess) + }) - return exitCode + return finalExitCode } func processError(err error, errorList []error) ([]error, error) { diff --git a/cliv2/cmd/cliv2/main_test.go b/cliv2/cmd/cliv2/main_test.go index 9fa76c314a..1828211851 100644 --- a/cliv2/cmd/cliv2/main_test.go +++ b/cliv2/cmd/cliv2/main_test.go @@ -466,7 +466,7 @@ func Test_runWorkflowAndProcessData(t *testing.T) { // invoke method under test logger := zerolog.New(os.Stderr) - err = runWorkflowAndProcessData(globalEngine, &logger, testCmnd) + err = runWorkflowAndProcessData(context.Background(), globalEngine, &logger, testCmnd) var expectedError *clierrors.ErrorWithExitCode assert.ErrorAs(t, err, &expectedError) @@ -560,7 +560,7 @@ func Test_runWorkflowAndProcessData_with_Filtering(t *testing.T) { assert.NoError(t, err) logger := zerolog.New(os.Stderr) - err = runWorkflowAndProcessData(globalEngine, &logger, testCmnd) + err = runWorkflowAndProcessData(context.Background(), globalEngine, &logger, testCmnd) } func Test_setTimeout(t *testing.T) { diff --git a/cliv2/go.mod b/cliv2/go.mod index c91d6f4173..3ec72dd52b 100644 --- a/cliv2/go.mod +++ b/cliv2/go.mod @@ -21,8 +21,8 @@ require ( github.com/snyk/cli-extension-secrets v0.0.0-20260330131056-456a17f6d188 github.com/snyk/code-client-go v1.26.2 github.com/snyk/container-cli v0.0.0-20260213211631-cd2b2cf8f3ea - github.com/snyk/error-catalog-golang-public v0.0.0-20260316131845-f02d7f42046b - github.com/snyk/go-application-framework v0.0.0-20260331133539-67257bb99539 + github.com/snyk/error-catalog-golang-public v0.0.0-20260326122451-686348fab429 + github.com/snyk/go-application-framework v0.0.0-20260402155353-0212397a709b github.com/snyk/go-httpauth v0.0.0-20240307114523-1f5ea3f55c65 github.com/snyk/snyk-iac-capture v0.6.5 github.com/snyk/snyk-ls v0.0.0-20260401163317-c1fe9ee766fd diff --git a/cliv2/go.sum b/cliv2/go.sum index 625cbc116b..937cd23653 100644 --- a/cliv2/go.sum +++ b/cliv2/go.sum @@ -559,10 +559,10 @@ github.com/snyk/container-cli v0.0.0-20260213211631-cd2b2cf8f3ea h1:/v48hCMPiZVj github.com/snyk/container-cli v0.0.0-20260213211631-cd2b2cf8f3ea/go.mod h1:P5yW8+jkwhYBsj5l2jtHeWujyX+SAtvkC8+LELKdlWI= github.com/snyk/dep-graph/go v0.0.0-20260127160647-c836da762c62 h1:kgZNQ5ztI4+n3YKLR5LJbqL8WJmUYgDSbFKREIY79g0= github.com/snyk/dep-graph/go v0.0.0-20260127160647-c836da762c62/go.mod h1:hTr91da/4ze2nk9q6ZW1BmfM2Z8rLUZSEZ3kK+6WGpc= -github.com/snyk/error-catalog-golang-public v0.0.0-20260316131845-f02d7f42046b h1:DM2SPu7rhsD/TNS7zhv4ZoqLLi2cFOqg1VTBCP6RfSg= -github.com/snyk/error-catalog-golang-public v0.0.0-20260316131845-f02d7f42046b/go.mod h1:Ytttq7Pw4vOCu9NtRQaOeDU2dhBYUyNBe6kX4+nIIQ4= -github.com/snyk/go-application-framework v0.0.0-20260331133539-67257bb99539 h1:5cetC+Ud3C9s+KlANfcQ9AlSkYKo2AWuUV1Erv6300M= -github.com/snyk/go-application-framework v0.0.0-20260331133539-67257bb99539/go.mod h1:7IOOtKxiQhtTbkrX7rax20QNJ/rwGill6n2Rejtld2I= +github.com/snyk/error-catalog-golang-public v0.0.0-20260326122451-686348fab429 h1:KUvautSov5PIOo3IQxbeu0d7zOVh5oO+sZ0N4lZkiJ8= +github.com/snyk/error-catalog-golang-public v0.0.0-20260326122451-686348fab429/go.mod h1:Ytttq7Pw4vOCu9NtRQaOeDU2dhBYUyNBe6kX4+nIIQ4= +github.com/snyk/go-application-framework v0.0.0-20260402155353-0212397a709b h1:DIqMmwwGno05IqrVYu+gJCgGI4T32CXBlLivryQ4NG0= +github.com/snyk/go-application-framework v0.0.0-20260402155353-0212397a709b/go.mod h1:7IOOtKxiQhtTbkrX7rax20QNJ/rwGill6n2Rejtld2I= github.com/snyk/go-httpauth v0.0.0-20240307114523-1f5ea3f55c65 h1:CEQuYv0Go6MEyRCD3YjLYM2u3Oxkx8GpCpFBd4rUTUk= github.com/snyk/go-httpauth v0.0.0-20240307114523-1f5ea3f55c65/go.mod h1:88KbbvGYlmLgee4OcQ19yr0bNpXpOr2kciOthaSzCAg= github.com/snyk/policy-engine v1.1.3 h1:MU+K8pxbN6VZ9P5wALUt8BwTjrPDpoEtmTtQqj7sKfY= diff --git a/cliv2/internal/cliv2/cliv2.go b/cliv2/internal/cliv2/cliv2.go index 98233e462b..bac899a594 100644 --- a/cliv2/internal/cliv2/cliv2.go +++ b/cliv2/internal/cliv2/cliv2.go @@ -263,8 +263,8 @@ func (c *CLI) commandVersion(passthroughArgs []string) error { } } -func (c *CLI) commandAbout(proxyInfo *proxy.ProxyInfo, passthroughArgs []string) error { - err := c.executeV1Default(proxyInfo, passthroughArgs) +func (c *CLI) commandAbout(ctx context.Context, proxyInfo *proxy.ProxyInfo, passthroughArgs []string) error { + err := c.executeV1Default(ctx, proxyInfo, passthroughArgs) if err != nil { return err } @@ -433,14 +433,11 @@ func (c *CLI) PrepareV1Command( return snykCmd, err } -func (c *CLI) executeV1Default(proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error { +func (c *CLI) executeV1Default(ctx context.Context, proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error { timeout := c.globalConfig.GetInt(configuration.TIMEOUT) - var ctx context.Context var cancel context.CancelFunc - if timeout == 0 { - ctx = context.Background() - } else { - ctx, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second) defer cancel() } @@ -545,7 +542,7 @@ func GetErrorFromFile(execErr error, errFilePath string, config configuration.Co return nil, ErrIPCNoDataSent } -func (c *CLI) Execute(proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error { +func (c *CLI) Execute(ctx context.Context, proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error { var err error handler := determineHandler(passThroughArgs) @@ -553,11 +550,11 @@ func (c *CLI) Execute(proxyInfo *proxy.ProxyInfo, passThroughArgs []string) erro case V2_VERSION: err = c.commandVersion(passThroughArgs) case V2_ABOUT: - err = c.commandAbout(proxyInfo, passThroughArgs) + err = c.commandAbout(ctx, proxyInfo, passThroughArgs) case V1_DEFAULT: fallthrough default: - err = c.executeV1Default(proxyInfo, passThroughArgs) + err = c.executeV1Default(ctx, proxyInfo, passThroughArgs) } return err diff --git a/cliv2/internal/cliv2/cliv2_test.go b/cliv2/internal/cliv2/cliv2_test.go index be38d023fc..fba047b1cf 100644 --- a/cliv2/internal/cliv2/cliv2_test.go +++ b/cliv2/internal/cliv2/cliv2_test.go @@ -408,7 +408,7 @@ func Test_extractOnlyOnce(t *testing.T) { assert.NoError(t, cli.Init()) // run once - err = cli.Execute(getProxyInfoForTest(), []string{"--help"}) + err = cli.Execute(context.Background(), getProxyInfoForTest(), []string{"--help"}) assert.Error(t, err) // invalid binary expected here assert.FileExists(t, cli.GetBinaryLocation()) fileInfo1, err := os.Stat(cli.GetBinaryLocation()) @@ -419,7 +419,7 @@ func Test_extractOnlyOnce(t *testing.T) { // run twice assert.Nil(t, cli.Init()) - err = cli.Execute(getProxyInfoForTest(), []string{"--help"}) + err = cli.Execute(context.Background(), getProxyInfoForTest(), []string{"--help"}) assert.Error(t, err) // invalid binary expected here assert.FileExists(t, cli.GetBinaryLocation()) fileInfo2, err := os.Stat(cli.GetBinaryLocation()) @@ -479,7 +479,7 @@ func Test_executeRunV2only(t *testing.T) { assert.NoError(t, err) assert.NoError(t, cli.Init()) - actualReturnCode := cliv2.DeriveExitCode(cli.Execute(getProxyInfoForTest(), []string{"--version"})) + actualReturnCode := cliv2.DeriveExitCode(cli.Execute(context.Background(), getProxyInfoForTest(), []string{"--version"})) assert.Equal(t, expectedReturnCode, actualReturnCode) assert.FileExists(t, cli.GetBinaryLocation()) } @@ -496,7 +496,7 @@ func Test_executeUnknownCommand(t *testing.T) { assert.NoError(t, err) assert.NoError(t, cli.Init()) - actualReturnCode := cliv2.DeriveExitCode(cli.Execute(getProxyInfoForTest(), []string{"bogusCommand"})) + actualReturnCode := cliv2.DeriveExitCode(cli.Execute(context.Background(), getProxyInfoForTest(), []string{"bogusCommand"})) assert.Equal(t, expectedReturnCode, actualReturnCode) } @@ -590,7 +590,7 @@ func Test_setTimeout(t *testing.T) { // sleep for 2s cli.SetV1BinaryLocation("/bin/sleep") - err = cli.Execute(getProxyInfoForTest(), []string{"2"}) + err = cli.Execute(context.Background(), getProxyInfoForTest(), []string{"2"}) assert.ErrorIs(t, err, context.DeadlineExceeded) } diff --git a/cliv2/pkg/basic_workflows/legacycli.go b/cliv2/pkg/basic_workflows/legacycli.go index 9a9f711f37..24e4a7ffac 100644 --- a/cliv2/pkg/basic_workflows/legacycli.go +++ b/cliv2/pkg/basic_workflows/legacycli.go @@ -151,9 +151,9 @@ func legacycliWorkflow( return output, err } - // run the cli + // run the cli with context from invocation (allows cancellation on signal) proxyInfo := wrapperProxy.ProxyInfo() - err = cli.Execute(proxyInfo, finalizeArguments(args, config.GetStringSlice(configuration.UNKNOWN_ARGS))) + err = cli.Execute(invocation.Context(), proxyInfo, finalizeArguments(args, config.GetStringSlice(configuration.UNKNOWN_ARGS))) if !useStdIo { _ = outWriter.Flush() diff --git a/jest.config.js b/jest.config.js index 7b32377e82..39f8f6fdb9 100644 --- a/jest.config.js +++ b/jest.config.js @@ -5,5 +5,6 @@ module.exports = createJestConfig({ displayName: 'coreCli', projects: ['', '/packages/*'], globalSetup: './test/setup.js', + globalTeardown: './test/teardown.js', setupFilesAfterEnv: ['./test/setup-jest.ts'], }); diff --git a/package-lock.json b/package-lock.json index 484da96489..9f77b27ac8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -127,6 +127,7 @@ "node-loader": "^2.0.0", "npm-run-all": "^4.1.5", "patch-package": "^6.5.0", + "pidtree": "^0.6.0", "portfinder": "^1.0.38", "prettier": "^3.3.3", "proxyquire": "^1.7.4", @@ -16193,6 +16194,18 @@ "node": ">= 4" } }, + "node_modules/npm-run-all/node_modules/pidtree": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/pidtree/-/pidtree-0.3.1.tgz", + "integrity": "sha512-qQbW94hLHEqCg7nhby4yRC7G2+jYHY4Rguc2bjw7Uug4GIJuu1tvf2uHaZv5Q8zdt+WKJ6qK1FOI6amaWUo5FA==", + "dev": true, + "bin": { + "pidtree": "bin/pidtree.js" + }, + "engines": { + "node": ">=0.10" + } + }, "node_modules/npm-run-path": { "version": "4.0.1", "license": "MIT", @@ -16992,9 +17005,10 @@ } }, "node_modules/pidtree": { - "version": "0.3.1", + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/pidtree/-/pidtree-0.6.0.tgz", + "integrity": "sha512-eG2dWTVw5bzqGRztnHExczNxt5VGsE6OwTeCG3fdUf9KBsZzO3R5OIIIzWR+iZA0NtZ+RDVdaoE2dK1cn6jH4g==", "dev": true, - "license": "MIT", "bin": { "pidtree": "bin/pidtree.js" }, diff --git a/package.json b/package.json index 340948260b..6e45343abe 100644 --- a/package.json +++ b/package.json @@ -172,6 +172,7 @@ "node-loader": "^2.0.0", "npm-run-all": "^4.1.5", "patch-package": "^6.5.0", + "pidtree": "^0.6.0", "portfinder": "^1.0.38", "prettier": "^3.3.3", "proxyquire": "^1.7.4", diff --git a/test/jest/acceptance/signal-handling.spec.ts b/test/jest/acceptance/signal-handling.spec.ts new file mode 100644 index 0000000000..6e6c9407a9 --- /dev/null +++ b/test/jest/acceptance/signal-handling.spec.ts @@ -0,0 +1,44 @@ +import { startSnykCLI } from '../util/startSnykCLI'; + +jest.setTimeout(1000 * 60); + +describe('signal handling', () => { + const env = { + ...process.env, + SNYK_CFG_PREVIEW_FEATURES_ENABLED: 'true', + }; + + it('exits with SNYK-CLI-0025 when receiving SIGINT', async () => { + // Use 'test -d' with debug output which takes time to initialize + const cli = await startSnykCLI('test -d', { env }); + + // Wait for CLI to initialize and start processing + await new Promise((r) => setTimeout(r, 2000)); + + // Send SIGINT to the CLI process + cli.process.kill('SIGINT'); + + const exitCode = await cli.wait({ timeout: 15000 }); + const stdout = cli.stdout.get(); + + expect(exitCode).toBe(2); + expect(stdout).toContain('SNYK-CLI-0025'); + }); + + it('exits with SNYK-CLI-0025 when receiving SIGTERM', async () => { + // Use 'test -d' with debug output which takes time to initialize + const cli = await startSnykCLI('test -d', { env }); + + // Wait for CLI to initialize and start processing + await new Promise((r) => setTimeout(r, 2000)); + + // Send SIGTERM to the CLI process + cli.process.kill('SIGTERM'); + + const exitCode = await cli.wait({ timeout: 15000 }); + const stdout = cli.stdout.get(); + + expect(exitCode).toBe(2); + expect(stdout).toContain('SNYK-CLI-0025'); + }); +}); diff --git a/test/jest/util/startSnykCLI.ts b/test/jest/util/startSnykCLI.ts index 9b891e6528..296c5ae394 100644 --- a/test/jest/util/startSnykCLI.ts +++ b/test/jest/util/startSnykCLI.ts @@ -86,31 +86,46 @@ const createTestCLI = (child: ChildProcessWithoutNullStreams) => { /** * Waits for process to exit and provides the exit code. + * When a process is killed by a signal, code may be null and signal will be set. */ const wait = async ({ timeout = DEFAULT_ASSERTION_TIMEOUT, } = {}): Promise => { - if (child.killed) { - return child.exitCode || 0; + // If process has already exited, return the exit code immediately + if (child.exitCode !== null) { + return child.exitCode; } return new Promise((resolve, reject) => { const onTimeout = () => { child.removeListener('error', onError); child.removeListener('close', onClose); + child.removeListener('exit', onExit); reject(new AssertionTimeoutError('wait', timeout)); }; const onError = (error) => { clearTimeout(timeoutId); + child.removeListener('close', onClose); + child.removeListener('exit', onExit); reject(error); }; + let exitCode: number | null = null; const onClose = (code) => { clearTimeout(timeoutId); - resolve(code || 0); + child.removeListener('error', onError); + child.removeListener('exit', onExit); + // Use exit code captured from 'exit' event, or from 'close' event, or child.exitCode + const finalCode = exitCode ?? code ?? child.exitCode ?? 0; + resolve(finalCode); + }; + const onExit = (code) => { + // 'exit' fires before 'close' - capture the exit code here + exitCode = code; }; const timeoutId = setTimeout(onTimeout, timeout); child.once('error', onError); child.once('close', onClose); + child.once('exit', onExit); }); }; diff --git a/test/teardown.js b/test/teardown.js new file mode 100644 index 0000000000..11539aa0a8 --- /dev/null +++ b/test/teardown.js @@ -0,0 +1,121 @@ +const { execFileSync } = require('child_process'); +const pidtree = require('pidtree'); + +// Timeout to wait for graceful shutdown before sending SIGKILL +const GRACEFUL_SHUTDOWN_TIMEOUT = 5000; + +/** + * Get process name/command for a given PID. + * Returns null if process doesn't exist or on error. + * Uses execFileSync with array args to avoid command injection. + */ +function getProcessName(pid) { + // Validate PID is a positive integer + const pidNum = Number(pid); + if (!Number.isInteger(pidNum) || pidNum <= 0) { + return null; + } + const pidStr = String(pidNum); + + try { + if (process.platform === 'win32') { + const output = execFileSync( + 'wmic', + ['process', 'where', `ProcessId=${pidStr}`, 'get', 'CommandLine'], + { encoding: 'utf-8', timeout: 1000 }, + ); + return output.toLowerCase(); + } else { + // Unix: use ps to get command + const output = execFileSync('ps', ['-p', pidStr, '-o', 'comm='], { + encoding: 'utf-8', + timeout: 1000, + }); + return output.trim().toLowerCase(); + } + } catch (err) { + return null; + } +} + +/** + * Check if a process is a Snyk CLI process. + */ +function isSnykProcess(pid) { + const name = getProcessName(pid); + if (!name) return false; + return name.includes('snyk'); +} + +/** + * Global teardown that kills any orphaned Snyk CLI processes. + * + * When Jest tests timeout, spawned CLI processes may be left running. + * This teardown finds all descendant processes of the Jest runner, + * filters for Snyk CLI processes, and sends SIGTERM (graceful) then + * SIGKILL (force) to ensure they exit and have a chance to send + * instrumentation data. + */ +module.exports = async function globalTeardown() { + const jestPid = process.pid; + + let childPids; + try { + childPids = await pidtree(jestPid); + } catch (err) { + // No children or error getting process tree + return; + } + + if (!childPids || childPids.length === 0) { + return; + } + + // Filter for Snyk CLI processes only + const snykPids = childPids.filter(isSnykProcess); + + if (snykPids.length === 0) { + return; + } + + console.log( + `[teardown] Found ${snykPids.length} orphaned Snyk CLI process(es), sending SIGTERM...`, + ); + + // Send SIGTERM to Snyk processes for graceful shutdown + for (const pid of snykPids) { + try { + process.kill(pid, 'SIGTERM'); + } catch (err) { + // Process may have already exited + } + } + + // Wait for graceful shutdown (CLI teardown timeout is 5s) + await new Promise((resolve) => + setTimeout(resolve, GRACEFUL_SHUTDOWN_TIMEOUT), + ); + + // Check which Snyk processes are still running and send SIGKILL + let remainingPids; + try { + remainingPids = await pidtree(jestPid); + } catch (err) { + return; + } + + const remainingSnykPids = (remainingPids || []).filter(isSnykProcess); + + if (remainingSnykPids.length > 0) { + console.log( + `[teardown] ${remainingSnykPids.length} Snyk process(es) still running, sending SIGKILL...`, + ); + for (const pid of remainingSnykPids) { + try { + process.kill(pid, 'SIGKILL'); + } catch (err) { + // Process may have already exited + } + } + } +};