250 lines
5.1 KiB
Go
250 lines
5.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/elk-language/go-prompt"
|
|
"github.com/fatih/color"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/ollama/ollama/api"
|
|
"html/template"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
)
|
|
|
|
var conversation []api.Message
|
|
var ollama *api.Client
|
|
var upgrader = websocket.Upgrader{}
|
|
|
|
func convertRole(role MessageType) string {
|
|
if role == MT_SYSTEM {
|
|
return "system"
|
|
} else if role == MT_ASSISTANT {
|
|
return "assistant"
|
|
} else if role == MT_USER {
|
|
return "user"
|
|
} else {
|
|
log.Fatalf("Invalid role type %d\n", role)
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func roleToInt(role string) MessageType {
|
|
if role == "system" {
|
|
return MT_SYSTEM
|
|
} else if role == "assistant" {
|
|
return MT_ASSISTANT
|
|
} else if role == "user" {
|
|
return MT_USER
|
|
} else {
|
|
log.Fatalf("Invalid role type %s\n", role)
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func colorRole(role MessageType) string {
|
|
if role == MT_SYSTEM {
|
|
return color.RedString("system")
|
|
} else if role == MT_ASSISTANT {
|
|
return color.GreenString("assistant")
|
|
} else if role == MT_USER {
|
|
return color.BlueString("user")
|
|
} else {
|
|
log.Fatalf("Invalid role type %d\n", role)
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func loadMessageFromDb() []api.Message {
|
|
dbMessages := GetMesages()
|
|
var chatMessages []api.Message
|
|
for _, msg := range dbMessages {
|
|
message := api.Message{
|
|
Role: convertRole(msg.Type),
|
|
Content: msg.Content,
|
|
}
|
|
chatMessages = append(chatMessages, message)
|
|
fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content)
|
|
}
|
|
return chatMessages
|
|
}
|
|
|
|
func executeCommand(cli string) {
|
|
args := strings.Split(cli, " ")
|
|
for _, cmd := range Commands {
|
|
if cmd.Name == args[0] {
|
|
cmd.Action(args)
|
|
return
|
|
}
|
|
}
|
|
fmt.Println("Unknown command")
|
|
}
|
|
|
|
func onUserInput(input string) {
|
|
if strings.HasPrefix(input, "/") {
|
|
executeCommand(input)
|
|
return
|
|
}
|
|
sendPromptInput(input, func(r api.ChatResponse) error {
|
|
_, err := fmt.Print(r.Message.Content)
|
|
return err
|
|
})
|
|
fmt.Println()
|
|
}
|
|
|
|
func sendPromptInput(input string, handler func(response api.ChatResponse) error) {
|
|
err := SaveMessage(MT_USER, input)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
conversation = append(conversation, api.Message{
|
|
Role: convertRole(MT_USER),
|
|
Content: input,
|
|
})
|
|
|
|
ctx := context.Background()
|
|
req := &api.ChatRequest{
|
|
Model: "llama3.2:1b",
|
|
Messages: conversation,
|
|
}
|
|
|
|
fullResponse := ""
|
|
respFunc := func(resp api.ChatResponse) error {
|
|
fullResponse = fullResponse + resp.Message.Content
|
|
err := handler(resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !resp.Done {
|
|
return nil
|
|
}
|
|
return SaveMessage(MT_ASSISTANT, fullResponse)
|
|
}
|
|
|
|
fmt.Print(colorRole(MT_ASSISTANT), ": ")
|
|
err = ollama.Chat(ctx, req, respFunc)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
conversation = append(conversation, api.Message{
|
|
Role: convertRole(MT_ASSISTANT),
|
|
Content: fullResponse,
|
|
})
|
|
}
|
|
|
|
func executeTemplate(w io.Writer, name string, data any) error {
|
|
t, err := template.
|
|
New("templates").
|
|
Funcs(template.FuncMap{
|
|
"trim": strings.TrimSpace,
|
|
}).
|
|
ParseGlob("templates/*.gohtml")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = t.ExecuteTemplate(w, name, data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func servePage(w http.ResponseWriter, _ *http.Request) {
|
|
wm := WebModel{}
|
|
wm.AddMessages(conversation)
|
|
err := executeTemplate(w, "index", wm)
|
|
if err != nil {
|
|
log.Println(err)
|
|
}
|
|
}
|
|
|
|
func serveWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
log.Println("Websocket connected")
|
|
c, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Println("upgrade: ", err)
|
|
return
|
|
}
|
|
defer c.Close()
|
|
for {
|
|
// Read the prompt
|
|
mt, message, err := c.ReadMessage()
|
|
if err != nil {
|
|
log.Println("Read error: ", err)
|
|
break
|
|
}
|
|
formInput := make(map[string]any)
|
|
err = json.Unmarshal(message, &formInput)
|
|
if err != nil {
|
|
log.Println("Unmarshal: ", err)
|
|
break
|
|
}
|
|
str, ok := formInput["message"].(string)
|
|
if !ok {
|
|
log.Println("Invalid user input: ", err)
|
|
break
|
|
}
|
|
|
|
// Send the request to Ollama
|
|
fullResponse := ""
|
|
firstResponse := true
|
|
sendPromptInput(str, func(response api.ChatResponse) error {
|
|
fullResponse = fullResponse + response.Message.Content
|
|
var content bytes.Buffer
|
|
err = executeTemplate(&content, "message", WebMessage{
|
|
Id: len(conversation) + 1,
|
|
New: firstResponse,
|
|
Replace: !firstResponse,
|
|
Content: fullResponse,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
firstResponse = false
|
|
return c.WriteMessage(mt, content.Bytes())
|
|
})
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
var err error
|
|
dbName := ""
|
|
if len(os.Args) > 1 {
|
|
dbName = os.Args[1]
|
|
}
|
|
|
|
ollama, err = api.ClientFromEnvironment()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
OpenDb(dbName)
|
|
conversation = loadMessageFromDb()
|
|
|
|
fs := http.FileServer(http.Dir("static"))
|
|
http.Handle("/static/", http.StripPrefix("/static/", fs))
|
|
http.HandleFunc("/", servePage)
|
|
http.HandleFunc("/ws", serveWebSocket)
|
|
go func() {
|
|
err := http.ListenAndServe(":8080", nil)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
runner := prompt.New(
|
|
onUserInput,
|
|
prompt.WithTitle("llamachat"),
|
|
prompt.WithPrefix("user: "),
|
|
prompt.WithCompleter(Completer),
|
|
)
|
|
|
|
runner.Run()
|
|
}
|