diff --git a/.gitignore b/.gitignore index 35949e7..23f9614 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ # Ignore chat histories *.chat + +# The binary +llamachat diff --git a/db.go b/db.go index d70816e..e352330 100644 --- a/db.go +++ b/db.go @@ -17,6 +17,11 @@ var fs embed.FS var db *sql.DB func OpenDb(dbName string) { + if dbName == "" || dbName == "-" { + log.Println("Using in-memory database") + dbName = "memory" + } + var err error db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s.chat?_foreign_keys=on", dbName)) if err != nil { diff --git a/go.mod b/go.mod index 92e008d..cb4530c 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/elk-language/go-prompt v1.1.5 github.com/fatih/color v1.18.0 github.com/golang-migrate/migrate/v4 v4.18.1 + github.com/gorilla/websocket v1.5.3 github.com/ollama/ollama v0.4.0 ) diff --git a/go.sum b/go.sum index cd3a2c6..93ac290 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4C github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/main.go b/main.go index e1ab92f..6bdb993 100644 --- a/main.go +++ b/main.go @@ -1,18 +1,25 @@ package main import ( + "bytes" "context" + "encoding/json" "fmt" "github.com/elk-language/go-prompt" "github.com/fatih/color" + "github.com/gorilla/websocket" "github.com/ollama/ollama/api" + "html/template" + "io" "log" + "net/http" "os" "strings" ) var conversation []api.Message var ollama *api.Client +var upgrader = websocket.Upgrader{} func convertRole(role MessageType) string { if role == MT_SYSTEM { @@ -83,6 +90,14 @@ func onUserInput(input string) { executeCommand(input) return } + sendPromptInput(input, func(r api.ChatResponse) error { + _, err := fmt.Print(r.Message.Content) + return err + }) + fmt.Println() +} + +func sendPromptInput(input string, handler func(response api.ChatResponse) error) { err := SaveMessage(MT_USER, input) if err != nil { log.Fatal(err) @@ -102,7 +117,10 @@ func onUserInput(input string) { fullResponse := "" respFunc := func(resp api.ChatResponse) error { fullResponse = fullResponse + resp.Message.Content - fmt.Print(resp.Message.Content) + err := handler(resp) + if err != nil { + return err + } if !resp.Done { return nil } @@ -118,13 +136,87 @@ func onUserInput(input string) { Role: convertRole(MT_ASSISTANT), Content: fullResponse, }) - fmt.Println() +} + +func executeTemplate(w io.Writer, name string, data any) error { + t, err := template. + New("templates"). + Funcs(template.FuncMap{ + "trim": strings.TrimSpace, + }). + ParseGlob("templates/*.gohtml") + if err != nil { + return err + } + err = t.ExecuteTemplate(w, name, data) + if err != nil { + return err + } + return nil +} + +func servePage(w http.ResponseWriter, _ *http.Request) { + wm := WebModel{} + wm.AddMessages(conversation) + err := executeTemplate(w, "index", wm) + if err != nil { + log.Println(err) + } +} + +func serveWebSocket(w http.ResponseWriter, r *http.Request) { + log.Println("Websocket connected") + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println("upgrade: ", err) + return + } + defer c.Close() + for { + // Read the prompt + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("Read error: ", err) + break + } + formInput := make(map[string]any) + err = json.Unmarshal(message, &formInput) + if err != nil { + log.Println("Unmarshal: ", err) + break + } + str, ok := formInput["message"].(string) + if !ok { + log.Println("Invalid user input: ", err) + break + } + + // Send the request to Ollama + fullResponse := "" + firstResponse := true + sendPromptInput(str, func(response api.ChatResponse) error { + fullResponse = fullResponse + response.Message.Content + var content bytes.Buffer + err = executeTemplate(&content, "message", WebMessage{ + Id: len(conversation) + 1, + New: firstResponse, + Replace: !firstResponse, + Content: fullResponse, + }) + if err != nil { + return err + } + firstResponse = false + return c.WriteMessage(mt, content.Bytes()) + }) + } } func main() { var err error - if len(os.Args) <= 1 { - log.Fatal("Missing command line parameter") + dbName := "" + if len(os.Args) > 1 { + dbName = os.Args[1] } ollama, err = api.ClientFromEnvironment() @@ -132,10 +224,20 @@ func main() { log.Fatal(err) } - dbName := os.Args[1] OpenDb(dbName) conversation = loadMessageFromDb() + fs := http.FileServer(http.Dir("static")) + http.Handle("/static/", http.StripPrefix("/static/", fs)) + http.HandleFunc("/", servePage) + http.HandleFunc("/ws", serveWebSocket) + go func() { + err := http.ListenAndServe(":8080", nil) + if err != nil { + log.Fatal(err) + } + }() + runner := prompt.New( onUserInput, prompt.WithTitle("llamachat"), diff --git a/static/script.js b/static/script.js new file mode 100644 index 0000000..f24be41 --- /dev/null +++ b/static/script.js @@ -0,0 +1,12 @@ +function scrollToBottom() { + console.log("Scrolling"); + window.scrollTo(0, document.body.scrollHeight); +} + +window.onload = function() { + scrollToBottom(); + htmx.config.allowNestedOobSwaps = false + htmx.config.wsReconnectDelay = function(retryCount) { + return 250; + } +} diff --git a/static/style.css b/static/style.css new file mode 100644 index 0000000..a3fd122 --- /dev/null +++ b/static/style.css @@ -0,0 +1,26 @@ +body { + background-color: #e4faff; +} + +.message { + border: 1px solid #000; + margin: 5px 0; + padding: 16px; +} + +.content { + white-space: pre-wrap; +} + +.content-box { + width: 80%; + margin: 0 auto; +} + +.message-box { + margin: 64px 0 64px; +} + +.message-input { + width: 100%; +} diff --git a/templates/index.gohtml b/templates/index.gohtml new file mode 100644 index 0000000..49ab0db --- /dev/null +++ b/templates/index.gohtml @@ -0,0 +1,35 @@ +{{- /*gotype:llamachat.WebPageModel*/ -}} +{{define "index"}} + + + + + Llamachat + + + + + + + +
+
+ {{range .Conversation}} + {{- /*gotype:llamachat.WebMessage*/ -}} + {{block "message-frag" .}} +
+ {{ .Content }} +
+ {{end}} + {{end}} +
+ +
+ + +
+
+ + + +{{end}} diff --git a/templates/message.gohtml b/templates/message.gohtml new file mode 100644 index 0000000..be6c7b8 --- /dev/null +++ b/templates/message.gohtml @@ -0,0 +1,9 @@ +{{- /*gotype:llamachat.WebMessage*/ -}} +{{define "message"}} +
+ {{template "message-frag" .}} +
+{{end}} diff --git a/webmodels.go b/webmodels.go new file mode 100644 index 0000000..0d64150 --- /dev/null +++ b/webmodels.go @@ -0,0 +1,31 @@ +package main + +import "github.com/ollama/ollama/api" + +type WebModel struct { + Conversation []WebMessage +} + +type WebMessage struct { + Id int + Replace bool + New bool + Content string +} + +func ConvertMessage(message api.Message, id int) WebMessage { + return WebMessage{ + Id: id, + Content: message.Content, + } +} + +func (wm *WebModel) AddMessage(message api.Message, id int) { + wm.Conversation = append(wm.Conversation, ConvertMessage(message, id)) +} + +func (wm *WebModel) AddMessages(messages []api.Message) { + for id, message := range messages { + wm.AddMessage(message, id) + } +}