Initial commit
This commit is contained in:
commit
ebad22fbee
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
# Ignore IntelliJ / GoLand stuff
|
||||
*.iml
|
||||
*.idea
|
||||
|
||||
# Ignore chat histories
|
||||
*.chat
|
58
commands.go
Normal file
58
commands.go
Normal file
@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/elk-language/go-prompt"
|
||||
pstrings "github.com/elk-language/go-prompt/strings"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Action func(args []string)
|
||||
|
||||
type CommandArg struct {
|
||||
Name string
|
||||
Action Action
|
||||
Help string
|
||||
}
|
||||
|
||||
func command(name string, action Action) CommandArg {
|
||||
return CommandArg{
|
||||
Name: "/" + name,
|
||||
Action: action,
|
||||
}
|
||||
}
|
||||
|
||||
func (arg CommandArg) withHelp(help string) CommandArg {
|
||||
arg.Help = help
|
||||
return arg
|
||||
}
|
||||
|
||||
func Completer(d prompt.Document) ([]prompt.Suggest, pstrings.RuneNumber, pstrings.RuneNumber) {
|
||||
endIndex := d.CurrentRuneIndex()
|
||||
|
||||
w := d.TextBeforeCursor()
|
||||
startIndex := endIndex - pstrings.RuneCount([]byte(w))
|
||||
|
||||
if !strings.HasPrefix(w, "/") {
|
||||
return nil, startIndex, endIndex
|
||||
}
|
||||
|
||||
var s []prompt.Suggest
|
||||
for _, cmd := range Commands {
|
||||
if strings.HasPrefix(cmd.Name, w) {
|
||||
s = append(s, prompt.Suggest{
|
||||
Text: cmd.Name,
|
||||
Description: cmd.Help,
|
||||
})
|
||||
}
|
||||
}
|
||||
return prompt.FilterHasPrefix(s, w, true), startIndex, endIndex
|
||||
}
|
||||
|
||||
var Commands = []CommandArg{
|
||||
command("quit", cmdQuit).withHelp("Quit the program"),
|
||||
}
|
||||
|
||||
func cmdQuit(args []string) {
|
||||
os.Exit(0)
|
||||
}
|
83
db.go
Normal file
83
db.go
Normal file
@ -0,0 +1,83 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"log"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var fs embed.FS
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
func OpenDb(dbName string) {
|
||||
var err error
|
||||
db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s.chat?_foreign_keys=on", dbName))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
d, err := iofs.New(fs, "migrations")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m, err := migrate.NewWithInstance("iofs", d, "sqlite3", driver)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
//err = m.Down()
|
||||
//if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
// log.Println("Could not go down:", err)
|
||||
//}
|
||||
err = m.Up()
|
||||
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func GetMesages() []Message {
|
||||
var messages []Message
|
||||
rows, err := db.Query("select type, message from messages order by id")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var message Message
|
||||
err = rows.Scan(&message.Type, &message.Content)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
messages = append(messages, message)
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func attempRollback(tx *sql.Tx) {
|
||||
err := tx.Rollback()
|
||||
if err != nil {
|
||||
log.Printf("Failed to perform rollback: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func SaveMessage(role MessageType, content string) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("INSERT INTO messages (type, message) VALUES (?, ?)", role, content)
|
||||
if err != nil {
|
||||
attempRollback(tx)
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
25
go.mod
Normal file
25
go.mod
Normal file
@ -0,0 +1,25 @@
|
||||
module llamachat
|
||||
|
||||
go 1.23.2
|
||||
|
||||
require (
|
||||
github.com/elk-language/go-prompt v1.1.5
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/golang-migrate/migrate/v4 v4.18.1
|
||||
github.com/ollama/ollama v0.4.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/mattn/go-tty v0.0.3 // indirect
|
||||
github.com/pkg/term v1.2.0-beta.2 // indirect
|
||||
github.com/rivo/uniseg v0.4.4 // indirect
|
||||
go.uber.org/atomic v1.7.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
|
||||
golang.org/x/sys v0.25.0 // indirect
|
||||
)
|
61
go.sum
Normal file
61
go.sum
Normal file
@ -0,0 +1,61 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/elk-language/go-prompt v1.1.5 h1:/pGHSmEICQbaJltkFYZtcQlm0fQ8WO3CgISkUnYXxYY=
|
||||
github.com/elk-language/go-prompt v1.1.5/go.mod h1:iEK3nFtQZuRxpoUVZk7Tie27TyWL+RMLkv4wA39XkWc=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y=
|
||||
github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.6/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-tty v0.0.3 h1:5OfyWorkyO7xP52Mq7tB36ajHDG5OHrmBGIS/DtakQI=
|
||||
github.com/mattn/go-tty v0.0.3/go.mod h1:ihxohKRERHTVzN+aSVRwACLCeqIoZAWpoICkkvrWyR0=
|
||||
github.com/ollama/ollama v0.4.0 h1:8CxFpQxHmCUkXypNl4Suy9+X4SNdsCvE8DY5ih7BIrU=
|
||||
github.com/ollama/ollama v0.4.0/go.mod h1:QDxM/t2teuubbfN/FT2pBRMPF0K1N3IakgT1OZBD4NY=
|
||||
github.com/pkg/term v1.2.0-beta.2 h1:L3y/h2jkuBVFdWiJvNfYfKmzcCnILw7mJWm2JQuMppw=
|
||||
github.com/pkg/term v1.2.0-beta.2/go.mod h1:E25nymQcrSllhX42Ok8MRm1+hyBdHY0dCeiKZ9jpNGw=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
|
||||
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
|
||||
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
147
main.go
Normal file
147
main.go
Normal file
@ -0,0 +1,147 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/elk-language/go-prompt"
|
||||
"github.com/fatih/color"
|
||||
"github.com/ollama/ollama/api"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var conversation []api.Message
|
||||
var ollama *api.Client
|
||||
|
||||
func convertRole(role MessageType) string {
|
||||
if role == MT_SYSTEM {
|
||||
return "system"
|
||||
} else if role == MT_ASSISTANT {
|
||||
return "assistant"
|
||||
} else if role == MT_USER {
|
||||
return "user"
|
||||
} else {
|
||||
log.Fatalf("Invalid role type %d\n", role)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func roleToInt(role string) MessageType {
|
||||
if role == "system" {
|
||||
return MT_SYSTEM
|
||||
} else if role == "assistant" {
|
||||
return MT_ASSISTANT
|
||||
} else if role == "user" {
|
||||
return MT_USER
|
||||
} else {
|
||||
log.Fatalf("Invalid role type %s\n", role)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func colorRole(role MessageType) string {
|
||||
if role == MT_SYSTEM {
|
||||
return color.RedString("system")
|
||||
} else if role == MT_ASSISTANT {
|
||||
return color.GreenString("assistant")
|
||||
} else if role == MT_USER {
|
||||
return color.BlueString("user")
|
||||
} else {
|
||||
log.Fatalf("Invalid role type %d\n", role)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func loadMessageFromDb() []api.Message {
|
||||
dbMessages := GetMesages()
|
||||
var chatMessages []api.Message
|
||||
for _, msg := range dbMessages {
|
||||
message := api.Message{
|
||||
Role: convertRole(msg.Type),
|
||||
Content: msg.Content,
|
||||
}
|
||||
chatMessages = append(chatMessages, message)
|
||||
fmt.Printf("%s: %s\n", colorRole(roleToInt(message.Role)), message.Content)
|
||||
}
|
||||
return chatMessages
|
||||
}
|
||||
|
||||
func executeCommand(cli string) {
|
||||
args := strings.Split(cli, " ")
|
||||
for _, cmd := range Commands {
|
||||
if cmd.Name == args[0] {
|
||||
cmd.Action(args)
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Println("Unknown command")
|
||||
}
|
||||
|
||||
func onUserInput(input string) {
|
||||
if strings.HasPrefix(input, "/") {
|
||||
executeCommand(input)
|
||||
return
|
||||
}
|
||||
err := SaveMessage(MT_USER, input)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
conversation = append(conversation, api.Message{
|
||||
Role: convertRole(MT_USER),
|
||||
Content: input,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
req := &api.ChatRequest{
|
||||
Model: "llama3.2:1b",
|
||||
Messages: conversation,
|
||||
}
|
||||
|
||||
fullResponse := ""
|
||||
respFunc := func(resp api.ChatResponse) error {
|
||||
fullResponse = fullResponse + resp.Message.Content
|
||||
fmt.Print(resp.Message.Content)
|
||||
if !resp.Done {
|
||||
return nil
|
||||
}
|
||||
return SaveMessage(MT_ASSISTANT, fullResponse)
|
||||
}
|
||||
|
||||
fmt.Print(colorRole(MT_ASSISTANT), ": ")
|
||||
err = ollama.Chat(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
conversation = append(conversation, api.Message{
|
||||
Role: convertRole(MT_ASSISTANT),
|
||||
Content: fullResponse,
|
||||
})
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
if len(os.Args) <= 1 {
|
||||
log.Fatal("Missing command line parameter")
|
||||
}
|
||||
|
||||
ollama, err = api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
dbName := os.Args[1]
|
||||
OpenDb(dbName)
|
||||
conversation = loadMessageFromDb()
|
||||
|
||||
runner := prompt.New(
|
||||
onUserInput,
|
||||
prompt.WithTitle("llamachat"),
|
||||
prompt.WithPrefix("user: "),
|
||||
prompt.WithCompleter(Completer),
|
||||
)
|
||||
|
||||
runner.Run()
|
||||
}
|
1
migrations/1_initial.down.sql
Normal file
1
migrations/1_initial.down.sql
Normal file
@ -0,0 +1 @@
|
||||
drop table messages;
|
12
migrations/1_initial.up.sql
Normal file
12
migrations/1_initial.up.sql
Normal file
@ -0,0 +1,12 @@
|
||||
create table messages (
|
||||
id integer primary key autoincrement,
|
||||
type int,
|
||||
message text
|
||||
);
|
||||
|
||||
insert into messages (id, type, message)
|
||||
values (
|
||||
1,
|
||||
1,
|
||||
'You are a helpful assistant. It is your task to aid the user as best as possible.'
|
||||
);
|
Loading…
x
Reference in New Issue
Block a user