diff --git a/cli.go b/cli.go new file mode 100644 index 0000000..e2b9460 --- /dev/null +++ b/cli.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "github.com/elk-language/go-prompt" + "github.com/ollama/ollama/api" + "strings" +) + +func onUserInput(input string) { + if strings.HasPrefix(input, "/") { + executeCommand(input) + return + } + fmt.Print(colorRole(MT_ASSISTANT), ": ") + sendPromptInput(input, func(r api.ChatResponse) error { + _, err := fmt.Print(r.Message.Content) + return err + }) + fmt.Println() +} + +func runAsCommandLine() { + runner := prompt.New( + onUserInput, + prompt.WithTitle("llamachat"), + prompt.WithPrefix("user: "), + prompt.WithCompleter(Completer), + ) + + runner.Run() +} + +func executeCommand(cli string) { + args := strings.Split(cli, " ") + for _, cmd := range Commands { + if cmd.Name == args[0] { + cmd.Action(args) + return + } + } + fmt.Println("Unknown command") +} diff --git a/db.go b/db.go index e352330..7e92ace 100644 --- a/db.go +++ b/db.go @@ -19,11 +19,14 @@ var db *sql.DB func OpenDb(dbName string) { if dbName == "" || dbName == "-" { log.Println("Using in-memory database") - dbName = "memory" + dbName = ":memory:" + } else { + dbName = dbName + ".chat" + log.Printf("Using database %s\n", dbName) } var err error - db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s.chat?_foreign_keys=on", dbName)) + db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s?_foreign_keys=on", dbName)) if err != nil { log.Fatal(err) } @@ -39,10 +42,7 @@ func OpenDb(dbName string) { if err != nil { log.Fatal(err) } - //err = m.Down() - //if err != nil && !errors.Is(err, migrate.ErrNoChange) { - // log.Println("Could not go down:", err) - //} + err = m.Up() if err != nil && !errors.Is(err, migrate.ErrNoChange) { log.Fatal(err) diff --git a/main.go b/main.go index 6bdb993..e4e3b2c 100644 --- a/main.go +++ b/main.go @@ -1,25 +1,17 @@ package main import ( - "bytes" "context" - "encoding/json" + "flag" "fmt" - "github.com/elk-language/go-prompt" "github.com/fatih/color" - "github.com/gorilla/websocket" "github.com/ollama/ollama/api" - "html/template" - "io" "log" - "net/http" "os" - "strings" ) var conversation []api.Message var ollama *api.Client -var upgrader = websocket.Upgrader{} func convertRole(role MessageType) string { if role == MT_SYSTEM { @@ -69,34 +61,11 @@ func loadMessageFromDb() []api.Message { Content: msg.Content, } chatMessages = append(chatMessages, message) - fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content) + //fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content) } return chatMessages } -func executeCommand(cli string) { - args := strings.Split(cli, " ") - for _, cmd := range Commands { - if cmd.Name == args[0] { - cmd.Action(args) - return - } - } - fmt.Println("Unknown command") -} - -func onUserInput(input string) { - if strings.HasPrefix(input, "/") { - executeCommand(input) - return - } - sendPromptInput(input, func(r api.ChatResponse) error { - _, err := fmt.Print(r.Message.Content) - return err - }) - fmt.Println() -} - func sendPromptInput(input string, handler func(response api.ChatResponse) error) { err := SaveMessage(MT_USER, input) if err != nil { @@ -127,7 +96,6 @@ func sendPromptInput(input string, handler func(response api.ChatResponse) error return SaveMessage(MT_ASSISTANT, fullResponse) } - fmt.Print(colorRole(MT_ASSISTANT), ": ") err = ollama.Chat(ctx, req, respFunc) if err != nil { log.Fatal(err) @@ -138,112 +106,44 @@ func sendPromptInput(input string, handler func(response api.ChatResponse) error }) } -func executeTemplate(w io.Writer, name string, data any) error { - t, err := template. - New("templates"). - Funcs(template.FuncMap{ - "trim": strings.TrimSpace, - }). - ParseGlob("templates/*.gohtml") - if err != nil { - return err - } - err = t.ExecuteTemplate(w, name, data) - if err != nil { - return err - } - return nil -} +const usage = `Usage: +llamachat [OPTION]... [CONVERSATION] -func servePage(w http.ResponseWriter, _ *http.Request) { - wm := WebModel{} - wm.AddMessages(conversation) - err := executeTemplate(w, "index", wm) - if err != nil { - log.Println(err) - } -} +Options: + -w, --web Access via a web page. -func serveWebSocket(w http.ResponseWriter, r *http.Request) { - log.Println("Websocket connected") - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Println("upgrade: ", err) - return - } - defer c.Close() - for { - // Read the prompt - mt, message, err := c.ReadMessage() - if err != nil { - log.Println("Read error: ", err) - break - } - formInput := make(map[string]any) - err = json.Unmarshal(message, &formInput) - if err != nil { - log.Println("Unmarshal: ", err) - break - } - str, ok := formInput["message"].(string) - if !ok { - log.Println("Invalid user input: ", err) - break - } - - // Send the request to Ollama - fullResponse := "" - firstResponse := true - sendPromptInput(str, func(response api.ChatResponse) error { - fullResponse = fullResponse + response.Message.Content - var content bytes.Buffer - err = executeTemplate(&content, "message", WebMessage{ - Id: len(conversation) + 1, - New: firstResponse, - Replace: !firstResponse, - Content: fullResponse, - }) - if err != nil { - return err - } - firstResponse = false - return c.WriteMessage(mt, content.Bytes()) - }) - } -} +If CONVERSATION is left out, the conversation will be kept in-memory. If it is +present, the conversation will be stored in a file called CONVERSATION.chat. +Alternatively, it can also be passed the literal string '-' (a singel dash) to +ensure an in-memory database is used. +` func main() { + flag.Usage = func() { fmt.Fprintf(os.Stderr, "%s\n", usage) } + var err error - dbName := "" - if len(os.Args) > 1 { - dbName = os.Args[1] + var useWeb bool + db := "" + flag.BoolVar(&useWeb, "web", false, "run in web mode") + flag.BoolVar(&useWeb, "w", false, "run in web mode") + + flag.Parse() + + db = flag.Arg(0) + if db == "-" { + db = "" } + OpenDb(db) + conversation = loadMessageFromDb() ollama, err = api.ClientFromEnvironment() if err != nil { log.Fatal(err) } - OpenDb(dbName) - conversation = loadMessageFromDb() - - fs := http.FileServer(http.Dir("static")) - http.Handle("/static/", http.StripPrefix("/static/", fs)) - http.HandleFunc("/", servePage) - http.HandleFunc("/ws", serveWebSocket) - go func() { - err := http.ListenAndServe(":8080", nil) - if err != nil { - log.Fatal(err) - } - }() - - runner := prompt.New( - onUserInput, - prompt.WithTitle("llamachat"), - prompt.WithPrefix("user: "), - prompt.WithCompleter(Completer), - ) - - runner.Run() + if useWeb { + runAsWeb() + } else { + runAsCommandLine() + } } diff --git a/static/script.js b/static/script.js index f24be41..8f22028 100644 --- a/static/script.js +++ b/static/script.js @@ -1,5 +1,4 @@ function scrollToBottom() { - console.log("Scrolling"); window.scrollTo(0, document.body.scrollHeight); } diff --git a/static/style.css b/static/style.css index a3fd122..c2f2d8e 100644 --- a/static/style.css +++ b/static/style.css @@ -3,6 +3,7 @@ body { } .message { + background-color: whitesmoke; border: 1px solid #000; margin: 5px 0; padding: 16px; diff --git a/templates/index.gohtml b/templates/index.gohtml index 49ab0db..95b5aea 100644 --- a/templates/index.gohtml +++ b/templates/index.gohtml @@ -24,7 +24,7 @@ {{end}} -
diff --git a/web.go b/web.go new file mode 100644 index 0000000..6fb3ba0 --- /dev/null +++ b/web.go @@ -0,0 +1,128 @@ +package main + +import ( + "bytes" + "encoding/json" + "github.com/gorilla/websocket" + "github.com/ollama/ollama/api" + "html/template" + "io" + "log" + "net" + "net/http" + "strings" +) + +var upgrader = websocket.Upgrader{} + +func runAsWeb() { + log.Println("Starting web server") + fs := http.FileServer(http.Dir("static")) + http.Handle("/static/", http.StripPrefix("/static/", fs)) + http.HandleFunc("/", servePage) + http.HandleFunc("/ws", serveWebSocket) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + log.Fatal(err) + } + log.Printf("Listening on http://127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) + + err = http.Serve(listener, nil) + if err != nil { + log.Fatal(err) + } +} + +func executeTemplate(w io.Writer, name string, data any) error { + t, err := template. + New("templates"). + Funcs(template.FuncMap{ + "trim": strings.TrimSpace, + }). + ParseGlob("templates/*.gohtml") + if err != nil { + return err + } + err = t.ExecuteTemplate(w, name, data) + if err != nil { + return err + } + return nil +} + +func servePage(w http.ResponseWriter, _ *http.Request) { + wm := WebModel{} + wm.AddMessages(conversation) + err := executeTemplate(w, "index", wm) + if err != nil { + log.Println(err) + } +} + +func serveWebSocket(w http.ResponseWriter, r *http.Request) { + log.Println("Websocket connected") + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println("upgrade: ", err) + return + } + + defer c.Close() + for { + // Read the prompt + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("Read error: ", err) + break + } + formInput := make(map[string]any) + err = json.Unmarshal(message, &formInput) + if err != nil { + log.Println("Unmarshal: ", err) + break + } + userRequest, ok := formInput["message"].(string) + if !ok { + log.Println("Invalid user input: ", err) + break + } + + err = func() error { + var content bytes.Buffer + err := executeTemplate(&content, "message", WebMessage{ + Id: len(conversation) + 1, + New: true, + Replace: false, + Content: userRequest, + }) + if err != nil { + return err + } + return c.WriteMessage(mt, content.Bytes()) + }() + if err != nil { + log.Println("Sending initial response: ", err) + break + } + + // Send the request to Ollama + fullResponse := "" + firstResponse := true + sendPromptInput(userRequest, func(response api.ChatResponse) error { + fullResponse = fullResponse + response.Message.Content + var content bytes.Buffer + err = executeTemplate(&content, "message", WebMessage{ + Id: len(conversation) + 1, + New: firstResponse, + Replace: !firstResponse, + Content: fullResponse, + }) + if err != nil { + return err + } + firstResponse = false + return c.WriteMessage(mt, content.Bytes()) + }) + } +}