package main import ( "bytes" "embed" "encoding/binary" "errors" "fmt" "io" "log" "os" "os/exec" "os/signal" "sync" "syscall" "time" "github.com/creack/pty" "github.com/ebitengine/oto/v3" "golang.org/x/term" ) //go:embed bells.wav var bellFile embed.FS // --- Audio engine --- type audioEngine struct { once sync.Once ctx *oto.Context pcm []byte initErr error bellCh chan struct{} } var engine = &audioEngine{ bellCh: make(chan struct{}, 1), } func (e *audioEngine) init() { e.once.Do(func() { data, err := bellFile.ReadFile("bells.wav") if err != nil { e.initErr = fmt.Errorf("read bell file: %w", err) return } pcm, sampleRate, channels, format, err := parseWAV(data) if err != nil { e.initErr = fmt.Errorf("parse WAV: %w", err) return } // oto's FormatUnsignedInt8 path has an integer underflow bug: it computes // float32(v8-(1<<7)) with uint8 arithmetic, wrapping values below 128 into // large positive numbers. Convert to signed 16-bit to use the correct path. if format == oto.FormatUnsignedInt8 { pcm16 := make([]byte, len(pcm)*2) for i, b := range pcm { s := (int16(b) - 128) * 256 pcm16[2*i] = byte(s) pcm16[2*i+1] = byte(s >> 8) } pcm = pcm16 format = oto.FormatSignedInt16LE } ctx, ready, err := oto.NewContext(&oto.NewContextOptions{ SampleRate: sampleRate, ChannelCount: channels, Format: format, BufferSize: time.Millisecond * 20, }) if err != nil { e.initErr = fmt.Errorf("create audio context: %w", err) return } <-ready e.ctx = ctx e.pcm = pcm }) } // ring queues at most one pending bell, dropping extras while one is already queued. func (e *audioEngine) ring() { select { case e.bellCh <- struct{}{}: default: } } // run is the single bell-playing goroutine. It must be started once before ring is called. func (e *audioEngine) run() { e.init() if e.initErr != nil { log.Printf("audio init failed (%v); falling back to terminal bell", e.initErr) } for range e.bellCh { if e.initErr != nil { os.Stdout.Write([]byte{'\x07'}) continue } player := e.ctx.NewPlayer(bytes.NewReader(e.pcm)) player.Play() for player.IsPlaying() { time.Sleep(time.Millisecond) } } } // --- WAV parser --- func parseWAV(data []byte) (pcm []byte, sampleRate, channels int, format oto.Format, err error) { r := bytes.NewReader(data) var id [4]byte if err = binary.Read(r, binary.LittleEndian, &id); err != nil { return nil, 0, 0, 0, fmt.Errorf("read RIFF id: %w", err) } if string(id[:]) != "RIFF" { return nil, 0, 0, 0, fmt.Errorf("not a RIFF file") } var fileSize uint32 if err = binary.Read(r, binary.LittleEndian, &fileSize); err != nil { return nil, 0, 0, 0, fmt.Errorf("read file size: %w", err) } if err = binary.Read(r, binary.LittleEndian, &id); err != nil { return nil, 0, 0, 0, fmt.Errorf("read WAVE id: %w", err) } if string(id[:]) != "WAVE" { return nil, 0, 0, 0, fmt.Errorf("not a WAVE file") } var hdr struct { AudioFormat uint16 NumChannels uint16 SampleRate uint32 ByteRate uint32 BlockAlign uint16 BitsPerSample uint16 } var fmtFound bool for { var chunkID [4]byte if err = binary.Read(r, binary.LittleEndian, &chunkID); err != nil { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { break } return nil, 0, 0, 0, fmt.Errorf("read chunk id: %w", err) } var chunkSize uint32 if err = binary.Read(r, binary.LittleEndian, &chunkSize); err != nil { return nil, 0, 0, 0, fmt.Errorf("read chunk size: %w", err) } switch string(chunkID[:]) { case "fmt ": if chunkSize < 16 { return nil, 0, 0, 0, fmt.Errorf("fmt chunk too small (%d bytes)", chunkSize) } if err = binary.Read(r, binary.LittleEndian, &hdr); err != nil { return nil, 0, 0, 0, fmt.Errorf("read fmt chunk: %w", err) } if chunkSize > 16 { skip := int64(chunkSize - 16) if chunkSize%2 != 0 { skip++ } if _, err = r.Seek(skip, io.SeekCurrent); err != nil { return nil, 0, 0, 0, fmt.Errorf("skip fmt extra: %w", err) } } fmtFound = true case "data": if !fmtFound { return nil, 0, 0, 0, fmt.Errorf("data chunk before fmt chunk") } sz := chunkSize if sz > uint32(r.Len()) { sz = uint32(r.Len()) } pcm = make([]byte, sz) if _, err = io.ReadFull(r, pcm); err != nil { return nil, 0, 0, 0, fmt.Errorf("read data chunk: %w", err) } if chunkSize%2 != 0 { r.Seek(1, io.SeekCurrent) } default: skip := int64(chunkSize) if chunkSize%2 != 0 { skip++ } if _, err = r.Seek(skip, io.SeekCurrent); err != nil { return nil, 0, 0, 0, fmt.Errorf("skip chunk: %w", err) } } } if !fmtFound { return nil, 0, 0, 0, fmt.Errorf("no fmt chunk in WAV file") } if pcm == nil { return nil, 0, 0, 0, fmt.Errorf("no data chunk in WAV file") } switch { case hdr.AudioFormat == 1 && hdr.BitsPerSample == 8: format = oto.FormatUnsignedInt8 case hdr.AudioFormat == 1 && hdr.BitsPerSample == 16: format = oto.FormatSignedInt16LE case hdr.AudioFormat == 3 && hdr.BitsPerSample == 32: format = oto.FormatFloat32LE default: return nil, 0, 0, 0, fmt.Errorf("unsupported WAV format: audio_format=%d, bits=%d", hdr.AudioFormat, hdr.BitsPerSample) } return pcm, int(hdr.SampleRate), int(hdr.NumChannels), format, nil } // --- PTY filter loop --- // filterLoop copies src to dst, stripping standalone BEL (0x07) bytes and calling bell() // for each one. BEL bytes that appear as OSC string terminators are passed through unchanged // so that the outer terminal does not get stuck mid-sequence. // // State machine covers the subset of ANSI/VT escape sequences that can contain 0x07: // // Normal → ESC → CSI (ESC [ ... final) // ESC → OSC (ESC ] ... BEL or ST) // ESC → Str (ESC P/X/^/_ ... ST) (DCS, SOS, PM, APC) // ESC → // // ST = String Terminator = ESC \ func filterLoop(dst io.Writer, src io.Reader, bell func()) error { const ( stNormal = iota stEsc // saw 0x1b stCSI // saw 0x1b [ — terminated by 0x40–0x7E stOSC // saw 0x1b ] — terminated by BEL or ST stStr // saw 0x1b P/X/^/_ — terminated by ST only stST // saw 0x1b inside stOSC or stStr, checking for '\' (ST) ) buf := make([]byte, 32*1024) state := stNormal parent := stNormal // parent of stST (stOSC or stStr) esc := make([]byte, 0, 256) writeAll := func(p []byte) error { _, err := dst.Write(p) return err } flushEsc := func() error { err := writeAll(esc) esc = esc[:0] return err } for { n, readErr := src.Read(buf) p := buf[:n] i := 0 for i < len(p) { b := p[i] // Fast path: bulk-copy normal bytes until a special byte is hit. if state == stNormal { j := i for i < len(p) && p[i] != 0x07 && p[i] != 0x1b { i++ } if i > j { if err := writeAll(p[j:i]); err != nil { return err } } if i >= len(p) { break } b = p[i] } i++ switch state { case stNormal: if b == 0x07 { bell() } else { // 0x1b esc = append(esc[:0], b) state = stEsc } case stEsc: switch b { case '[': esc = append(esc, b) state = stCSI case ']': esc = append(esc, b) state = stOSC case 'P', 'X', '^', '_': // DCS, SOS, PM, APC esc = append(esc, b) state = stStr case 0x1b: // Another ESC: flush previous incomplete sequence, start fresh. if err := writeAll(esc); err != nil { return err } esc = esc[:0] esc = append(esc, b) // stay in stEsc case 0x07: // BEL right after bare ESC: flush ESC, ring bell. if err := flushEsc(); err != nil { return err } state = stNormal bell() default: // Two-character escape sequence complete. esc = append(esc, b) if err := flushEsc(); err != nil { return err } state = stNormal } case stCSI: if b == 0x07 { // Unexpected BEL mid-CSI: flush incomplete sequence, ring bell. if err := flushEsc(); err != nil { return err } state = stNormal bell() } else { esc = append(esc, b) if b >= 0x40 && b <= 0x7e { // final byte if err := flushEsc(); err != nil { return err } state = stNormal } } case stOSC: esc = append(esc, b) if b == 0x07 { // BEL terminates OSC — pass the whole sequence through unchanged. if err := flushEsc(); err != nil { return err } state = stNormal } else if b == 0x1b { parent = stOSC state = stST } case stStr: // DCS/SOS/PM/APC — only ST terminates esc = append(esc, b) if b == 0x1b { parent = stStr state = stST } case stST: esc = append(esc, b) switch b { case '\\': // String Terminator complete if err := flushEsc(); err != nil { return err } state = stNormal case 0x07: if parent == stOSC { // BEL terminates OSC even after an intermediate ESC. if err := flushEsc(); err != nil { return err } state = stNormal } else { // Not a terminator for DCS/SOS/PM/APC; back to parent. state = parent } case 0x1b: // Another ESC inside string; stay in stST. default: // Not ST; back to parent string state. state = parent } } } if readErr != nil { _ = flushEsc() // flush any buffered incomplete sequence if isExpectedPTYErr(readErr) { return nil } return readErr } } } // isExpectedPTYErr returns true for errors that indicate normal PTY shutdown. func isExpectedPTYErr(err error) bool { if errors.Is(err, io.EOF) { return true } var errno syscall.Errno if errors.As(err, &errno) { return errno == syscall.EIO } return false } // --- Main --- func main() { if len(os.Args) < 2 { fmt.Fprintf(os.Stderr, "Usage: bellpilot [args...]\n") os.Exit(1) } os.Exit(run(os.Args[1], os.Args[2:])) } func run(name string, args []string) int { go engine.run() c := exec.Command(name, args...) ptmx, err := pty.Start(c) if err != nil { log.Printf("start: %v", err) return 1 } defer ptmx.Close() // Propagate terminal resize to PTY. winchCh := make(chan os.Signal, 1) signal.Notify(winchCh, syscall.SIGWINCH) go func() { for range winchCh { if err := pty.InheritSize(os.Stdin, ptmx); err != nil { log.Printf("resize: %v", err) } } }() winchCh <- syscall.SIGWINCH // Forward termination signals to the child. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) go func() { for sig := range sigCh { c.Process.Signal(sig) } }() // Put stdin in raw mode if it is a terminal. var oldState *term.State if term.IsTerminal(int(os.Stdin.Fd())) { oldState, err = term.MakeRaw(int(os.Stdin.Fd())) if err != nil { log.Printf("make raw: %v", err) return 1 } defer term.Restore(int(os.Stdin.Fd()), oldState) } childPid := c.Process.Pid termState := oldState // updated on each MakeRaw; only touched by suspend goroutine var suspendMu sync.Mutex // ensures only one suspend/resume cycle runs at a time // Copy stdin → PTY, intercepting Ctrl+Z (0x1A) for job control. go func() { buf := make([]byte, 4096) for { n, err := os.Stdin.Read(buf) if n > 0 { i := 0 for j := 0; j < n; j++ { if buf[j] == 0x1A { // Ctrl+Z if j > i { ptmx.Write(buf[i:j]) } // Run the full suspend/resume cycle in one goroutine so // there is no cross-goroutine state to synchronize. go func() { if !suspendMu.TryLock() { return // already suspending } defer suspendMu.Unlock() // Stop the child immediately without forwarding 0x1A. // Forwarding 0x1A causes copilot to write partial teardown // sequences that glitch the display before we can stop it. // We send to both the direct PID and the process group so // that subprocesses are also stopped regardless of whether // pty.Start created a new process group. // SIGSTOP is synchronous — the child is stopped before // Kill() returns, so no sleep is needed. syscall.Kill(childPid, syscall.SIGSTOP) syscall.Kill(-childPid, syscall.SIGSTOP) // Exit alternate screen and show cursor in case copilot had // them active, so the shell prompt appears on a clean screen. os.Stdout.Write([]byte("\x1b[?1049l\x1b[?25h\r\n")) // Hand the terminal back to the shell. if termState != nil { term.Restore(int(os.Stdin.Fd()), termState) } // Stop ourselves; the shell takes over. syscall.Kill(os.Getpid(), syscall.SIGSTOP) // ── Execution resumes here after "fg" ────────── // Re-apply raw mode. if termState != nil { if s, err := term.MakeRaw(int(os.Stdin.Fd())); err == nil { termState = s } } // Continue the child (direct PID + process group). syscall.Kill(childPid, syscall.SIGCONT) syscall.Kill(-childPid, syscall.SIGCONT) // Sync terminal size and ask copilot to redraw its UI. pty.InheritSize(os.Stdin, ptmx) syscall.Kill(childPid, syscall.SIGWINCH) }() i = j + 1 } } if i < n { ptmx.Write(buf[i:n]) } } if err != nil { return } } }() // exitCodeCh receives the child's exit code when it terminates. exitCodeCh := make(chan int, 1) // waitChild monitors the child for exit/signal; stop events are handled by // the suspend goroutine above so we just loop past them here. go func() { for { var ws syscall.WaitStatus _, err := syscall.Wait4(childPid, &ws, syscall.WUNTRACED, nil) if err != nil { if err == syscall.EINTR { continue } exitCodeCh <- 1 return } if ws.Exited() { exitCodeCh <- ws.ExitStatus() return } if ws.Signaled() { exitCodeCh <- 128 + int(ws.Signal()) return } // ws.Stopped() — suspend goroutine handles this; loop for next event. } }() if err := filterLoop(os.Stdout, ptmx, engine.ring); err != nil { log.Printf("filter: %v", err) } return <-exitCodeCh }