Add web support

This commit is contained in:
Sebastiaan de Schaetzen 2024-11-12 13:28:14 +01:00
parent ebad22fbee
commit 73ff32aa69
10 changed files with 231 additions and 5 deletions

3
.gitignore vendored
View File

@ -4,3 +4,6 @@
# Ignore chat histories # Ignore chat histories
*.chat *.chat
# The binary
llamachat

5
db.go
View File

@ -17,6 +17,11 @@ var fs embed.FS
var db *sql.DB var db *sql.DB
func OpenDb(dbName string) { func OpenDb(dbName string) {
if dbName == "" || dbName == "-" {
log.Println("Using in-memory database")
dbName = "memory"
}
var err error var err error
db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s.chat?_foreign_keys=on", dbName)) db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s.chat?_foreign_keys=on", dbName))
if err != nil { if err != nil {

1
go.mod
View File

@ -6,6 +6,7 @@ require (
github.com/elk-language/go-prompt v1.1.5 github.com/elk-language/go-prompt v1.1.5
github.com/fatih/color v1.18.0 github.com/fatih/color v1.18.0
github.com/golang-migrate/migrate/v4 v4.18.1 github.com/golang-migrate/migrate/v4 v4.18.1
github.com/gorilla/websocket v1.5.3
github.com/ollama/ollama v0.4.0 github.com/ollama/ollama v0.4.0
) )

2
go.sum
View File

@ -9,6 +9,8 @@ github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4C
github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=

112
main.go
View File

@ -1,18 +1,25 @@
package main package main
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/elk-language/go-prompt" "github.com/elk-language/go-prompt"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/gorilla/websocket"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"html/template"
"io"
"log" "log"
"net/http"
"os" "os"
"strings" "strings"
) )
var conversation []api.Message var conversation []api.Message
var ollama *api.Client var ollama *api.Client
var upgrader = websocket.Upgrader{}
func convertRole(role MessageType) string { func convertRole(role MessageType) string {
if role == MT_SYSTEM { if role == MT_SYSTEM {
@ -83,6 +90,14 @@ func onUserInput(input string) {
executeCommand(input) executeCommand(input)
return return
} }
sendPromptInput(input, func(r api.ChatResponse) error {
_, err := fmt.Print(r.Message.Content)
return err
})
fmt.Println()
}
func sendPromptInput(input string, handler func(response api.ChatResponse) error) {
err := SaveMessage(MT_USER, input) err := SaveMessage(MT_USER, input)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -102,7 +117,10 @@ func onUserInput(input string) {
fullResponse := "" fullResponse := ""
respFunc := func(resp api.ChatResponse) error { respFunc := func(resp api.ChatResponse) error {
fullResponse = fullResponse + resp.Message.Content fullResponse = fullResponse + resp.Message.Content
fmt.Print(resp.Message.Content) err := handler(resp)
if err != nil {
return err
}
if !resp.Done { if !resp.Done {
return nil return nil
} }
@ -118,13 +136,87 @@ func onUserInput(input string) {
Role: convertRole(MT_ASSISTANT), Role: convertRole(MT_ASSISTANT),
Content: fullResponse, Content: fullResponse,
}) })
fmt.Println() }
func executeTemplate(w io.Writer, name string, data any) error {
t, err := template.
New("templates").
Funcs(template.FuncMap{
"trim": strings.TrimSpace,
}).
ParseGlob("templates/*.gohtml")
if err != nil {
return err
}
err = t.ExecuteTemplate(w, name, data)
if err != nil {
return err
}
return nil
}
func servePage(w http.ResponseWriter, _ *http.Request) {
wm := WebModel{}
wm.AddMessages(conversation)
err := executeTemplate(w, "index", wm)
if err != nil {
log.Println(err)
}
}
func serveWebSocket(w http.ResponseWriter, r *http.Request) {
log.Println("Websocket connected")
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("upgrade: ", err)
return
}
defer c.Close()
for {
// Read the prompt
mt, message, err := c.ReadMessage()
if err != nil {
log.Println("Read error: ", err)
break
}
formInput := make(map[string]any)
err = json.Unmarshal(message, &formInput)
if err != nil {
log.Println("Unmarshal: ", err)
break
}
str, ok := formInput["message"].(string)
if !ok {
log.Println("Invalid user input: ", err)
break
}
// Send the request to Ollama
fullResponse := ""
firstResponse := true
sendPromptInput(str, func(response api.ChatResponse) error {
fullResponse = fullResponse + response.Message.Content
var content bytes.Buffer
err = executeTemplate(&content, "message", WebMessage{
Id: len(conversation) + 1,
New: firstResponse,
Replace: !firstResponse,
Content: fullResponse,
})
if err != nil {
return err
}
firstResponse = false
return c.WriteMessage(mt, content.Bytes())
})
}
} }
func main() { func main() {
var err error var err error
if len(os.Args) <= 1 { dbName := ""
log.Fatal("Missing command line parameter") if len(os.Args) > 1 {
dbName = os.Args[1]
} }
ollama, err = api.ClientFromEnvironment() ollama, err = api.ClientFromEnvironment()
@ -132,10 +224,20 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
dbName := os.Args[1]
OpenDb(dbName) OpenDb(dbName)
conversation = loadMessageFromDb() conversation = loadMessageFromDb()
fs := http.FileServer(http.Dir("static"))
http.Handle("/static/", http.StripPrefix("/static/", fs))
http.HandleFunc("/", servePage)
http.HandleFunc("/ws", serveWebSocket)
go func() {
err := http.ListenAndServe(":8080", nil)
if err != nil {
log.Fatal(err)
}
}()
runner := prompt.New( runner := prompt.New(
onUserInput, onUserInput,
prompt.WithTitle("llamachat"), prompt.WithTitle("llamachat"),

12
static/script.js Normal file
View File

@ -0,0 +1,12 @@
function scrollToBottom() {
console.log("Scrolling");
window.scrollTo(0, document.body.scrollHeight);
}
window.onload = function() {
scrollToBottom();
htmx.config.allowNestedOobSwaps = false
htmx.config.wsReconnectDelay = function(retryCount) {
return 250;
}
}

26
static/style.css Normal file
View File

@ -0,0 +1,26 @@
body {
background-color: #e4faff;
}
.message {
border: 1px solid #000;
margin: 5px 0;
padding: 16px;
}
.content {
white-space: pre-wrap;
}
.content-box {
width: 80%;
margin: 0 auto;
}
.message-box {
margin: 64px 0 64px;
}
.message-input {
width: 100%;
}

35
templates/index.gohtml Normal file
View File

@ -0,0 +1,35 @@
{{- /*gotype:llamachat.WebPageModel*/ -}}
{{define "index"}}
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Llamachat</title>
<link rel="stylesheet" href="static/style.css">
<script src="static/script.js" type="application/javascript"></script>
<script src="https://unpkg.com/htmx.org@2.0.3"></script>
<script src="https://unpkg.com/htmx-ext-ws@2.0.1/ws.js"></script>
</head>
<body hx-ext="ws">
<div class="content-box" id="conversation" ws-connect="/ws" hx-on::ws-after-message="scrollToBottom();">
<div id="conversation">
{{range .Conversation}}
{{- /*gotype:llamachat.WebMessage*/ -}}
{{block "message-frag" .}}
<div class="message" id="message-{{.Id}}">
<span class="content">{{ .Content }}</span>
</div>
{{end}}
{{end}}
</div>
<form class="message-box" id="message-box" ws-send>
<input id="message-input" name="message" type="text" placeholder="Type a message">
<input type="submit" value="Send">
</form>
</div>
</body>
</html>
{{end}}

9
templates/message.gohtml Normal file
View File

@ -0,0 +1,9 @@
{{- /*gotype:llamachat.WebMessage*/ -}}
{{define "message"}}
<div
{{if .New}} hx-swap-oob="beforebegin" id="message-box"
{{else if .Replace}} hx-swap-oob="outerHTML" id="message-{{.Id}}"
{{end}}>
{{template "message-frag" .}}
</div>
{{end}}

31
webmodels.go Normal file
View File

@ -0,0 +1,31 @@
package main
import "github.com/ollama/ollama/api"
type WebModel struct {
Conversation []WebMessage
}
type WebMessage struct {
Id int
Replace bool
New bool
Content string
}
func ConvertMessage(message api.Message, id int) WebMessage {
return WebMessage{
Id: id,
Content: message.Content,
}
}
func (wm *WebModel) AddMessage(message api.Message, id int) {
wm.Conversation = append(wm.Conversation, ConvertMessage(message, id))
}
func (wm *WebModel) AddMessages(messages []api.Message) {
for id, message := range messages {
wm.AddMessage(message, id)
}
}