Selaa lähdekoodia

Simplify terminal setup functions.

Remove SetupGlobal and Sanitize and replace them with a SetupForEval, which is
similar to Setup, but called before evaluating code and returns a function to
call after evaluating code.
Qi Xiao 2 vuotta sitten
vanhempi
säilyke
e2901118a5

+ 4 - 8
pkg/cli/term/setup.go

@@ -15,14 +15,10 @@ func Setup(in, out *os.File) (func() error, error) {
 	return setup(in, out)
 }
 
-// SetupGlobal sets up the terminal for the entire Elvish session.
-func SetupGlobal(in, out *os.File) func() {
-	return setupGlobal(in, out)
-}
-
-// Sanitize sanitizes the terminal after an external command has executed.
-func Sanitize(in, out *os.File) {
-	sanitize(in, out)
+// SetupForEval sets up the terminal for evaluating Elvish code. It returns a
+// function to call after the evaluation finishes.
+func SetupForEval(in, out *os.File) func() {
+	return setupForEval(in, out)
 }
 
 const (

+ 4 - 2
pkg/cli/term/setup_unix.go

@@ -52,8 +52,10 @@ func setup(in, out *os.File) (func() error, error) {
 	return restore, errSetupVT
 }
 
-func setupGlobal(in, out *os.File) func() {
-	return func() {}
+func setupForEval(in, out *os.File) func() {
+	// There is nothing to set up on UNIX, but we try to sanitize the terminal
+	// when evaluation finishes.
+	return func() { sanitize(in, out) }
 }
 
 func sanitize(in, out *os.File) {

+ 12 - 34
pkg/cli/term/setup_windows.go

@@ -8,17 +8,13 @@ import (
 )
 
 const (
-	wantedInMode = windows.ENABLE_WINDOW_INPUT |
+	inMode = windows.ENABLE_WINDOW_INPUT |
 		windows.ENABLE_MOUSE_INPUT | windows.ENABLE_PROCESSED_INPUT
-	wantedOutMode = windows.ENABLE_PROCESSED_OUTPUT |
+	outMode = windows.ENABLE_PROCESSED_OUTPUT |
 		windows.ENABLE_WRAP_AT_EOL_OUTPUT |
 		windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
-
-	additionalGlobalOutMode = windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
 )
 
-var globalOldInMode, globalOldOutMode uint32
-
 func setup(in, out *os.File) (func() error, error) {
 	hIn := windows.Handle(in.Fd())
 	hOut := windows.Handle(out.Fd())
@@ -33,8 +29,8 @@ func setup(in, out *os.File) (func() error, error) {
 		return nil, err
 	}
 
-	errSetIn := windows.SetConsoleMode(hIn, wantedInMode)
-	errSetOut := windows.SetConsoleMode(hOut, wantedOutMode)
+	errSetIn := windows.SetConsoleMode(hIn, inMode)
+	errSetOut := windows.SetConsoleMode(hOut, outMode)
 	errVT := setupVT(out)
 
 	return func() error {
@@ -45,32 +41,14 @@ func setup(in, out *os.File) (func() error, error) {
 	}, diag.Errors(errSetIn, errSetOut, errVT)
 }
 
-func setupGlobal(in, out *os.File) func() {
-	hIn := windows.Handle(in.Fd())
-	hOut := windows.Handle(out.Fd())
+const outFlagForEval = windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
 
-	err := windows.GetConsoleMode(hIn, &globalOldInMode)
-	if err != nil {
-		return func() {}
-	}
-	err = windows.GetConsoleMode(hOut, &globalOldOutMode)
-	if err != nil {
-		return func() {}
+func setupForEval(_, out *os.File) func() {
+	h := windows.Handle(out.Fd())
+	var oldOutMode uint32
+	err := windows.GetConsoleMode(h, &oldOutMode)
+	if err == nil {
+		windows.SetConsoleMode(h, oldOutMode|outFlagForEval)
 	}
-
-	windows.SetConsoleMode(hIn, globalOldInMode)
-	windows.SetConsoleMode(hOut, globalOldOutMode|additionalGlobalOutMode)
-
-	return func() {
-		windows.SetConsoleMode(hIn, globalOldInMode)
-		windows.SetConsoleMode(hOut, globalOldOutMode)
-	}
-}
-
-func sanitize(in, out *os.File) {
-	hIn := windows.Handle(in.Fd())
-	hOut := windows.Handle(out.Fd())
-
-	windows.SetConsoleMode(hIn, globalOldInMode)
-	windows.SetConsoleMode(hOut, globalOldOutMode|additionalGlobalOutMode)
+	return func() {}
 }

+ 35 - 80
pkg/cli/term/setup_windows_test.go

@@ -1,109 +1,64 @@
 package term
 
 import (
-	"fmt"
 	"os"
 	"testing"
 
 	"golang.org/x/sys/windows"
 )
 
-func TestSetupGlobalTerminal(t *testing.T) {
-	in, out, release, err := createStdInOut()
-	if err != nil {
-		t.Errorf("cannot open stdin/stdout %v", err)
-		return
-	}
-	defer release()
-
-	initialOutMode, _ := getConsoleMode(out)
-	initialOutMode = initialOutMode &^ windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
-	setConsoleMode(out, initialOutMode)
+func TestSetupForEval(t *testing.T) {
+	// open CONOUT$ manually because os.Stdout is redirected during testing
+	out := openFile(t, "CONOUT$", os.O_RDWR, 0)
+	defer out.Close()
 
-	// check that mode is for control sequences
-	restore := setupGlobal(in, out)
-	err = assertConsoleMode(
-		out,
-		initialOutMode|windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING)
-	if err != nil {
-		t.Errorf("got err %v, want nil", err)
-		return
-	}
+	// Start with ENABLE_VIRTUAL_TERMINAL_PROCESSING
+	initialOutMode := getConsoleMode(t, out) | windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
+	setConsoleMode(t, out, initialOutMode)
 
-	// check that mode is restored
-	restore()
-	err = assertConsoleMode(
-		out,
-		initialOutMode)
-	if err != nil {
-		t.Errorf("got err %v, want nil", err)
-		return
-	}
-}
+	// Clear ENABLE_VIRTUAL_TERMINAL_PROCESSING
+	modifiedOutMode := initialOutMode &^ windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
+	setConsoleMode(t, out, modifiedOutMode)
 
-func TestSanitizeTerminal(t *testing.T) {
-	in, out, release, err := createStdInOut()
-	if err != nil {
-		t.Errorf("cannot open stdin/stdout %v", err)
-		return
+	// Check that SetupForEval sets ENABLE_VIRTUAL_TERMINAL_PROCESSING without
+	// changing other bits
+	restore := setupForEval(nil, out)
+	if got := getConsoleMode(t, out); got != initialOutMode {
+		t.Errorf("got console mode %v, want %v", got, initialOutMode)
 	}
-	defer release()
-
-	initialOutMode, _ := getConsoleMode(out)
-	setConsoleMode(out, initialOutMode)
-
-	setupGlobal(in, out)
 
-	// break console mode
-	setConsoleMode(out, initialOutMode&^windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING)
+	// Check that restore is a no-op
+	setConsoleMode(t, out, modifiedOutMode)
 
-	sanitize(in, out)
-
-	// check that sanitized
-	err = assertConsoleMode(
-		out,
-		initialOutMode|windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING)
-	if err != nil {
-		t.Errorf("got err %v, want nil", err)
-		return
+	restore()
+	if got := getConsoleMode(t, out); got != modifiedOutMode {
+		t.Errorf("got console mode %v, want %v", got, modifiedOutMode)
 	}
 }
 
-func assertConsoleMode(file *os.File, want uint32) error {
-	got, err := getConsoleMode(file)
+func openFile(t *testing.T, name string, flag int, perm os.FileMode) *os.File {
+	t.Helper()
+	out, err := os.OpenFile(name, flag, perm)
 	if err != nil {
-		return err
-	} else if got != want {
-		return fmt.Errorf("got %b, want %b", got, want)
-	} else {
-		return nil
+		t.Fatalf("open %s: %v", name, err)
 	}
+	return out
 }
 
-// open stdin/stdout manually because os.Stdin/os.Stdout cannot use in testing
-func createStdInOut() (*os.File, *os.File, func(), error) {
-	in, err := os.OpenFile("CONIN$", os.O_RDWR, 0)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	out, err := os.OpenFile("CONOUT$", os.O_RDWR, 0)
+func setConsoleMode(t *testing.T, file *os.File, mode uint32) {
+	t.Helper()
+	err := windows.SetConsoleMode(windows.Handle(file.Fd()), mode)
 	if err != nil {
-		return nil, nil, nil, err
+		t.Fatal("SetConsoleMode:", err)
 	}
-	release := func() {
-		in.Close()
-		out.Close()
-	}
-	return in, out, release, nil
-}
-
-func setConsoleMode(file *os.File, mode uint32) error {
-	err := windows.SetConsoleMode(windows.Handle(file.Fd()), mode)
-	return err
 }
 
-func getConsoleMode(file *os.File) (uint32, error) {
+func getConsoleMode(t *testing.T, file *os.File) uint32 {
+	t.Helper()
 	var mode uint32
 	err := windows.GetConsoleMode(windows.Handle(file.Fd()), &mode)
-	return mode, err
+	if err != nil {
+		t.Fatal("GetConsoleMode:", err)
+	}
+	return mode
 }

+ 3 - 11
pkg/shell/interact.go

@@ -11,7 +11,6 @@ import (
 	"time"
 
 	"src.elv.sh/pkg/cli"
-	"src.elv.sh/pkg/cli/term"
 	"src.elv.sh/pkg/daemon/daemondefs"
 	"src.elv.sh/pkg/diag"
 	"src.elv.sh/pkg/edit"
@@ -92,8 +91,6 @@ func interact(ev *eval.Evaler, fds [3]*os.File, cfg *interactCfg) {
 		}
 	}
 
-	term.Sanitize(fds[0], fds[2])
-
 	cooldown := time.Second
 	cmdNum := 0
 
@@ -128,10 +125,8 @@ func interact(ev *eval.Evaler, fds [3]*os.File, cfg *interactCfg) {
 		if strings.TrimSpace(line) == "" {
 			continue
 		}
-		src := parse.Source{Name: fmt.Sprintf("[tty %v]", cmdNum), Code: line}
-		duration, err := evalInTTY(ev, fds, src)
-		ed.RunAfterCommandHooks(src, duration, err)
-		term.Sanitize(fds[0], fds[2])
+		err = evalInTTY(fds, ev, ed,
+			parse.Source{Name: fmt.Sprintf("[tty %v]", cmdNum), Code: line})
 		if err != nil {
 			diag.ShowError(fds[2], err)
 		}
@@ -163,10 +158,7 @@ func sourceRC(fds [3]*os.File, ev *eval.Evaler, ed editor, rcPath string) error
 		}
 		return err
 	}
-	src := parse.Source{Name: absPath, Code: code, IsFile: true}
-	duration, err := evalInTTY(ev, fds, src)
-	ed.RunAfterCommandHooks(src, duration, err)
-	return err
+	return evalInTTY(fds, ev, ed, parse.Source{Name: absPath, Code: code, IsFile: true})
 }
 
 type minEditor struct {

+ 1 - 1
pkg/shell/script.go

@@ -61,7 +61,7 @@ func script(ev *eval.Evaler, fds [3]*os.File, args []string, cfg *scriptCfg) int
 			return 2
 		}
 	} else {
-		_, err := evalInTTY(ev, fds, src)
+		err := evalInTTY(fds, ev, nil, src)
 		if err != nil {
 			diag.ShowError(fds[2], err)
 			return 2

+ 10 - 11
pkg/shell/shell.go

@@ -57,7 +57,7 @@ func (p *Program) RegisterFlags(fs *prog.FlagSet) {
 func (p *Program) Run(fds [3]*os.File, args []string) error {
 	cleanup1 := IncSHLVL()
 	defer cleanup1()
-	cleanup2 := initTTYAndSignal(fds)
+	cleanup2 := initSignal(fds)
 	defer cleanup2()
 
 	ev := MakeEvaler(fds[2])
@@ -132,9 +132,7 @@ func IncSHLVL() func() {
 	}
 }
 
-func initTTYAndSignal(fds [3]*os.File) func() {
-	restoreTTY := term.SetupGlobal(fds[0], fds[2])
-
+func initSignal(fds [3]*os.File) func() {
 	sigCh := sys.NotifySignals()
 	go func() {
 		for sig := range sigCh {
@@ -143,18 +141,19 @@ func initTTYAndSignal(fds [3]*os.File) func() {
 		}
 	}()
 
-	return func() {
-		signal.Stop(sigCh)
-		restoreTTY()
-	}
+	return func() { signal.Stop(sigCh) }
 }
 
-func evalInTTY(ev *eval.Evaler, fds [3]*os.File, src parse.Source) (float64, error) {
+func evalInTTY(fds [3]*os.File, ev *eval.Evaler, ed editor, src parse.Source) error {
 	start := time.Now()
 	ports, cleanup := eval.PortsFromFiles(fds, ev.ValuePrefix())
 	defer cleanup()
+	restore := term.SetupForEval(fds[0], fds[1])
+	defer restore()
 	err := ev.Eval(src, eval.EvalCfg{
 		Ports: ports, Interrupt: eval.ListenInterrupts, PutInFg: true})
-	end := time.Now()
-	return end.Sub(start).Seconds(), err
+	if ed != nil {
+		ed.RunAfterCommandHooks(src, time.Since(start).Seconds(), err)
+	}
+	return err
 }