Files
bellpilot/main.go

581 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"embed"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"os"
"os/exec"
"os/signal"
"sync"
"syscall"
"time"
"github.com/creack/pty"
"github.com/ebitengine/oto/v3"
"golang.org/x/term"
)
//go:embed bells.wav
var bellFile embed.FS
// --- Audio engine ---
type audioEngine struct {
once sync.Once
ctx *oto.Context
pcm []byte
initErr error
bellCh chan struct{}
}
var engine = &audioEngine{
bellCh: make(chan struct{}, 1),
}
func (e *audioEngine) init() {
e.once.Do(func() {
data, err := bellFile.ReadFile("bells.wav")
if err != nil {
e.initErr = fmt.Errorf("read bell file: %w", err)
return
}
pcm, sampleRate, channels, format, err := parseWAV(data)
if err != nil {
e.initErr = fmt.Errorf("parse WAV: %w", err)
return
}
// oto's FormatUnsignedInt8 path has an integer underflow bug: it computes
// float32(v8-(1<<7)) with uint8 arithmetic, wrapping values below 128 into
// large positive numbers. Convert to signed 16-bit to use the correct path.
if format == oto.FormatUnsignedInt8 {
pcm16 := make([]byte, len(pcm)*2)
for i, b := range pcm {
s := (int16(b) - 128) * 256
pcm16[2*i] = byte(s)
pcm16[2*i+1] = byte(s >> 8)
}
pcm = pcm16
format = oto.FormatSignedInt16LE
}
ctx, ready, err := oto.NewContext(&oto.NewContextOptions{
SampleRate: sampleRate,
ChannelCount: channels,
Format: format,
BufferSize: time.Millisecond * 20,
})
if err != nil {
e.initErr = fmt.Errorf("create audio context: %w", err)
return
}
<-ready
e.ctx = ctx
e.pcm = pcm
})
}
// ring queues at most one pending bell, dropping extras while one is already queued.
func (e *audioEngine) ring() {
select {
case e.bellCh <- struct{}{}:
default:
}
}
// run is the single bell-playing goroutine. It must be started once before ring is called.
func (e *audioEngine) run() {
e.init()
if e.initErr != nil {
log.Printf("audio init failed (%v); falling back to terminal bell", e.initErr)
}
for range e.bellCh {
if e.initErr != nil {
os.Stdout.Write([]byte{'\x07'})
continue
}
player := e.ctx.NewPlayer(bytes.NewReader(e.pcm))
player.Play()
for player.IsPlaying() {
time.Sleep(time.Millisecond)
}
}
}
// --- WAV parser ---
func parseWAV(data []byte) (pcm []byte, sampleRate, channels int, format oto.Format, err error) {
r := bytes.NewReader(data)
var id [4]byte
if err = binary.Read(r, binary.LittleEndian, &id); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read RIFF id: %w", err)
}
if string(id[:]) != "RIFF" {
return nil, 0, 0, 0, fmt.Errorf("not a RIFF file")
}
var fileSize uint32
if err = binary.Read(r, binary.LittleEndian, &fileSize); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read file size: %w", err)
}
if err = binary.Read(r, binary.LittleEndian, &id); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read WAVE id: %w", err)
}
if string(id[:]) != "WAVE" {
return nil, 0, 0, 0, fmt.Errorf("not a WAVE file")
}
var hdr struct {
AudioFormat uint16
NumChannels uint16
SampleRate uint32
ByteRate uint32
BlockAlign uint16
BitsPerSample uint16
}
var fmtFound bool
for {
var chunkID [4]byte
if err = binary.Read(r, binary.LittleEndian, &chunkID); err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
return nil, 0, 0, 0, fmt.Errorf("read chunk id: %w", err)
}
var chunkSize uint32
if err = binary.Read(r, binary.LittleEndian, &chunkSize); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read chunk size: %w", err)
}
switch string(chunkID[:]) {
case "fmt ":
if chunkSize < 16 {
return nil, 0, 0, 0, fmt.Errorf("fmt chunk too small (%d bytes)", chunkSize)
}
if err = binary.Read(r, binary.LittleEndian, &hdr); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read fmt chunk: %w", err)
}
if chunkSize > 16 {
skip := int64(chunkSize - 16)
if chunkSize%2 != 0 {
skip++
}
if _, err = r.Seek(skip, io.SeekCurrent); err != nil {
return nil, 0, 0, 0, fmt.Errorf("skip fmt extra: %w", err)
}
}
fmtFound = true
case "data":
if !fmtFound {
return nil, 0, 0, 0, fmt.Errorf("data chunk before fmt chunk")
}
sz := chunkSize
if sz > uint32(r.Len()) {
sz = uint32(r.Len())
}
pcm = make([]byte, sz)
if _, err = io.ReadFull(r, pcm); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read data chunk: %w", err)
}
if chunkSize%2 != 0 {
r.Seek(1, io.SeekCurrent)
}
default:
skip := int64(chunkSize)
if chunkSize%2 != 0 {
skip++
}
if _, err = r.Seek(skip, io.SeekCurrent); err != nil {
return nil, 0, 0, 0, fmt.Errorf("skip chunk: %w", err)
}
}
}
if !fmtFound {
return nil, 0, 0, 0, fmt.Errorf("no fmt chunk in WAV file")
}
if pcm == nil {
return nil, 0, 0, 0, fmt.Errorf("no data chunk in WAV file")
}
switch {
case hdr.AudioFormat == 1 && hdr.BitsPerSample == 8:
format = oto.FormatUnsignedInt8
case hdr.AudioFormat == 1 && hdr.BitsPerSample == 16:
format = oto.FormatSignedInt16LE
case hdr.AudioFormat == 3 && hdr.BitsPerSample == 32:
format = oto.FormatFloat32LE
default:
return nil, 0, 0, 0, fmt.Errorf("unsupported WAV format: audio_format=%d, bits=%d", hdr.AudioFormat, hdr.BitsPerSample)
}
return pcm, int(hdr.SampleRate), int(hdr.NumChannels), format, nil
}
// --- PTY filter loop ---
// filterLoop copies src to dst, stripping standalone BEL (0x07) bytes and calling bell()
// for each one. BEL bytes that appear as OSC string terminators are passed through unchanged
// so that the outer terminal does not get stuck mid-sequence.
//
// State machine covers the subset of ANSI/VT escape sequences that can contain 0x07:
//
// Normal → ESC → CSI (ESC [ ... final)
// ESC → OSC (ESC ] ... BEL or ST)
// ESC → Str (ESC P/X/^/_ ... ST) (DCS, SOS, PM, APC)
// ESC → <two-char>
//
// ST = String Terminator = ESC \
func filterLoop(dst io.Writer, src io.Reader, bell func()) error {
const (
stNormal = iota
stEsc // saw 0x1b
stCSI // saw 0x1b [ — terminated by 0x400x7E
stOSC // saw 0x1b ] — terminated by BEL or ST
stStr // saw 0x1b P/X/^/_ — terminated by ST only
stST // saw 0x1b inside stOSC or stStr, checking for '\' (ST)
)
buf := make([]byte, 32*1024)
state := stNormal
parent := stNormal // parent of stST (stOSC or stStr)
esc := make([]byte, 0, 256)
writeAll := func(p []byte) error {
_, err := dst.Write(p)
return err
}
flushEsc := func() error {
err := writeAll(esc)
esc = esc[:0]
return err
}
for {
n, readErr := src.Read(buf)
p := buf[:n]
i := 0
for i < len(p) {
b := p[i]
// Fast path: bulk-copy normal bytes until a special byte is hit.
if state == stNormal {
j := i
for i < len(p) && p[i] != 0x07 && p[i] != 0x1b {
i++
}
if i > j {
if err := writeAll(p[j:i]); err != nil {
return err
}
}
if i >= len(p) {
break
}
b = p[i]
}
i++
switch state {
case stNormal:
if b == 0x07 {
bell()
} else { // 0x1b
esc = append(esc[:0], b)
state = stEsc
}
case stEsc:
switch b {
case '[':
esc = append(esc, b)
state = stCSI
case ']':
esc = append(esc, b)
state = stOSC
case 'P', 'X', '^', '_': // DCS, SOS, PM, APC
esc = append(esc, b)
state = stStr
case 0x1b:
// Another ESC: flush previous incomplete sequence, start fresh.
if err := writeAll(esc); err != nil {
return err
}
esc = esc[:0]
esc = append(esc, b)
// stay in stEsc
case 0x07:
// BEL right after bare ESC: flush ESC, ring bell.
if err := flushEsc(); err != nil {
return err
}
state = stNormal
bell()
default:
// Two-character escape sequence complete.
esc = append(esc, b)
if err := flushEsc(); err != nil {
return err
}
state = stNormal
}
case stCSI:
if b == 0x07 {
// Unexpected BEL mid-CSI: flush incomplete sequence, ring bell.
if err := flushEsc(); err != nil {
return err
}
state = stNormal
bell()
} else {
esc = append(esc, b)
if b >= 0x40 && b <= 0x7e { // final byte
if err := flushEsc(); err != nil {
return err
}
state = stNormal
}
}
case stOSC:
esc = append(esc, b)
if b == 0x07 {
// BEL terminates OSC — pass the whole sequence through unchanged.
if err := flushEsc(); err != nil {
return err
}
state = stNormal
} else if b == 0x1b {
parent = stOSC
state = stST
}
case stStr: // DCS/SOS/PM/APC — only ST terminates
esc = append(esc, b)
if b == 0x1b {
parent = stStr
state = stST
}
case stST:
esc = append(esc, b)
switch b {
case '\\': // String Terminator complete
if err := flushEsc(); err != nil {
return err
}
state = stNormal
case 0x07:
if parent == stOSC {
// BEL terminates OSC even after an intermediate ESC.
if err := flushEsc(); err != nil {
return err
}
state = stNormal
} else {
// Not a terminator for DCS/SOS/PM/APC; back to parent.
state = parent
}
case 0x1b:
// Another ESC inside string; stay in stST.
default:
// Not ST; back to parent string state.
state = parent
}
}
}
if readErr != nil {
_ = flushEsc() // flush any buffered incomplete sequence
if isExpectedPTYErr(readErr) {
return nil
}
return readErr
}
}
}
// isExpectedPTYErr returns true for errors that indicate normal PTY shutdown.
func isExpectedPTYErr(err error) bool {
if errors.Is(err, io.EOF) {
return true
}
var errno syscall.Errno
if errors.As(err, &errno) {
return errno == syscall.EIO
}
return false
}
// --- Main ---
func main() {
if len(os.Args) < 2 {
fmt.Fprintf(os.Stderr, "Usage: bellpilot <command> [args...]\n")
os.Exit(1)
}
os.Exit(run(os.Args[1], os.Args[2:]))
}
func run(name string, args []string) int {
go engine.run()
c := exec.Command(name, args...)
ptmx, err := pty.Start(c)
if err != nil {
log.Printf("start: %v", err)
return 1
}
defer ptmx.Close()
// Propagate terminal resize to PTY.
winchCh := make(chan os.Signal, 1)
signal.Notify(winchCh, syscall.SIGWINCH)
go func() {
for range winchCh {
if err := pty.InheritSize(os.Stdin, ptmx); err != nil {
log.Printf("resize: %v", err)
}
}
}()
winchCh <- syscall.SIGWINCH
// Forward termination signals to the child.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
go func() {
for sig := range sigCh {
c.Process.Signal(sig)
}
}()
// Put stdin in raw mode if it is a terminal.
var oldState *term.State
if term.IsTerminal(int(os.Stdin.Fd())) {
oldState, err = term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
log.Printf("make raw: %v", err)
return 1
}
defer term.Restore(int(os.Stdin.Fd()), oldState)
}
childPid := c.Process.Pid
termState := oldState // updated on each MakeRaw; only touched by suspend goroutine
var suspendMu sync.Mutex // ensures only one suspend/resume cycle runs at a time
// Copy stdin → PTY, intercepting Ctrl+Z (0x1A) for job control.
go func() {
buf := make([]byte, 4096)
for {
n, err := os.Stdin.Read(buf)
if n > 0 {
i := 0
for j := 0; j < n; j++ {
if buf[j] == 0x1A { // Ctrl+Z
if j > i {
ptmx.Write(buf[i:j])
}
// Run the full suspend/resume cycle in one goroutine so
// there is no cross-goroutine state to synchronize.
go func() {
if !suspendMu.TryLock() {
return // already suspending
}
defer suspendMu.Unlock()
// Stop the child immediately without forwarding 0x1A.
// Forwarding 0x1A causes copilot to write partial teardown
// sequences that glitch the display before we can stop it.
// We send to both the direct PID and the process group so
// that subprocesses are also stopped regardless of whether
// pty.Start created a new process group.
// SIGSTOP is synchronous — the child is stopped before
// Kill() returns, so no sleep is needed.
syscall.Kill(childPid, syscall.SIGSTOP)
syscall.Kill(-childPid, syscall.SIGSTOP)
// Exit alternate screen and show cursor in case copilot had
// them active, so the shell prompt appears on a clean screen.
os.Stdout.Write([]byte("\x1b[?1049l\x1b[?25h\r\n"))
// Hand the terminal back to the shell.
if termState != nil {
term.Restore(int(os.Stdin.Fd()), termState)
}
// Stop ourselves; the shell takes over.
syscall.Kill(os.Getpid(), syscall.SIGSTOP)
// ── Execution resumes here after "fg" ──────────
// Re-apply raw mode.
if termState != nil {
if s, err := term.MakeRaw(int(os.Stdin.Fd())); err == nil {
termState = s
}
}
// Continue the child (direct PID + process group).
syscall.Kill(childPid, syscall.SIGCONT)
syscall.Kill(-childPid, syscall.SIGCONT)
// Sync terminal size and ask copilot to redraw its UI.
pty.InheritSize(os.Stdin, ptmx)
syscall.Kill(childPid, syscall.SIGWINCH)
}()
i = j + 1
}
}
if i < n {
ptmx.Write(buf[i:n])
}
}
if err != nil {
return
}
}
}()
// exitCodeCh receives the child's exit code when it terminates.
exitCodeCh := make(chan int, 1)
// waitChild monitors the child for exit/signal; stop events are handled by
// the suspend goroutine above so we just loop past them here.
go func() {
for {
var ws syscall.WaitStatus
_, err := syscall.Wait4(childPid, &ws, syscall.WUNTRACED, nil)
if err != nil {
if err == syscall.EINTR {
continue
}
exitCodeCh <- 1
return
}
if ws.Exited() {
exitCodeCh <- ws.ExitStatus()
return
}
if ws.Signaled() {
exitCodeCh <- 128 + int(ws.Signal())
return
}
// ws.Stopped() — suspend goroutine handles this; loop for next event.
}
}()
if err := filterLoop(os.Stdout, ptmx, engine.ring); err != nil {
log.Printf("filter: %v", err)
}
return <-exitCodeCh
}