diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 28e40306..5c1c16a2 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -715,8 +715,24 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq } func isConsoleInitializationRequired(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments) bool { + // Password input always requires console initialization + if connect.RequiresPassword() { + return true + } + + // Check if stdin is from a terminal or a redirection + file, err := os.Stdin.Stat() + if err == nil { + // If stdin is not a character device, it's coming from a pipe or redirect + if (file.Mode() & os.ModeCharDevice) == 0 { + // Non-interactive: stdin is redirected + return false + } + } + + // If we get here, stdin is from a terminal or we couldn't determine iactive := args.InputFile == nil && args.Query == "" && len(args.ChangePasswordAndExit) == 0 - return iactive || connect.RequiresPassword() + return iactive } func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { diff --git a/cmd/sqlcmd/stdin_console_test.go b/cmd/sqlcmd/stdin_console_test.go new file mode 100644 index 00000000..2c9f7ba5 --- /dev/null +++ b/cmd/sqlcmd/stdin_console_test.go @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "os" + "testing" + + "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" + "github.com/stretchr/testify/assert" +) + +func TestIsConsoleInitializationRequiredWithRedirectedStdin(t *testing.T) { + // Create a temp file to simulate redirected stdin + tempFile, err := os.CreateTemp("", "stdin-test-*.txt") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tempFile.Name()) + defer tempFile.Close() + + // Write some data to it + _, err = tempFile.WriteString("SELECT 1;\nGO\n") + if err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + + // Remember the original stdin + originalStdin := os.Stdin + defer func() { os.Stdin = originalStdin }() + + // Test with a file redirection + stdinFile, err := os.Open(tempFile.Name()) + if err != nil { + t.Fatalf("Failed to open temp file: %v", err) + } + defer stdinFile.Close() + + // Replace stdin with our redirected file + os.Stdin = stdinFile + + // Set up a connect settings instance for SQL authentication + connectConfig := sqlcmd.ConnectSettings{ + UserName: "testuser", // This will trigger SQL authentication, requiring a password + } + + // Test regular args + args := &SQLCmdArguments{} + + // Print file stat mode for debugging + fileStat, _ := os.Stdin.Stat() + t.Logf("File mode: %v", fileStat.Mode()) + t.Logf("Is character device: %v", (fileStat.Mode()&os.ModeCharDevice) != 0) + t.Logf("Connection config: %+v", connectConfig) + t.Logf("RequiresPassword() returns: %v", connectConfig.RequiresPassword()) + + // Test with SQL authentication that requires a password + res := isConsoleInitializationRequired(&connectConfig, args) + // Should be true since password is required, even with redirected stdin + assert.True(t, res, "Console initialization should be required when SQL authentication is used") + + // Now test with no authentication (no password required) + connectConfig = sqlcmd.ConnectSettings{} + res = isConsoleInitializationRequired(&connectConfig, args) + // Should be false since stdin is redirected and no password is required + assert.False(t, res, "Console initialization should not be required with redirected stdin and no password") +} diff --git a/go.mod b/go.mod index 7245c502..bc51d653 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.14.0 github.com/stretchr/testify v1.10.0 - golang.org/x/sys v0.32.0 + golang.org/x/sys v0.33.0 golang.org/x/text v0.24.0 golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d gopkg.in/yaml.v2 v2.4.0 @@ -85,6 +85,7 @@ require ( golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.39.0 // indirect golang.org/x/sync v0.13.0 // indirect + golang.org/x/term v0.32.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250425173222-7b384671a197 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250425173222-7b384671a197 // indirect google.golang.org/grpc v1.71.1 // indirect diff --git a/go.sum b/go.sum index 7aad4257..993b9f39 100644 --- a/go.sum +++ b/go.sum @@ -519,9 +519,11 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/pkg/console/console.go b/pkg/console/console.go index 4e0bb001..4eebded7 100644 --- a/pkg/console/console.go +++ b/pkg/console/console.go @@ -4,6 +4,7 @@ package console import ( + "bufio" "os" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" @@ -11,9 +12,11 @@ import ( ) type console struct { - impl *liner.State - historyFile string - prompt string + impl *liner.State + historyFile string + prompt string + stdinRedirected bool + stdinReader *bufio.Reader } // NewConsole creates a sqlcmdConsole implementation that provides these features: @@ -21,15 +24,21 @@ type console struct { // - Simple tab key completion of SQL keywords func NewConsole(historyFile string) sqlcmd.Console { c := &console{ - impl: liner.NewLiner(), - historyFile: historyFile, + impl: liner.NewLiner(), + historyFile: historyFile, + stdinRedirected: isStdinRedirected(), } - c.impl.SetCtrlCAborts(true) - c.impl.SetCompleter(CompleteLine) - if c.historyFile != "" { - if f, err := os.Open(historyFile); err == nil { - _, _ = c.impl.ReadHistory(f) - f.Close() + + if c.stdinRedirected { + c.stdinReader = bufio.NewReader(os.Stdin) + } else { + c.impl.SetCtrlCAborts(true) + c.impl.SetCompleter(CompleteLine) + if c.historyFile != "" { + if f, err := os.Open(historyFile); err == nil { + _, _ = c.impl.ReadHistory(f) + f.Close() + } } } return c @@ -37,19 +46,41 @@ func NewConsole(historyFile string) sqlcmd.Console { // Close writes out the history data to disk and closes the console buffers func (c *console) Close() { - if c.historyFile != "" { + if !c.stdinRedirected && c.historyFile != "" { if f, err := os.Create(c.historyFile); err == nil { _, _ = c.impl.WriteHistory(f) f.Close() } } - c.impl.Close() + + if !c.stdinRedirected { + c.impl.Close() + } } // Readline displays the current prompt and returns a line of text entered by the user. // It appends the returned line to the history buffer. // If the user presses Ctrl-C the error returned is sqlcmd.ErrCtrlC +// If stdin is redirected, it reads directly from stdin without displaying prompts func (c *console) Readline() (string, error) { + // Handle redirected stdin without displaying prompts + if c.stdinRedirected { + line, err := c.stdinReader.ReadString('\n') + if err != nil { + return "", err + } + // Trim the trailing newline + if len(line) > 0 && line[len(line)-1] == '\n' { + line = line[:len(line)-1] + // Also trim carriage return if present + if len(line) > 0 && line[len(line)-1] == '\r' { + line = line[:len(line)-1] + } + } + return line, nil + } + + // Interactive terminal mode with prompts s, err := c.impl.Prompt(c.prompt) if err == liner.ErrPromptAborted { return "", sqlcmd.ErrCtrlC @@ -61,6 +92,8 @@ func (c *console) Readline() (string, error) { // ReadPassword displays the given prompt and returns the password entered by the user. // If the user presses Ctrl-C the error returned is sqlcmd.ErrCtrlC func (c *console) ReadPassword(prompt string) ([]byte, error) { + // Even when stdin is redirected, we need to use the prompt for passwords + // since they should not be read from the redirected input b, err := c.impl.PasswordPrompt(prompt) if err == liner.ErrPromptAborted { return []byte{}, sqlcmd.ErrCtrlC diff --git a/pkg/console/console_redirect.go b/pkg/console/console_redirect.go new file mode 100644 index 00000000..b09486d9 --- /dev/null +++ b/pkg/console/console_redirect.go @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package console + +import ( + "os" + "golang.org/x/term" +) + +// isStdinRedirected checks if stdin is coming from a pipe or redirection +func isStdinRedirected() bool { + stat, err := os.Stdin.Stat() + if err != nil { + // If we can't determine, assume it's not redirected + return false + } + + // If it's not a character device, it's coming from a pipe or redirection + if (stat.Mode() & os.ModeCharDevice) == 0 { + return true + } + + // Double-check using term.IsTerminal + fd := int(os.Stdin.Fd()) + return !term.IsTerminal(fd) +} \ No newline at end of file diff --git a/pkg/console/console_redirect_test.go b/pkg/console/console_redirect_test.go new file mode 100644 index 00000000..f26cab98 --- /dev/null +++ b/pkg/console/console_redirect_test.go @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package console + +import ( + "io" + "os" + "testing" +) + +func TestStdinRedirectionDetection(t *testing.T) { + // Save original stdin + origStdin := os.Stdin + defer func() { os.Stdin = origStdin }() + + // Create a pipe + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Couldn't create pipe: %v", err) + } + defer r.Close() + defer w.Close() + + // Replace stdin with our pipe + os.Stdin = r + + // Test if stdin is properly detected as redirected + if !isStdinRedirected() { + t.Errorf("Pipe input should be detected as redirected") + } + + // Write some test input + go func() { + _, _ = io.WriteString(w, "test input\n") + w.Close() + }() + + // Create console with redirected stdin + console := NewConsole("") + + // Test readline + line, err := console.Readline() + if err != nil { + t.Fatalf("Failed to read from redirected stdin: %v", err) + } + + if line != "test input" { + t.Errorf("Expected 'test input', got '%s'", line) + } + + // Clean up + console.Close() +} \ No newline at end of file