Files
bellpilot/main.go

488 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"embed"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"os"
"os/exec"
"os/signal"
"sync"
"syscall"
"time"
"github.com/creack/pty"
"github.com/ebitengine/oto/v3"
"golang.org/x/term"
)
//go:embed bells.wav
var bellFile embed.FS
// --- Audio engine ---
type audioEngine struct {
once sync.Once
ctx *oto.Context
pcm []byte
initErr error
bellCh chan struct{}
}
var engine = &audioEngine{
bellCh: make(chan struct{}, 1),
}
func (e *audioEngine) init() {
e.once.Do(func() {
data, err := bellFile.ReadFile("bells.wav")
if err != nil {
e.initErr = fmt.Errorf("read bell file: %w", err)
return
}
pcm, sampleRate, channels, format, err := parseWAV(data)
if err != nil {
e.initErr = fmt.Errorf("parse WAV: %w", err)
return
}
// oto's FormatUnsignedInt8 path has an integer underflow bug: it computes
// float32(v8-(1<<7)) with uint8 arithmetic, wrapping values below 128 into
// large positive numbers. Convert to signed 16-bit to use the correct path.
if format == oto.FormatUnsignedInt8 {
pcm16 := make([]byte, len(pcm)*2)
for i, b := range pcm {
s := (int16(b) - 128) * 256
pcm16[2*i] = byte(s)
pcm16[2*i+1] = byte(s >> 8)
}
pcm = pcm16
format = oto.FormatSignedInt16LE
}
ctx, ready, err := oto.NewContext(&oto.NewContextOptions{
SampleRate: sampleRate,
ChannelCount: channels,
Format: format,
BufferSize: time.Millisecond * 20,
})
if err != nil {
e.initErr = fmt.Errorf("create audio context: %w", err)
return
}
<-ready
e.ctx = ctx
e.pcm = pcm
})
}
// ring queues at most one pending bell, dropping extras while one is already queued.
func (e *audioEngine) ring() {
select {
case e.bellCh <- struct{}{}:
default:
}
}
// run is the single bell-playing goroutine. It must be started once before ring is called.
func (e *audioEngine) run() {
e.init()
if e.initErr != nil {
log.Printf("audio init failed (%v); falling back to terminal bell", e.initErr)
}
for range e.bellCh {
if e.initErr != nil {
os.Stdout.Write([]byte{'\x07'})
continue
}
player := e.ctx.NewPlayer(bytes.NewReader(e.pcm))
player.Play()
for player.IsPlaying() {
time.Sleep(time.Millisecond)
}
}
}
// --- WAV parser ---
func parseWAV(data []byte) (pcm []byte, sampleRate, channels int, format oto.Format, err error) {
r := bytes.NewReader(data)
var id [4]byte
if err = binary.Read(r, binary.LittleEndian, &id); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read RIFF id: %w", err)
}
if string(id[:]) != "RIFF" {
return nil, 0, 0, 0, fmt.Errorf("not a RIFF file")
}
var fileSize uint32
if err = binary.Read(r, binary.LittleEndian, &fileSize); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read file size: %w", err)
}
if err = binary.Read(r, binary.LittleEndian, &id); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read WAVE id: %w", err)
}
if string(id[:]) != "WAVE" {
return nil, 0, 0, 0, fmt.Errorf("not a WAVE file")
}
var hdr struct {
AudioFormat uint16
NumChannels uint16
SampleRate uint32
ByteRate uint32
BlockAlign uint16
BitsPerSample uint16
}
var fmtFound bool
for {
var chunkID [4]byte
if err = binary.Read(r, binary.LittleEndian, &chunkID); err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
return nil, 0, 0, 0, fmt.Errorf("read chunk id: %w", err)
}
var chunkSize uint32
if err = binary.Read(r, binary.LittleEndian, &chunkSize); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read chunk size: %w", err)
}
switch string(chunkID[:]) {
case "fmt ":
if chunkSize < 16 {
return nil, 0, 0, 0, fmt.Errorf("fmt chunk too small (%d bytes)", chunkSize)
}
if err = binary.Read(r, binary.LittleEndian, &hdr); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read fmt chunk: %w", err)
}
if chunkSize > 16 {
skip := int64(chunkSize - 16)
if chunkSize%2 != 0 {
skip++
}
if _, err = r.Seek(skip, io.SeekCurrent); err != nil {
return nil, 0, 0, 0, fmt.Errorf("skip fmt extra: %w", err)
}
}
fmtFound = true
case "data":
if !fmtFound {
return nil, 0, 0, 0, fmt.Errorf("data chunk before fmt chunk")
}
sz := chunkSize
if sz > uint32(r.Len()) {
sz = uint32(r.Len())
}
pcm = make([]byte, sz)
if _, err = io.ReadFull(r, pcm); err != nil {
return nil, 0, 0, 0, fmt.Errorf("read data chunk: %w", err)
}
if chunkSize%2 != 0 {
r.Seek(1, io.SeekCurrent)
}
default:
skip := int64(chunkSize)
if chunkSize%2 != 0 {
skip++
}
if _, err = r.Seek(skip, io.SeekCurrent); err != nil {
return nil, 0, 0, 0, fmt.Errorf("skip chunk: %w", err)
}
}
}
if !fmtFound {
return nil, 0, 0, 0, fmt.Errorf("no fmt chunk in WAV file")
}
if pcm == nil {
return nil, 0, 0, 0, fmt.Errorf("no data chunk in WAV file")
}
switch {
case hdr.AudioFormat == 1 && hdr.BitsPerSample == 8:
format = oto.FormatUnsignedInt8
case hdr.AudioFormat == 1 && hdr.BitsPerSample == 16:
format = oto.FormatSignedInt16LE
case hdr.AudioFormat == 3 && hdr.BitsPerSample == 32:
format = oto.FormatFloat32LE
default:
return nil, 0, 0, 0, fmt.Errorf("unsupported WAV format: audio_format=%d, bits=%d", hdr.AudioFormat, hdr.BitsPerSample)
}
return pcm, int(hdr.SampleRate), int(hdr.NumChannels), format, nil
}
// --- PTY filter loop ---
// filterLoop copies src to dst, stripping standalone BEL (0x07) bytes and calling bell()
// for each one. BEL bytes that appear as OSC string terminators are passed through unchanged
// so that the outer terminal does not get stuck mid-sequence.
//
// State machine covers the subset of ANSI/VT escape sequences that can contain 0x07:
//
// Normal → ESC → CSI (ESC [ ... final)
// ESC → OSC (ESC ] ... BEL or ST)
// ESC → Str (ESC P/X/^/_ ... ST) (DCS, SOS, PM, APC)
// ESC → <two-char>
//
// ST = String Terminator = ESC \
func filterLoop(dst io.Writer, src io.Reader, bell func()) error {
const (
stNormal = iota
stEsc // saw 0x1b
stCSI // saw 0x1b [ — terminated by 0x400x7E
stOSC // saw 0x1b ] — terminated by BEL or ST
stStr // saw 0x1b P/X/^/_ — terminated by ST only
stST // saw 0x1b inside stOSC or stStr, checking for '\' (ST)
)
buf := make([]byte, 32*1024)
state := stNormal
parent := stNormal // parent of stST (stOSC or stStr)
esc := make([]byte, 0, 256)
writeAll := func(p []byte) error {
_, err := dst.Write(p)
return err
}
flushEsc := func() error {
err := writeAll(esc)
esc = esc[:0]
return err
}
for {
n, readErr := src.Read(buf)
p := buf[:n]
i := 0
for i < len(p) {
b := p[i]
// Fast path: bulk-copy normal bytes until a special byte is hit.
if state == stNormal {
j := i
for i < len(p) && p[i] != 0x07 && p[i] != 0x1b {
i++
}
if i > j {
if err := writeAll(p[j:i]); err != nil {
return err
}
}
if i >= len(p) {
break
}
b = p[i]
}
i++
switch state {
case stNormal:
if b == 0x07 {
bell()
} else { // 0x1b
esc = append(esc[:0], b)
state = stEsc
}
case stEsc:
switch b {
case '[':
esc = append(esc, b)
state = stCSI
case ']':
esc = append(esc, b)
state = stOSC
case 'P', 'X', '^', '_': // DCS, SOS, PM, APC
esc = append(esc, b)
state = stStr
case 0x1b:
// Another ESC: flush previous incomplete sequence, start fresh.
if err := writeAll(esc); err != nil {
return err
}
esc = esc[:0]
esc = append(esc, b)
// stay in stEsc
case 0x07:
// BEL right after bare ESC: flush ESC, ring bell.
if err := flushEsc(); err != nil {
return err
}
state = stNormal
bell()
default:
// Two-character escape sequence complete.
esc = append(esc, b)
if err := flushEsc(); err != nil {
return err
}
state = stNormal
}
case stCSI:
if b == 0x07 {
// Unexpected BEL mid-CSI: flush incomplete sequence, ring bell.
if err := flushEsc(); err != nil {
return err
}
state = stNormal
bell()
} else {
esc = append(esc, b)
if b >= 0x40 && b <= 0x7e { // final byte
if err := flushEsc(); err != nil {
return err
}
state = stNormal
}
}
case stOSC:
esc = append(esc, b)
if b == 0x07 {
// BEL terminates OSC — pass the whole sequence through unchanged.
if err := flushEsc(); err != nil {
return err
}
state = stNormal
} else if b == 0x1b {
parent = stOSC
state = stST
}
case stStr: // DCS/SOS/PM/APC — only ST terminates
esc = append(esc, b)
if b == 0x1b {
parent = stStr
state = stST
}
case stST:
esc = append(esc, b)
switch b {
case '\\': // String Terminator complete
if err := flushEsc(); err != nil {
return err
}
state = stNormal
case 0x07:
if parent == stOSC {
// BEL terminates OSC even after an intermediate ESC.
if err := flushEsc(); err != nil {
return err
}
state = stNormal
} else {
// Not a terminator for DCS/SOS/PM/APC; back to parent.
state = parent
}
case 0x1b:
// Another ESC inside string; stay in stST.
default:
// Not ST; back to parent string state.
state = parent
}
}
}
if readErr != nil {
_ = flushEsc() // flush any buffered incomplete sequence
if isExpectedPTYErr(readErr) {
return nil
}
return readErr
}
}
}
// isExpectedPTYErr returns true for errors that indicate normal PTY shutdown.
func isExpectedPTYErr(err error) bool {
if errors.Is(err, io.EOF) {
return true
}
var errno syscall.Errno
if errors.As(err, &errno) {
return errno == syscall.EIO
}
return false
}
// --- Main ---
func main() {
if len(os.Args) < 2 {
fmt.Fprintf(os.Stderr, "Usage: bellpilot <command> [args...]\n")
os.Exit(1)
}
os.Exit(run(os.Args[1], os.Args[2:]))
}
func run(name string, args []string) int {
go engine.run()
c := exec.Command(name, args...)
ptmx, err := pty.Start(c)
if err != nil {
log.Printf("start: %v", err)
return 1
}
defer ptmx.Close()
// Propagate terminal resize to PTY.
winchCh := make(chan os.Signal, 1)
signal.Notify(winchCh, syscall.SIGWINCH)
go func() {
for range winchCh {
if err := pty.InheritSize(os.Stdin, ptmx); err != nil {
log.Printf("resize: %v", err)
}
}
}()
winchCh <- syscall.SIGWINCH
// Forward termination signals to the child.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
go func() {
for sig := range sigCh {
c.Process.Signal(sig)
}
}()
// Put stdin in raw mode if it is a terminal.
if term.IsTerminal(int(os.Stdin.Fd())) {
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
log.Printf("make raw: %v", err)
return 1
}
defer term.Restore(int(os.Stdin.Fd()), oldState)
}
go io.Copy(ptmx, os.Stdin)
if err := filterLoop(os.Stdout, ptmx, engine.ring); err != nil {
log.Printf("filter: %v", err)
}
if err := c.Wait(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
return exitErr.ExitCode()
}
log.Printf("wait: %v", err)
return 1
}
return 0
}