488 lines
11 KiB
Go
488 lines
11 KiB
Go
package main
|
||
|
||
import (
|
||
"bytes"
|
||
"embed"
|
||
"encoding/binary"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"os"
|
||
"os/exec"
|
||
"os/signal"
|
||
"sync"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/creack/pty"
|
||
"github.com/ebitengine/oto/v3"
|
||
"golang.org/x/term"
|
||
)
|
||
|
||
//go:embed bells.wav
|
||
var bellFile embed.FS
|
||
|
||
// --- Audio engine ---
|
||
|
||
type audioEngine struct {
|
||
once sync.Once
|
||
ctx *oto.Context
|
||
pcm []byte
|
||
initErr error
|
||
bellCh chan struct{}
|
||
}
|
||
|
||
var engine = &audioEngine{
|
||
bellCh: make(chan struct{}, 1),
|
||
}
|
||
|
||
func (e *audioEngine) init() {
|
||
e.once.Do(func() {
|
||
data, err := bellFile.ReadFile("bells.wav")
|
||
if err != nil {
|
||
e.initErr = fmt.Errorf("read bell file: %w", err)
|
||
return
|
||
}
|
||
|
||
pcm, sampleRate, channels, format, err := parseWAV(data)
|
||
if err != nil {
|
||
e.initErr = fmt.Errorf("parse WAV: %w", err)
|
||
return
|
||
}
|
||
|
||
// oto's FormatUnsignedInt8 path has an integer underflow bug: it computes
|
||
// float32(v8-(1<<7)) with uint8 arithmetic, wrapping values below 128 into
|
||
// large positive numbers. Convert to signed 16-bit to use the correct path.
|
||
if format == oto.FormatUnsignedInt8 {
|
||
pcm16 := make([]byte, len(pcm)*2)
|
||
for i, b := range pcm {
|
||
s := (int16(b) - 128) * 256
|
||
pcm16[2*i] = byte(s)
|
||
pcm16[2*i+1] = byte(s >> 8)
|
||
}
|
||
pcm = pcm16
|
||
format = oto.FormatSignedInt16LE
|
||
}
|
||
|
||
ctx, ready, err := oto.NewContext(&oto.NewContextOptions{
|
||
SampleRate: sampleRate,
|
||
ChannelCount: channels,
|
||
Format: format,
|
||
BufferSize: time.Millisecond * 20,
|
||
})
|
||
if err != nil {
|
||
e.initErr = fmt.Errorf("create audio context: %w", err)
|
||
return
|
||
}
|
||
<-ready
|
||
|
||
e.ctx = ctx
|
||
e.pcm = pcm
|
||
})
|
||
}
|
||
|
||
// ring queues at most one pending bell, dropping extras while one is already queued.
|
||
func (e *audioEngine) ring() {
|
||
select {
|
||
case e.bellCh <- struct{}{}:
|
||
default:
|
||
}
|
||
}
|
||
|
||
// run is the single bell-playing goroutine. It must be started once before ring is called.
|
||
func (e *audioEngine) run() {
|
||
e.init()
|
||
if e.initErr != nil {
|
||
log.Printf("audio init failed (%v); falling back to terminal bell", e.initErr)
|
||
}
|
||
for range e.bellCh {
|
||
if e.initErr != nil {
|
||
os.Stdout.Write([]byte{'\x07'})
|
||
continue
|
||
}
|
||
player := e.ctx.NewPlayer(bytes.NewReader(e.pcm))
|
||
player.Play()
|
||
for player.IsPlaying() {
|
||
time.Sleep(time.Millisecond)
|
||
}
|
||
}
|
||
}
|
||
|
||
// --- WAV parser ---
|
||
|
||
func parseWAV(data []byte) (pcm []byte, sampleRate, channels int, format oto.Format, err error) {
|
||
r := bytes.NewReader(data)
|
||
|
||
var id [4]byte
|
||
if err = binary.Read(r, binary.LittleEndian, &id); err != nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("read RIFF id: %w", err)
|
||
}
|
||
if string(id[:]) != "RIFF" {
|
||
return nil, 0, 0, 0, fmt.Errorf("not a RIFF file")
|
||
}
|
||
var fileSize uint32
|
||
if err = binary.Read(r, binary.LittleEndian, &fileSize); err != nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("read file size: %w", err)
|
||
}
|
||
if err = binary.Read(r, binary.LittleEndian, &id); err != nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("read WAVE id: %w", err)
|
||
}
|
||
if string(id[:]) != "WAVE" {
|
||
return nil, 0, 0, 0, fmt.Errorf("not a WAVE file")
|
||
}
|
||
|
||
var hdr struct {
|
||
AudioFormat uint16
|
||
NumChannels uint16
|
||
SampleRate uint32
|
||
ByteRate uint32
|
||
BlockAlign uint16
|
||
BitsPerSample uint16
|
||
}
|
||
var fmtFound bool
|
||
|
||
for {
|
||
var chunkID [4]byte
|
||
if err = binary.Read(r, binary.LittleEndian, &chunkID); err != nil {
|
||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||
break
|
||
}
|
||
return nil, 0, 0, 0, fmt.Errorf("read chunk id: %w", err)
|
||
}
|
||
var chunkSize uint32
|
||
if err = binary.Read(r, binary.LittleEndian, &chunkSize); err != nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("read chunk size: %w", err)
|
||
}
|
||
|
||
switch string(chunkID[:]) {
|
||
case "fmt ":
|
||
if chunkSize < 16 {
|
||
return nil, 0, 0, 0, fmt.Errorf("fmt chunk too small (%d bytes)", chunkSize)
|
||
}
|
||
if err = binary.Read(r, binary.LittleEndian, &hdr); err != nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("read fmt chunk: %w", err)
|
||
}
|
||
if chunkSize > 16 {
|
||
skip := int64(chunkSize - 16)
|
||
if chunkSize%2 != 0 {
|
||
skip++
|
||
}
|
||
if _, err = r.Seek(skip, io.SeekCurrent); err != nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("skip fmt extra: %w", err)
|
||
}
|
||
}
|
||
fmtFound = true
|
||
case "data":
|
||
if !fmtFound {
|
||
return nil, 0, 0, 0, fmt.Errorf("data chunk before fmt chunk")
|
||
}
|
||
sz := chunkSize
|
||
if sz > uint32(r.Len()) {
|
||
sz = uint32(r.Len())
|
||
}
|
||
pcm = make([]byte, sz)
|
||
if _, err = io.ReadFull(r, pcm); err != nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("read data chunk: %w", err)
|
||
}
|
||
if chunkSize%2 != 0 {
|
||
r.Seek(1, io.SeekCurrent)
|
||
}
|
||
default:
|
||
skip := int64(chunkSize)
|
||
if chunkSize%2 != 0 {
|
||
skip++
|
||
}
|
||
if _, err = r.Seek(skip, io.SeekCurrent); err != nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("skip chunk: %w", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
if !fmtFound {
|
||
return nil, 0, 0, 0, fmt.Errorf("no fmt chunk in WAV file")
|
||
}
|
||
if pcm == nil {
|
||
return nil, 0, 0, 0, fmt.Errorf("no data chunk in WAV file")
|
||
}
|
||
|
||
switch {
|
||
case hdr.AudioFormat == 1 && hdr.BitsPerSample == 8:
|
||
format = oto.FormatUnsignedInt8
|
||
case hdr.AudioFormat == 1 && hdr.BitsPerSample == 16:
|
||
format = oto.FormatSignedInt16LE
|
||
case hdr.AudioFormat == 3 && hdr.BitsPerSample == 32:
|
||
format = oto.FormatFloat32LE
|
||
default:
|
||
return nil, 0, 0, 0, fmt.Errorf("unsupported WAV format: audio_format=%d, bits=%d", hdr.AudioFormat, hdr.BitsPerSample)
|
||
}
|
||
|
||
return pcm, int(hdr.SampleRate), int(hdr.NumChannels), format, nil
|
||
}
|
||
|
||
// --- PTY filter loop ---
|
||
|
||
// filterLoop copies src to dst, stripping standalone BEL (0x07) bytes and calling bell()
|
||
// for each one. BEL bytes that appear as OSC string terminators are passed through unchanged
|
||
// so that the outer terminal does not get stuck mid-sequence.
|
||
//
|
||
// State machine covers the subset of ANSI/VT escape sequences that can contain 0x07:
|
||
//
|
||
// Normal → ESC → CSI (ESC [ ... final)
|
||
// ESC → OSC (ESC ] ... BEL or ST)
|
||
// ESC → Str (ESC P/X/^/_ ... ST) (DCS, SOS, PM, APC)
|
||
// ESC → <two-char>
|
||
//
|
||
// ST = String Terminator = ESC \
|
||
func filterLoop(dst io.Writer, src io.Reader, bell func()) error {
|
||
const (
|
||
stNormal = iota
|
||
stEsc // saw 0x1b
|
||
stCSI // saw 0x1b [ — terminated by 0x40–0x7E
|
||
stOSC // saw 0x1b ] — terminated by BEL or ST
|
||
stStr // saw 0x1b P/X/^/_ — terminated by ST only
|
||
stST // saw 0x1b inside stOSC or stStr, checking for '\' (ST)
|
||
)
|
||
|
||
buf := make([]byte, 32*1024)
|
||
state := stNormal
|
||
parent := stNormal // parent of stST (stOSC or stStr)
|
||
esc := make([]byte, 0, 256)
|
||
|
||
writeAll := func(p []byte) error {
|
||
_, err := dst.Write(p)
|
||
return err
|
||
}
|
||
flushEsc := func() error {
|
||
err := writeAll(esc)
|
||
esc = esc[:0]
|
||
return err
|
||
}
|
||
|
||
for {
|
||
n, readErr := src.Read(buf)
|
||
p := buf[:n]
|
||
i := 0
|
||
for i < len(p) {
|
||
b := p[i]
|
||
|
||
// Fast path: bulk-copy normal bytes until a special byte is hit.
|
||
if state == stNormal {
|
||
j := i
|
||
for i < len(p) && p[i] != 0x07 && p[i] != 0x1b {
|
||
i++
|
||
}
|
||
if i > j {
|
||
if err := writeAll(p[j:i]); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
if i >= len(p) {
|
||
break
|
||
}
|
||
b = p[i]
|
||
}
|
||
i++
|
||
|
||
switch state {
|
||
case stNormal:
|
||
if b == 0x07 {
|
||
bell()
|
||
} else { // 0x1b
|
||
esc = append(esc[:0], b)
|
||
state = stEsc
|
||
}
|
||
|
||
case stEsc:
|
||
switch b {
|
||
case '[':
|
||
esc = append(esc, b)
|
||
state = stCSI
|
||
case ']':
|
||
esc = append(esc, b)
|
||
state = stOSC
|
||
case 'P', 'X', '^', '_': // DCS, SOS, PM, APC
|
||
esc = append(esc, b)
|
||
state = stStr
|
||
case 0x1b:
|
||
// Another ESC: flush previous incomplete sequence, start fresh.
|
||
if err := writeAll(esc); err != nil {
|
||
return err
|
||
}
|
||
esc = esc[:0]
|
||
esc = append(esc, b)
|
||
// stay in stEsc
|
||
case 0x07:
|
||
// BEL right after bare ESC: flush ESC, ring bell.
|
||
if err := flushEsc(); err != nil {
|
||
return err
|
||
}
|
||
state = stNormal
|
||
bell()
|
||
default:
|
||
// Two-character escape sequence complete.
|
||
esc = append(esc, b)
|
||
if err := flushEsc(); err != nil {
|
||
return err
|
||
}
|
||
state = stNormal
|
||
}
|
||
|
||
case stCSI:
|
||
if b == 0x07 {
|
||
// Unexpected BEL mid-CSI: flush incomplete sequence, ring bell.
|
||
if err := flushEsc(); err != nil {
|
||
return err
|
||
}
|
||
state = stNormal
|
||
bell()
|
||
} else {
|
||
esc = append(esc, b)
|
||
if b >= 0x40 && b <= 0x7e { // final byte
|
||
if err := flushEsc(); err != nil {
|
||
return err
|
||
}
|
||
state = stNormal
|
||
}
|
||
}
|
||
|
||
case stOSC:
|
||
esc = append(esc, b)
|
||
if b == 0x07 {
|
||
// BEL terminates OSC — pass the whole sequence through unchanged.
|
||
if err := flushEsc(); err != nil {
|
||
return err
|
||
}
|
||
state = stNormal
|
||
} else if b == 0x1b {
|
||
parent = stOSC
|
||
state = stST
|
||
}
|
||
|
||
case stStr: // DCS/SOS/PM/APC — only ST terminates
|
||
esc = append(esc, b)
|
||
if b == 0x1b {
|
||
parent = stStr
|
||
state = stST
|
||
}
|
||
|
||
case stST:
|
||
esc = append(esc, b)
|
||
switch b {
|
||
case '\\': // String Terminator complete
|
||
if err := flushEsc(); err != nil {
|
||
return err
|
||
}
|
||
state = stNormal
|
||
case 0x07:
|
||
if parent == stOSC {
|
||
// BEL terminates OSC even after an intermediate ESC.
|
||
if err := flushEsc(); err != nil {
|
||
return err
|
||
}
|
||
state = stNormal
|
||
} else {
|
||
// Not a terminator for DCS/SOS/PM/APC; back to parent.
|
||
state = parent
|
||
}
|
||
case 0x1b:
|
||
// Another ESC inside string; stay in stST.
|
||
default:
|
||
// Not ST; back to parent string state.
|
||
state = parent
|
||
}
|
||
}
|
||
}
|
||
|
||
if readErr != nil {
|
||
_ = flushEsc() // flush any buffered incomplete sequence
|
||
if isExpectedPTYErr(readErr) {
|
||
return nil
|
||
}
|
||
return readErr
|
||
}
|
||
}
|
||
}
|
||
|
||
// isExpectedPTYErr returns true for errors that indicate normal PTY shutdown.
|
||
func isExpectedPTYErr(err error) bool {
|
||
if errors.Is(err, io.EOF) {
|
||
return true
|
||
}
|
||
var errno syscall.Errno
|
||
if errors.As(err, &errno) {
|
||
return errno == syscall.EIO
|
||
}
|
||
return false
|
||
}
|
||
|
||
// --- Main ---
|
||
|
||
func main() {
|
||
if len(os.Args) < 2 {
|
||
fmt.Fprintf(os.Stderr, "Usage: bellpilot <command> [args...]\n")
|
||
os.Exit(1)
|
||
}
|
||
os.Exit(run(os.Args[1], os.Args[2:]))
|
||
}
|
||
|
||
func run(name string, args []string) int {
|
||
go engine.run()
|
||
|
||
c := exec.Command(name, args...)
|
||
|
||
ptmx, err := pty.Start(c)
|
||
if err != nil {
|
||
log.Printf("start: %v", err)
|
||
return 1
|
||
}
|
||
defer ptmx.Close()
|
||
|
||
// Propagate terminal resize to PTY.
|
||
winchCh := make(chan os.Signal, 1)
|
||
signal.Notify(winchCh, syscall.SIGWINCH)
|
||
go func() {
|
||
for range winchCh {
|
||
if err := pty.InheritSize(os.Stdin, ptmx); err != nil {
|
||
log.Printf("resize: %v", err)
|
||
}
|
||
}
|
||
}()
|
||
winchCh <- syscall.SIGWINCH
|
||
|
||
// Forward termination signals to the child.
|
||
sigCh := make(chan os.Signal, 1)
|
||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||
go func() {
|
||
for sig := range sigCh {
|
||
c.Process.Signal(sig)
|
||
}
|
||
}()
|
||
|
||
// Put stdin in raw mode if it is a terminal.
|
||
if term.IsTerminal(int(os.Stdin.Fd())) {
|
||
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
||
if err != nil {
|
||
log.Printf("make raw: %v", err)
|
||
return 1
|
||
}
|
||
defer term.Restore(int(os.Stdin.Fd()), oldState)
|
||
}
|
||
|
||
go io.Copy(ptmx, os.Stdin)
|
||
|
||
if err := filterLoop(os.Stdout, ptmx, engine.ring); err != nil {
|
||
log.Printf("filter: %v", err)
|
||
}
|
||
|
||
if err := c.Wait(); err != nil {
|
||
var exitErr *exec.ExitError
|
||
if errors.As(err, &exitErr) {
|
||
return exitErr.ExitCode()
|
||
}
|
||
log.Printf("wait: %v", err)
|
||
return 1
|
||
}
|
||
return 0
|
||
}
|