llamachat/main.go

250 lines
5.1 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/elk-language/go-prompt"
"github.com/fatih/color"
"github.com/gorilla/websocket"
"github.com/ollama/ollama/api"
"html/template"
"io"
"log"
"net/http"
"os"
"strings"
)
var conversation []api.Message
var ollama *api.Client
var upgrader = websocket.Upgrader{}
func convertRole(role MessageType) string {
if role == MT_SYSTEM {
return "system"
} else if role == MT_ASSISTANT {
return "assistant"
} else if role == MT_USER {
return "user"
} else {
log.Fatalf("Invalid role type %d\n", role)
return ""
}
}
func roleToInt(role string) MessageType {
if role == "system" {
return MT_SYSTEM
} else if role == "assistant" {
return MT_ASSISTANT
} else if role == "user" {
return MT_USER
} else {
log.Fatalf("Invalid role type %s\n", role)
return 0
}
}
func colorRole(role MessageType) string {
if role == MT_SYSTEM {
return color.RedString("system")
} else if role == MT_ASSISTANT {
return color.GreenString("assistant")
} else if role == MT_USER {
return color.BlueString("user")
} else {
log.Fatalf("Invalid role type %d\n", role)
return ""
}
}
func loadMessageFromDb() []api.Message {
dbMessages := GetMesages()
var chatMessages []api.Message
for _, msg := range dbMessages {
message := api.Message{
Role: convertRole(msg.Type),
Content: msg.Content,
}
chatMessages = append(chatMessages, message)
fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content)
}
return chatMessages
}
func executeCommand(cli string) {
args := strings.Split(cli, " ")
for _, cmd := range Commands {
if cmd.Name == args[0] {
cmd.Action(args)
return
}
}
fmt.Println("Unknown command")
}
func onUserInput(input string) {
if strings.HasPrefix(input, "/") {
executeCommand(input)
return
}
sendPromptInput(input, func(r api.ChatResponse) error {
_, err := fmt.Print(r.Message.Content)
return err
})
fmt.Println()
}
func sendPromptInput(input string, handler func(response api.ChatResponse) error) {
err := SaveMessage(MT_USER, input)
if err != nil {
log.Fatal(err)
}
conversation = append(conversation, api.Message{
Role: convertRole(MT_USER),
Content: input,
})
ctx := context.Background()
req := &api.ChatRequest{
Model: "llama3.2:1b",
Messages: conversation,
}
fullResponse := ""
respFunc := func(resp api.ChatResponse) error {
fullResponse = fullResponse + resp.Message.Content
err := handler(resp)
if err != nil {
return err
}
if !resp.Done {
return nil
}
return SaveMessage(MT_ASSISTANT, fullResponse)
}
fmt.Print(colorRole(MT_ASSISTANT), ": ")
err = ollama.Chat(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
conversation = append(conversation, api.Message{
Role: convertRole(MT_ASSISTANT),
Content: fullResponse,
})
}
func executeTemplate(w io.Writer, name string, data any) error {
t, err := template.
New("templates").
Funcs(template.FuncMap{
"trim": strings.TrimSpace,
}).
ParseGlob("templates/*.gohtml")
if err != nil {
return err
}
err = t.ExecuteTemplate(w, name, data)
if err != nil {
return err
}
return nil
}
func servePage(w http.ResponseWriter, _ *http.Request) {
wm := WebModel{}
wm.AddMessages(conversation)
err := executeTemplate(w, "index", wm)
if err != nil {
log.Println(err)
}
}
func serveWebSocket(w http.ResponseWriter, r *http.Request) {
log.Println("Websocket connected")
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("upgrade: ", err)
return
}
defer c.Close()
for {
// Read the prompt
mt, message, err := c.ReadMessage()
if err != nil {
log.Println("Read error: ", err)
break
}
formInput := make(map[string]any)
err = json.Unmarshal(message, &formInput)
if err != nil {
log.Println("Unmarshal: ", err)
break
}
str, ok := formInput["message"].(string)
if !ok {
log.Println("Invalid user input: ", err)
break
}
// Send the request to Ollama
fullResponse := ""
firstResponse := true
sendPromptInput(str, func(response api.ChatResponse) error {
fullResponse = fullResponse + response.Message.Content
var content bytes.Buffer
err = executeTemplate(&content, "message", WebMessage{
Id: len(conversation) + 1,
New: firstResponse,
Replace: !firstResponse,
Content: fullResponse,
})
if err != nil {
return err
}
firstResponse = false
return c.WriteMessage(mt, content.Bytes())
})
}
}
func main() {
var err error
dbName := ""
if len(os.Args) > 1 {
dbName = os.Args[1]
}
ollama, err = api.ClientFromEnvironment()
if err != nil {
log.Fatal(err)
}
OpenDb(dbName)
conversation = loadMessageFromDb()
fs := http.FileServer(http.Dir("static"))
http.Handle("/static/", http.StripPrefix("/static/", fs))
http.HandleFunc("/", servePage)
http.HandleFunc("/ws", serveWebSocket)
go func() {
err := http.ListenAndServe(":8080", nil)
if err != nil {
log.Fatal(err)
}
}()
runner := prompt.New(
onUserInput,
prompt.WithTitle("llamachat"),
prompt.WithPrefix("user: "),
prompt.WithCompleter(Completer),
)
runner.Run()
}