package main import ( "bytes" "context" "encoding/json" "fmt" "github.com/elk-language/go-prompt" "github.com/fatih/color" "github.com/gorilla/websocket" "github.com/ollama/ollama/api" "html/template" "io" "log" "net/http" "os" "strings" ) var conversation []api.Message var ollama *api.Client var upgrader = websocket.Upgrader{} func convertRole(role MessageType) string { if role == MT_SYSTEM { return "system" } else if role == MT_ASSISTANT { return "assistant" } else if role == MT_USER { return "user" } else { log.Fatalf("Invalid role type %d\n", role) return "" } } func roleToInt(role string) MessageType { if role == "system" { return MT_SYSTEM } else if role == "assistant" { return MT_ASSISTANT } else if role == "user" { return MT_USER } else { log.Fatalf("Invalid role type %s\n", role) return 0 } } func colorRole(role MessageType) string { if role == MT_SYSTEM { return color.RedString("system") } else if role == MT_ASSISTANT { return color.GreenString("assistant") } else if role == MT_USER { return color.BlueString("user") } else { log.Fatalf("Invalid role type %d\n", role) return "" } } func loadMessageFromDb() []api.Message { dbMessages := GetMesages() var chatMessages []api.Message for _, msg := range dbMessages { message := api.Message{ Role: convertRole(msg.Type), Content: msg.Content, } chatMessages = append(chatMessages, message) fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content) } return chatMessages } func executeCommand(cli string) { args := strings.Split(cli, " ") for _, cmd := range Commands { if cmd.Name == args[0] { cmd.Action(args) return } } fmt.Println("Unknown command") } func onUserInput(input string) { if strings.HasPrefix(input, "/") { executeCommand(input) return } sendPromptInput(input, func(r api.ChatResponse) error { _, err := fmt.Print(r.Message.Content) return err }) fmt.Println() } func sendPromptInput(input string, handler func(response api.ChatResponse) error) { err := SaveMessage(MT_USER, input) if err != nil { log.Fatal(err) } conversation = append(conversation, api.Message{ Role: convertRole(MT_USER), Content: input, }) ctx := context.Background() req := &api.ChatRequest{ Model: "llama3.2:1b", Messages: conversation, } fullResponse := "" respFunc := func(resp api.ChatResponse) error { fullResponse = fullResponse + resp.Message.Content err := handler(resp) if err != nil { return err } if !resp.Done { return nil } return SaveMessage(MT_ASSISTANT, fullResponse) } fmt.Print(colorRole(MT_ASSISTANT), ": ") err = ollama.Chat(ctx, req, respFunc) if err != nil { log.Fatal(err) } conversation = append(conversation, api.Message{ Role: convertRole(MT_ASSISTANT), Content: fullResponse, }) } func executeTemplate(w io.Writer, name string, data any) error { t, err := template. New("templates"). Funcs(template.FuncMap{ "trim": strings.TrimSpace, }). ParseGlob("templates/*.gohtml") if err != nil { return err } err = t.ExecuteTemplate(w, name, data) if err != nil { return err } return nil } func servePage(w http.ResponseWriter, _ *http.Request) { wm := WebModel{} wm.AddMessages(conversation) err := executeTemplate(w, "index", wm) if err != nil { log.Println(err) } } func serveWebSocket(w http.ResponseWriter, r *http.Request) { log.Println("Websocket connected") c, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println("upgrade: ", err) return } defer c.Close() for { // Read the prompt mt, message, err := c.ReadMessage() if err != nil { log.Println("Read error: ", err) break } formInput := make(map[string]any) err = json.Unmarshal(message, &formInput) if err != nil { log.Println("Unmarshal: ", err) break } str, ok := formInput["message"].(string) if !ok { log.Println("Invalid user input: ", err) break } // Send the request to Ollama fullResponse := "" firstResponse := true sendPromptInput(str, func(response api.ChatResponse) error { fullResponse = fullResponse + response.Message.Content var content bytes.Buffer err = executeTemplate(&content, "message", WebMessage{ Id: len(conversation) + 1, New: firstResponse, Replace: !firstResponse, Content: fullResponse, }) if err != nil { return err } firstResponse = false return c.WriteMessage(mt, content.Bytes()) }) } } func main() { var err error dbName := "" if len(os.Args) > 1 { dbName = os.Args[1] } ollama, err = api.ClientFromEnvironment() if err != nil { log.Fatal(err) } OpenDb(dbName) conversation = loadMessageFromDb() fs := http.FileServer(http.Dir("static")) http.Handle("/static/", http.StripPrefix("/static/", fs)) http.HandleFunc("/", servePage) http.HandleFunc("/ws", serveWebSocket) go func() { err := http.ListenAndServe(":8080", nil) if err != nil { log.Fatal(err) } }() runner := prompt.New( onUserInput, prompt.WithTitle("llamachat"), prompt.WithPrefix("user: "), prompt.WithCompleter(Completer), ) runner.Run() }