package mysqlite

import "log"

type Tx struct {
	db *Db
}

func (d *Db) Begin() (*Tx, error) {
	d.Lock()
	err := d.query("BEGIN").Exec()
	if err != nil {
		return nil, err
	}
	return &Tx{db: d}, nil
}

func (d *Db) MustBegin() *Tx {
	tx, err := d.Begin()
	if err != nil {
		panic(err)
	}
	return tx
}

func (tx *Tx) Commit() error {
	defer tx.unlock()
	return tx.Query("COMMIT").Exec()
}

func (tx *Tx) MustCommit() {
	err := tx.Commit()
	if err != nil {
		panic(err)
	}
}

func (tx *Tx) Rollback() error {
	if tx.db == nil {
		// The transaction was already commited
		return nil
	}
	defer tx.unlock()
	return tx.Query("ROLLBACK").Exec()
}

func (tx *Tx) MustRollback() {
	err := tx.Rollback()
	if err != nil {
		log.Panicf("error doing rollback: %v", err)
	}
}

func (tx *Tx) unlock() {
	if tx.db != nil {
		tx.db.Unlock()
		tx.db = nil
	}
}

func (tx *Tx) Query(query string) *Query {
	if tx.db == nil {
		panic("query was performed on a transaction after Commit or Rollback")
	}
	return tx.db.query(query)
}