Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion cmd/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,24 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
}

func isConsoleInitializationRequired(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments) bool {
// Password input always requires console initialization
if connect.RequiresPassword() {
return true
}

// Check if stdin is from a terminal or a redirection
file, err := os.Stdin.Stat()
if err == nil {
// If stdin is not a character device, it's coming from a pipe or redirect
if (file.Mode() & os.ModeCharDevice) == 0 {
// Non-interactive: stdin is redirected
return false
}
}

// If we get here, stdin is from a terminal or we couldn't determine
iactive := args.InputFile == nil && args.Query == "" && len(args.ChangePasswordAndExit) == 0
return iactive || connect.RequiresPassword()
return iactive
}

func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
Expand Down
68 changes: 68 additions & 0 deletions cmd/sqlcmd/stdin_console_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package sqlcmd

import (
"os"
"testing"

"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
"github.com/stretchr/testify/assert"
)

func TestIsConsoleInitializationRequiredWithRedirectedStdin(t *testing.T) {
// Create a temp file to simulate redirected stdin
tempFile, err := os.CreateTemp("", "stdin-test-*.txt")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempFile.Name())
defer tempFile.Close()

// Write some data to it
_, err = tempFile.WriteString("SELECT 1;\nGO\n")
if err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}

// Remember the original stdin
originalStdin := os.Stdin
defer func() { os.Stdin = originalStdin }()

// Test with a file redirection
stdinFile, err := os.Open(tempFile.Name())
if err != nil {
t.Fatalf("Failed to open temp file: %v", err)
}
defer stdinFile.Close()

// Replace stdin with our redirected file
os.Stdin = stdinFile

// Set up a connect settings instance for SQL authentication
connectConfig := sqlcmd.ConnectSettings{
UserName: "testuser", // This will trigger SQL authentication, requiring a password
}

// Test regular args
args := &SQLCmdArguments{}

// Print file stat mode for debugging
fileStat, _ := os.Stdin.Stat()
t.Logf("File mode: %v", fileStat.Mode())
t.Logf("Is character device: %v", (fileStat.Mode()&os.ModeCharDevice) != 0)
t.Logf("Connection config: %+v", connectConfig)
t.Logf("RequiresPassword() returns: %v", connectConfig.RequiresPassword())
Comment on lines +53 to +56
Copy link

Copilot AI May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] These t.Logf calls appear to be leftover debug logging and do not contribute to the test assertions. Consider removing them to keep the test output clean.

Suggested change
t.Logf("File mode: %v", fileStat.Mode())
t.Logf("Is character device: %v", (fileStat.Mode()&os.ModeCharDevice) != 0)
t.Logf("Connection config: %+v", connectConfig)
t.Logf("RequiresPassword() returns: %v", connectConfig.RequiresPassword())

Copilot uses AI. Check for mistakes.

// Test with SQL authentication that requires a password
res := isConsoleInitializationRequired(&connectConfig, args)
// Should be true since password is required, even with redirected stdin
assert.True(t, res, "Console initialization should be required when SQL authentication is used")

// Now test with no authentication (no password required)
connectConfig = sqlcmd.ConnectSettings{}
res = isConsoleInitializationRequired(&connectConfig, args)
// Should be false since stdin is redirected and no password is required
assert.False(t, res, "Console initialization should not be required with redirected stdin and no password")
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ require (
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.14.0
github.com/stretchr/testify v1.10.0
golang.org/x/sys v0.32.0
golang.org/x/sys v0.33.0
golang.org/x/text v0.24.0
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d
gopkg.in/yaml.v2 v2.4.0
Expand Down Expand Up @@ -85,6 +85,7 @@ require (
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.39.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/term v0.32.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250425173222-7b384671a197 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250425173222-7b384671a197 // indirect
google.golang.org/grpc v1.71.1 // indirect
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,11 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
59 changes: 46 additions & 13 deletions pkg/console/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,83 @@
package console

import (
"bufio"
"os"

"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
"github.com/peterh/liner"
)

type console struct {
impl *liner.State
historyFile string
prompt string
impl *liner.State
historyFile string
prompt string
stdinRedirected bool
stdinReader *bufio.Reader
}

// NewConsole creates a sqlcmdConsole implementation that provides these features:
// - Storage of input history to a local file. History can be scrolled through using the up and down arrow keys.
// - Simple tab key completion of SQL keywords
func NewConsole(historyFile string) sqlcmd.Console {
c := &console{
impl: liner.NewLiner(),
historyFile: historyFile,
impl: liner.NewLiner(),
historyFile: historyFile,
stdinRedirected: isStdinRedirected(),
Copy link

Copilot AI May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompt field is never initialized in NewConsole, so interactive prompts will always be empty. Consider setting a default prompt value (e.g. "sqlcmd> ") when stdin is a terminal.

Suggested change
stdinRedirected: isStdinRedirected(),
stdinRedirected: isStdinRedirected(),
prompt: "sqlcmd> ",

Copilot uses AI. Check for mistakes.
}
c.impl.SetCtrlCAborts(true)
c.impl.SetCompleter(CompleteLine)
if c.historyFile != "" {
if f, err := os.Open(historyFile); err == nil {
_, _ = c.impl.ReadHistory(f)
f.Close()

if c.stdinRedirected {
c.stdinReader = bufio.NewReader(os.Stdin)
} else {
c.impl.SetCtrlCAborts(true)
c.impl.SetCompleter(CompleteLine)
if c.historyFile != "" {
if f, err := os.Open(historyFile); err == nil {
_, _ = c.impl.ReadHistory(f)
f.Close()
}
}
}
return c
}

// Close writes out the history data to disk and closes the console buffers
func (c *console) Close() {
if c.historyFile != "" {
if !c.stdinRedirected && c.historyFile != "" {
if f, err := os.Create(c.historyFile); err == nil {
_, _ = c.impl.WriteHistory(f)
f.Close()
}
}
c.impl.Close()

if !c.stdinRedirected {
c.impl.Close()
}
}

// Readline displays the current prompt and returns a line of text entered by the user.
// It appends the returned line to the history buffer.
// If the user presses Ctrl-C the error returned is sqlcmd.ErrCtrlC
// If stdin is redirected, it reads directly from stdin without displaying prompts
func (c *console) Readline() (string, error) {
// Handle redirected stdin without displaying prompts
if c.stdinRedirected {
line, err := c.stdinReader.ReadString('\n')
if err != nil {
return "", err
}
// Trim the trailing newline
if len(line) > 0 && line[len(line)-1] == '\n' {
line = line[:len(line)-1]
// Also trim carriage return if present
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
}
return line, nil
}

// Interactive terminal mode with prompts
s, err := c.impl.Prompt(c.prompt)
if err == liner.ErrPromptAborted {
return "", sqlcmd.ErrCtrlC
Expand All @@ -61,6 +92,8 @@ func (c *console) Readline() (string, error) {
// ReadPassword displays the given prompt and returns the password entered by the user.
// If the user presses Ctrl-C the error returned is sqlcmd.ErrCtrlC
func (c *console) ReadPassword(prompt string) ([]byte, error) {
// Even when stdin is redirected, we need to use the prompt for passwords
// since they should not be read from the redirected input
b, err := c.impl.PasswordPrompt(prompt)
if err == liner.ErrPromptAborted {
return []byte{}, sqlcmd.ErrCtrlC
Expand Down
54 changes: 54 additions & 0 deletions pkg/console/console_redirect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package console

import (
"io"
"os"
"testing"
)

func TestStdinRedirectionDetection(t *testing.T) {
// Save original stdin
origStdin := os.Stdin
defer func() { os.Stdin = origStdin }()

// Create a pipe
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("Couldn't create pipe: %v", err)
}
defer r.Close()
defer w.Close()

// Replace stdin with our pipe
os.Stdin = r

// Test if stdin is properly detected as redirected
if !isStdinRedirected() {
t.Errorf("Pipe input should be detected as redirected")
}

// Write some test input
go func() {
_, _ = io.WriteString(w, "test input\n")
w.Close()
}()

// Create console with redirected stdin
console := NewConsole("")

// Test readline
line, err := console.Readline()
if err != nil {
t.Fatalf("Failed to read from redirected stdin: %v", err)
}

if line != "test input" {
t.Errorf("Expected 'test input', got '%s'", line)
}

// Clean up
console.Close()
}
30 changes: 30 additions & 0 deletions pkg/console/console_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

//go:build !windows
// +build !windows

package console

import (
"os"
"golang.org/x/term"
)

// isStdinRedirected checks if stdin is coming from a pipe or redirection
func isStdinRedirected() bool {
stat, err := os.Stdin.Stat()
if err != nil {
// If we can't determine, assume it's not redirected
return false
}

// If it's not a character device, it's coming from a pipe or redirection
if (stat.Mode() & os.ModeCharDevice) == 0 {
return true
}

// Double-check using isatty
fd := int(os.Stdin.Fd())
return !term.IsTerminal(fd)
}
30 changes: 30 additions & 0 deletions pkg/console/console_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

//go:build windows
// +build windows

package console

import (
"os"
"golang.org/x/term"
)

// isStdinRedirected checks if stdin is coming from a pipe or redirection
func isStdinRedirected() bool {
stat, err := os.Stdin.Stat()
if err != nil {
// If we can't determine, assume it's not redirected
return false
}

// If it's not a character device, it's coming from a pipe or redirection
if (stat.Mode() & os.ModeCharDevice) == 0 {
return true
}

// For Windows, check if stdin is a terminal
fd := int(os.Stdin.Fd())
return !term.IsTerminal(fd)
}
Loading