From ebad22fbee2c7b0401b0f462edcc46b30ed0080f Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Thu, 7 Nov 2024 23:30:35 +0100 Subject: [PATCH] Initial commit --- .gitignore | 6 ++ commands.go | 58 ++++++++++++++ db.go | 83 +++++++++++++++++++ go.mod | 25 ++++++ go.sum | 61 ++++++++++++++ main.go | 147 ++++++++++++++++++++++++++++++++++ migrations/1_initial.down.sql | 1 + migrations/1_initial.up.sql | 12 +++ models.go | 14 ++++ 9 files changed, 407 insertions(+) create mode 100644 .gitignore create mode 100644 commands.go create mode 100644 db.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 migrations/1_initial.down.sql create mode 100644 migrations/1_initial.up.sql create mode 100644 models.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..35949e7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# Ignore IntelliJ / GoLand stuff +*.iml +*.idea + +# Ignore chat histories +*.chat diff --git a/commands.go b/commands.go new file mode 100644 index 0000000..912b2d4 --- /dev/null +++ b/commands.go @@ -0,0 +1,58 @@ +package main + +import ( + "github.com/elk-language/go-prompt" + pstrings "github.com/elk-language/go-prompt/strings" + "os" + "strings" +) + +type Action func(args []string) + +type CommandArg struct { + Name string + Action Action + Help string +} + +func command(name string, action Action) CommandArg { + return CommandArg{ + Name: "/" + name, + Action: action, + } +} + +func (arg CommandArg) withHelp(help string) CommandArg { + arg.Help = help + return arg +} + +func Completer(d prompt.Document) ([]prompt.Suggest, pstrings.RuneNumber, pstrings.RuneNumber) { + endIndex := d.CurrentRuneIndex() + + w := d.TextBeforeCursor() + startIndex := endIndex - pstrings.RuneCount([]byte(w)) + + if !strings.HasPrefix(w, "/") { + return nil, startIndex, endIndex + } + + var s []prompt.Suggest + for _, cmd := range Commands { + if strings.HasPrefix(cmd.Name, w) { + s = append(s, prompt.Suggest{ + Text: cmd.Name, + Description: cmd.Help, + }) + } + } + return prompt.FilterHasPrefix(s, w, true), startIndex, endIndex +} + +var Commands = []CommandArg{ + command("quit", cmdQuit).withHelp("Quit the program"), +} + +func cmdQuit(args []string) { + os.Exit(0) +} diff --git a/db.go b/db.go new file mode 100644 index 0000000..d70816e --- /dev/null +++ b/db.go @@ -0,0 +1,83 @@ +package main + +import ( + "database/sql" + "embed" + "errors" + "fmt" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" + "log" +) + +//go:embed migrations/*.sql +var fs embed.FS + +var db *sql.DB + +func OpenDb(dbName string) { + var err error + db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s.chat?_foreign_keys=on", dbName)) + if err != nil { + log.Fatal(err) + } + driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) + if err != nil { + log.Fatal(err) + } + d, err := iofs.New(fs, "migrations") + if err != nil { + log.Fatal(err) + } + m, err := migrate.NewWithInstance("iofs", d, "sqlite3", driver) + if err != nil { + log.Fatal(err) + } + //err = m.Down() + //if err != nil && !errors.Is(err, migrate.ErrNoChange) { + // log.Println("Could not go down:", err) + //} + err = m.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + log.Fatal(err) + } +} + +func GetMesages() []Message { + var messages []Message + rows, err := db.Query("select type, message from messages order by id") + if err != nil { + log.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var message Message + err = rows.Scan(&message.Type, &message.Content) + if err != nil { + log.Fatal(err) + } + messages = append(messages, message) + } + return messages +} + +func attempRollback(tx *sql.Tx) { + err := tx.Rollback() + if err != nil { + log.Printf("Failed to perform rollback: %v\n", err) + } +} + +func SaveMessage(role MessageType, content string) error { + tx, err := db.Begin() + if err != nil { + return err + } + _, err = tx.Exec("INSERT INTO messages (type, message) VALUES (?, ?)", role, content) + if err != nil { + attempRollback(tx) + return err + } + return tx.Commit() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..92e008d --- /dev/null +++ b/go.mod @@ -0,0 +1,25 @@ +module llamachat + +go 1.23.2 + +require ( + github.com/elk-language/go-prompt v1.1.5 + github.com/fatih/color v1.18.0 + github.com/golang-migrate/migrate/v4 v4.18.1 + github.com/ollama/ollama v0.4.0 +) + +require ( + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/mattn/go-tty v0.0.3 // indirect + github.com/pkg/term v1.2.0-beta.2 // indirect + github.com/rivo/uniseg v0.4.4 // indirect + go.uber.org/atomic v1.7.0 // indirect + golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect + golang.org/x/sys v0.25.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..cd3a2c6 --- /dev/null +++ b/go.sum @@ -0,0 +1,61 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/elk-language/go-prompt v1.1.5 h1:/pGHSmEICQbaJltkFYZtcQlm0fQ8WO3CgISkUnYXxYY= +github.com/elk-language/go-prompt v1.1.5/go.mod h1:iEK3nFtQZuRxpoUVZk7Tie27TyWL+RMLkv4wA39XkWc= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y= +github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.6/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= +github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-tty v0.0.3 h1:5OfyWorkyO7xP52Mq7tB36ajHDG5OHrmBGIS/DtakQI= +github.com/mattn/go-tty v0.0.3/go.mod h1:ihxohKRERHTVzN+aSVRwACLCeqIoZAWpoICkkvrWyR0= +github.com/ollama/ollama v0.4.0 h1:8CxFpQxHmCUkXypNl4Suy9+X4SNdsCvE8DY5ih7BIrU= +github.com/ollama/ollama v0.4.0/go.mod h1:QDxM/t2teuubbfN/FT2pBRMPF0K1N3IakgT1OZBD4NY= +github.com/pkg/term v1.2.0-beta.2 h1:L3y/h2jkuBVFdWiJvNfYfKmzcCnILw7mJWm2JQuMppw= +github.com/pkg/term v1.2.0-beta.2/go.mod h1:E25nymQcrSllhX42Ok8MRm1+hyBdHY0dCeiKZ9jpNGw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= +github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..e1ab92f --- /dev/null +++ b/main.go @@ -0,0 +1,147 @@ +package main + +import ( + "context" + "fmt" + "github.com/elk-language/go-prompt" + "github.com/fatih/color" + "github.com/ollama/ollama/api" + "log" + "os" + "strings" +) + +var conversation []api.Message +var ollama *api.Client + +func convertRole(role MessageType) string { + if role == MT_SYSTEM { + return "system" + } else if role == MT_ASSISTANT { + return "assistant" + } else if role == MT_USER { + return "user" + } else { + log.Fatalf("Invalid role type %d\n", role) + return "" + } +} + +func roleToInt(role string) MessageType { + if role == "system" { + return MT_SYSTEM + } else if role == "assistant" { + return MT_ASSISTANT + } else if role == "user" { + return MT_USER + } else { + log.Fatalf("Invalid role type %s\n", role) + return 0 + } +} + +func colorRole(role MessageType) string { + if role == MT_SYSTEM { + return color.RedString("system") + } else if role == MT_ASSISTANT { + return color.GreenString("assistant") + } else if role == MT_USER { + return color.BlueString("user") + } else { + log.Fatalf("Invalid role type %d\n", role) + return "" + } +} + +func loadMessageFromDb() []api.Message { + dbMessages := GetMesages() + var chatMessages []api.Message + for _, msg := range dbMessages { + message := api.Message{ + Role: convertRole(msg.Type), + Content: msg.Content, + } + chatMessages = append(chatMessages, message) + fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content) + } + return chatMessages +} + +func executeCommand(cli string) { + args := strings.Split(cli, " ") + for _, cmd := range Commands { + if cmd.Name == args[0] { + cmd.Action(args) + return + } + } + fmt.Println("Unknown command") +} + +func onUserInput(input string) { + if strings.HasPrefix(input, "/") { + executeCommand(input) + return + } + err := SaveMessage(MT_USER, input) + if err != nil { + log.Fatal(err) + } + + conversation = append(conversation, api.Message{ + Role: convertRole(MT_USER), + Content: input, + }) + + ctx := context.Background() + req := &api.ChatRequest{ + Model: "llama3.2:1b", + Messages: conversation, + } + + fullResponse := "" + respFunc := func(resp api.ChatResponse) error { + fullResponse = fullResponse + resp.Message.Content + fmt.Print(resp.Message.Content) + if !resp.Done { + return nil + } + return SaveMessage(MT_ASSISTANT, fullResponse) + } + + fmt.Print(colorRole(MT_ASSISTANT), ": ") + err = ollama.Chat(ctx, req, respFunc) + if err != nil { + log.Fatal(err) + } + conversation = append(conversation, api.Message{ + Role: convertRole(MT_ASSISTANT), + Content: fullResponse, + }) + fmt.Println() +} + +func main() { + var err error + if len(os.Args) <= 1 { + log.Fatal("Missing command line parameter") + } + + ollama, err = api.ClientFromEnvironment() + if err != nil { + log.Fatal(err) + } + + dbName := os.Args[1] + OpenDb(dbName) + conversation = loadMessageFromDb() + + runner := prompt.New( + onUserInput, + prompt.WithTitle("llamachat"), + prompt.WithPrefix("user: "), + prompt.WithCompleter(Completer), + ) + + runner.Run() +} diff --git a/migrations/1_initial.down.sql b/migrations/1_initial.down.sql new file mode 100644 index 0000000..e177e51 --- /dev/null +++ b/migrations/1_initial.down.sql @@ -0,0 +1 @@ +drop table messages; diff --git a/migrations/1_initial.up.sql b/migrations/1_initial.up.sql new file mode 100644 index 0000000..9c17ab0 --- /dev/null +++ b/migrations/1_initial.up.sql @@ -0,0 +1,12 @@ +create table messages ( + id integer primary key autoincrement, + type int, + message text +); + +insert into messages (id, type, message) +values ( + 1, + 1, + 'You are a helpful assistant. It is your task to aid the user as best as possible.' +); diff --git a/models.go b/models.go new file mode 100644 index 0000000..92df2f2 --- /dev/null +++ b/models.go @@ -0,0 +1,14 @@ +package main + +type MessageType int + +const ( + MT_SYSTEM MessageType = 1 + MT_ASSISTANT = 2 + MT_USER = 3 +) + +type Message struct { + Type MessageType + Content string +}