Improve little things

This commit is contained in:
Sebastiaan de Schaetzen 2024-11-13 11:50:52 +01:00
parent 73ff32aa69
commit 2816d46962
7 changed files with 209 additions and 138 deletions

43
cli.go Normal file
View File

@ -0,0 +1,43 @@
package main
import (
"fmt"
"github.com/elk-language/go-prompt"
"github.com/ollama/ollama/api"
"strings"
)
func onUserInput(input string) {
if strings.HasPrefix(input, "/") {
executeCommand(input)
return
}
fmt.Print(colorRole(MT_ASSISTANT), ": ")
sendPromptInput(input, func(r api.ChatResponse) error {
_, err := fmt.Print(r.Message.Content)
return err
})
fmt.Println()
}
func runAsCommandLine() {
runner := prompt.New(
onUserInput,
prompt.WithTitle("llamachat"),
prompt.WithPrefix("user: "),
prompt.WithCompleter(Completer),
)
runner.Run()
}
func executeCommand(cli string) {
args := strings.Split(cli, " ")
for _, cmd := range Commands {
if cmd.Name == args[0] {
cmd.Action(args)
return
}
}
fmt.Println("Unknown command")
}

12
db.go
View File

@ -19,11 +19,14 @@ var db *sql.DB
func OpenDb(dbName string) { func OpenDb(dbName string) {
if dbName == "" || dbName == "-" { if dbName == "" || dbName == "-" {
log.Println("Using in-memory database") log.Println("Using in-memory database")
dbName = "memory" dbName = ":memory:"
} else {
dbName = dbName + ".chat"
log.Printf("Using database %s\n", dbName)
} }
var err error var err error
db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s.chat?_foreign_keys=on", dbName)) db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s?_foreign_keys=on", dbName))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -39,10 +42,7 @@ func OpenDb(dbName string) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
//err = m.Down()
//if err != nil && !errors.Is(err, migrate.ErrNoChange) {
// log.Println("Could not go down:", err)
//}
err = m.Up() err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) { if err != nil && !errors.Is(err, migrate.ErrNoChange) {
log.Fatal(err) log.Fatal(err)

160
main.go
View File

@ -1,25 +1,17 @@
package main package main
import ( import (
"bytes"
"context" "context"
"encoding/json" "flag"
"fmt" "fmt"
"github.com/elk-language/go-prompt"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/gorilla/websocket"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"html/template"
"io"
"log" "log"
"net/http"
"os" "os"
"strings"
) )
var conversation []api.Message var conversation []api.Message
var ollama *api.Client var ollama *api.Client
var upgrader = websocket.Upgrader{}
func convertRole(role MessageType) string { func convertRole(role MessageType) string {
if role == MT_SYSTEM { if role == MT_SYSTEM {
@ -69,34 +61,11 @@ func loadMessageFromDb() []api.Message {
Content: msg.Content, Content: msg.Content,
} }
chatMessages = append(chatMessages, message) chatMessages = append(chatMessages, message)
fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content) //fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content)
} }
return chatMessages return chatMessages
} }
func executeCommand(cli string) {
args := strings.Split(cli, " ")
for _, cmd := range Commands {
if cmd.Name == args[0] {
cmd.Action(args)
return
}
}
fmt.Println("Unknown command")
}
func onUserInput(input string) {
if strings.HasPrefix(input, "/") {
executeCommand(input)
return
}
sendPromptInput(input, func(r api.ChatResponse) error {
_, err := fmt.Print(r.Message.Content)
return err
})
fmt.Println()
}
func sendPromptInput(input string, handler func(response api.ChatResponse) error) { func sendPromptInput(input string, handler func(response api.ChatResponse) error) {
err := SaveMessage(MT_USER, input) err := SaveMessage(MT_USER, input)
if err != nil { if err != nil {
@ -127,7 +96,6 @@ func sendPromptInput(input string, handler func(response api.ChatResponse) error
return SaveMessage(MT_ASSISTANT, fullResponse) return SaveMessage(MT_ASSISTANT, fullResponse)
} }
fmt.Print(colorRole(MT_ASSISTANT), ": ")
err = ollama.Chat(ctx, req, respFunc) err = ollama.Chat(ctx, req, respFunc)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -138,112 +106,44 @@ func sendPromptInput(input string, handler func(response api.ChatResponse) error
}) })
} }
func executeTemplate(w io.Writer, name string, data any) error { const usage = `Usage:
t, err := template. llamachat [OPTION]... [CONVERSATION]
New("templates").
Funcs(template.FuncMap{
"trim": strings.TrimSpace,
}).
ParseGlob("templates/*.gohtml")
if err != nil {
return err
}
err = t.ExecuteTemplate(w, name, data)
if err != nil {
return err
}
return nil
}
func servePage(w http.ResponseWriter, _ *http.Request) { Options:
wm := WebModel{} -w, --web Access via a web page.
wm.AddMessages(conversation)
err := executeTemplate(w, "index", wm)
if err != nil {
log.Println(err)
}
}
func serveWebSocket(w http.ResponseWriter, r *http.Request) { If CONVERSATION is left out, the conversation will be kept in-memory. If it is
log.Println("Websocket connected") present, the conversation will be stored in a file called CONVERSATION.chat.
c, err := upgrader.Upgrade(w, r, nil) Alternatively, it can also be passed the literal string '-' (a singel dash) to
if err != nil { ensure an in-memory database is used.
log.Println("upgrade: ", err) `
return
}
defer c.Close()
for {
// Read the prompt
mt, message, err := c.ReadMessage()
if err != nil {
log.Println("Read error: ", err)
break
}
formInput := make(map[string]any)
err = json.Unmarshal(message, &formInput)
if err != nil {
log.Println("Unmarshal: ", err)
break
}
str, ok := formInput["message"].(string)
if !ok {
log.Println("Invalid user input: ", err)
break
}
// Send the request to Ollama
fullResponse := ""
firstResponse := true
sendPromptInput(str, func(response api.ChatResponse) error {
fullResponse = fullResponse + response.Message.Content
var content bytes.Buffer
err = executeTemplate(&content, "message", WebMessage{
Id: len(conversation) + 1,
New: firstResponse,
Replace: !firstResponse,
Content: fullResponse,
})
if err != nil {
return err
}
firstResponse = false
return c.WriteMessage(mt, content.Bytes())
})
}
}
func main() { func main() {
flag.Usage = func() { fmt.Fprintf(os.Stderr, "%s\n", usage) }
var err error var err error
dbName := "" var useWeb bool
if len(os.Args) > 1 { db := ""
dbName = os.Args[1] flag.BoolVar(&useWeb, "web", false, "run in web mode")
flag.BoolVar(&useWeb, "w", false, "run in web mode")
flag.Parse()
db = flag.Arg(0)
if db == "-" {
db = ""
} }
OpenDb(db)
conversation = loadMessageFromDb()
ollama, err = api.ClientFromEnvironment() ollama, err = api.ClientFromEnvironment()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
OpenDb(dbName) if useWeb {
conversation = loadMessageFromDb() runAsWeb()
} else {
fs := http.FileServer(http.Dir("static")) runAsCommandLine()
http.Handle("/static/", http.StripPrefix("/static/", fs)) }
http.HandleFunc("/", servePage)
http.HandleFunc("/ws", serveWebSocket)
go func() {
err := http.ListenAndServe(":8080", nil)
if err != nil {
log.Fatal(err)
}
}()
runner := prompt.New(
onUserInput,
prompt.WithTitle("llamachat"),
prompt.WithPrefix("user: "),
prompt.WithCompleter(Completer),
)
runner.Run()
} }

View File

@ -1,5 +1,4 @@
function scrollToBottom() { function scrollToBottom() {
console.log("Scrolling");
window.scrollTo(0, document.body.scrollHeight); window.scrollTo(0, document.body.scrollHeight);
} }

View File

@ -3,6 +3,7 @@ body {
} }
.message { .message {
background-color: whitesmoke;
border: 1px solid #000; border: 1px solid #000;
margin: 5px 0; margin: 5px 0;
padding: 16px; padding: 16px;

View File

@ -24,7 +24,7 @@
{{end}} {{end}}
</div> </div>
<form class="message-box" id="message-box" ws-send> <form class="message-box" id="message-box" ws-send hx-on::ws-before-send="this.reset();">
<input id="message-input" name="message" type="text" placeholder="Type a message"> <input id="message-input" name="message" type="text" placeholder="Type a message">
<input type="submit" value="Send"> <input type="submit" value="Send">
</form> </form>

128
web.go Normal file
View File

@ -0,0 +1,128 @@
package main
import (
"bytes"
"encoding/json"
"github.com/gorilla/websocket"
"github.com/ollama/ollama/api"
"html/template"
"io"
"log"
"net"
"net/http"
"strings"
)
var upgrader = websocket.Upgrader{}
func runAsWeb() {
log.Println("Starting web server")
fs := http.FileServer(http.Dir("static"))
http.Handle("/static/", http.StripPrefix("/static/", fs))
http.HandleFunc("/", servePage)
http.HandleFunc("/ws", serveWebSocket)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
log.Fatal(err)
}
log.Printf("Listening on http://127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)
err = http.Serve(listener, nil)
if err != nil {
log.Fatal(err)
}
}
func executeTemplate(w io.Writer, name string, data any) error {
t, err := template.
New("templates").
Funcs(template.FuncMap{
"trim": strings.TrimSpace,
}).
ParseGlob("templates/*.gohtml")
if err != nil {
return err
}
err = t.ExecuteTemplate(w, name, data)
if err != nil {
return err
}
return nil
}
func servePage(w http.ResponseWriter, _ *http.Request) {
wm := WebModel{}
wm.AddMessages(conversation)
err := executeTemplate(w, "index", wm)
if err != nil {
log.Println(err)
}
}
func serveWebSocket(w http.ResponseWriter, r *http.Request) {
log.Println("Websocket connected")
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("upgrade: ", err)
return
}
defer c.Close()
for {
// Read the prompt
mt, message, err := c.ReadMessage()
if err != nil {
log.Println("Read error: ", err)
break
}
formInput := make(map[string]any)
err = json.Unmarshal(message, &formInput)
if err != nil {
log.Println("Unmarshal: ", err)
break
}
userRequest, ok := formInput["message"].(string)
if !ok {
log.Println("Invalid user input: ", err)
break
}
err = func() error {
var content bytes.Buffer
err := executeTemplate(&content, "message", WebMessage{
Id: len(conversation) + 1,
New: true,
Replace: false,
Content: userRequest,
})
if err != nil {
return err
}
return c.WriteMessage(mt, content.Bytes())
}()
if err != nil {
log.Println("Sending initial response: ", err)
break
}
// Send the request to Ollama
fullResponse := ""
firstResponse := true
sendPromptInput(userRequest, func(response api.ChatResponse) error {
fullResponse = fullResponse + response.Message.Content
var content bytes.Buffer
err = executeTemplate(&content, "message", WebMessage{
Id: len(conversation) + 1,
New: firstResponse,
Replace: !firstResponse,
Content: fullResponse,
})
if err != nil {
return err
}
firstResponse = false
return c.WriteMessage(mt, content.Bytes())
})
}
}