aboutsummaryrefslogtreecommitdiff
path: root/src/models/psql_main_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/models/psql_main_test.go')
-rw-r--r--src/models/psql_main_test.go231
1 files changed, 231 insertions, 0 deletions
diff --git a/src/models/psql_main_test.go b/src/models/psql_main_test.go
new file mode 100644
index 0000000..63c615f
--- /dev/null
+++ b/src/models/psql_main_test.go
@@ -0,0 +1,231 @@
+// Code generated by SQLBoiler 4.16.1 (https://github.com/volatiletech/sqlboiler). DO NOT EDIT.
+// This file is meant to be re-generated in place and/or deleted at any time.
+
+package models
+
+import (
+ "bytes"
+ "database/sql"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "regexp"
+ "strings"
+
+ "github.com/friendsofgo/errors"
+ "github.com/kat-co/vala"
+ _ "github.com/lib/pq"
+ "github.com/spf13/viper"
+ "github.com/volatiletech/randomize"
+ "github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-psql/driver"
+)
+
+var rgxPGFkey = regexp.MustCompile(`(?m)^ALTER TABLE .*\n\s+ADD CONSTRAINT .*? FOREIGN KEY .*?;\n`)
+
+type pgTester struct {
+ dbConn *sql.DB
+
+ dbName string
+ host string
+ user string
+ pass string
+ sslmode string
+ port int
+
+ pgPassFile string
+
+ testDBName string
+ skipSQLCmd bool
+}
+
+func init() {
+ dbMain = &pgTester{}
+}
+
+// setup dumps the database schema and imports it into a temporary randomly
+// generated test database so that tests can be run against it using the
+// generated sqlboiler ORM package.
+func (p *pgTester) setup() error {
+ var err error
+
+ viper.SetDefault("psql.schema", "public")
+ viper.SetDefault("psql.port", 5432)
+ viper.SetDefault("psql.sslmode", "require")
+
+ p.dbName = viper.GetString("psql.dbname")
+ p.host = viper.GetString("psql.host")
+ p.user = viper.GetString("psql.user")
+ p.pass = viper.GetString("psql.pass")
+ p.port = viper.GetInt("psql.port")
+ p.sslmode = viper.GetString("psql.sslmode")
+ p.testDBName = viper.GetString("psql.testdbname")
+ p.skipSQLCmd = viper.GetBool("psql.skipsqlcmd")
+
+ err = vala.BeginValidation().Validate(
+ vala.StringNotEmpty(p.user, "psql.user"),
+ vala.StringNotEmpty(p.host, "psql.host"),
+ vala.Not(vala.Equals(p.port, 0, "psql.port")),
+ vala.StringNotEmpty(p.dbName, "psql.dbname"),
+ vala.StringNotEmpty(p.sslmode, "psql.sslmode"),
+ ).Check()
+
+ if err != nil {
+ return err
+ }
+
+ // if no testing DB passed
+ if len(p.testDBName) == 0 {
+ // Create a randomized db name.
+ p.testDBName = randomize.StableDBName(p.dbName)
+ }
+
+ if err = p.makePGPassFile(); err != nil {
+ return err
+ }
+
+ if !p.skipSQLCmd {
+ if err = p.dropTestDB(); err != nil {
+ return err
+ }
+ if err = p.createTestDB(); err != nil {
+ return err
+ }
+
+ dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName)
+ dumpCmd.Env = append(os.Environ(), p.pgEnv()...)
+ createCmd := exec.Command("psql", p.testDBName)
+ createCmd.Env = append(os.Environ(), p.pgEnv()...)
+
+ r, w := io.Pipe()
+ dumpCmdStderr := &bytes.Buffer{}
+ createCmdStderr := &bytes.Buffer{}
+
+ dumpCmd.Stdout = w
+ dumpCmd.Stderr = dumpCmdStderr
+
+ createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r)
+ createCmd.Stderr = createCmdStderr
+
+ if err = dumpCmd.Start(); err != nil {
+ return errors.Wrap(err, "failed to start pg_dump command")
+ }
+ if err = createCmd.Start(); err != nil {
+ return errors.Wrap(err, "failed to start psql command")
+ }
+
+ if err = dumpCmd.Wait(); err != nil {
+ fmt.Println(err)
+ fmt.Println(dumpCmdStderr.String())
+ return errors.Wrap(err, "failed to wait for pg_dump command")
+ }
+
+ _ = w.Close() // After dumpCmd is done, close the write end of the pipe
+
+ if err = createCmd.Wait(); err != nil {
+ fmt.Println(err)
+ fmt.Println(createCmdStderr.String())
+ return errors.Wrap(err, "failed to wait for psql command")
+ }
+ }
+
+ return nil
+}
+
+func (p *pgTester) runCmd(stdin, command string, args ...string) error {
+ cmd := exec.Command(command, args...)
+ cmd.Env = append(os.Environ(), p.pgEnv()...)
+
+ if len(stdin) != 0 {
+ cmd.Stdin = strings.NewReader(stdin)
+ }
+
+ stdout := &bytes.Buffer{}
+ stderr := &bytes.Buffer{}
+ cmd.Stdout = stdout
+ cmd.Stderr = stderr
+ if err := cmd.Run(); err != nil {
+ fmt.Println("failed running:", command, args)
+ fmt.Println(stdout.String())
+ fmt.Println(stderr.String())
+ return err
+ }
+
+ return nil
+}
+
+func (p *pgTester) pgEnv() []string {
+ return []string{
+ fmt.Sprintf("PGHOST=%s", p.host),
+ fmt.Sprintf("PGPORT=%d", p.port),
+ fmt.Sprintf("PGUSER=%s", p.user),
+ fmt.Sprintf("PGPASSFILE=%s", p.pgPassFile),
+ }
+}
+
+func (p *pgTester) makePGPassFile() error {
+ tmp, err := os.CreateTemp("", "pgpass")
+ if err != nil {
+ return errors.Wrap(err, "failed to create option file")
+ }
+
+ fmt.Fprintf(tmp, "%s:%d:postgres:%s", p.host, p.port, p.user)
+ if len(p.pass) != 0 {
+ fmt.Fprintf(tmp, ":%s", p.pass)
+ }
+ fmt.Fprintln(tmp)
+
+ fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user)
+ if len(p.pass) != 0 {
+ fmt.Fprintf(tmp, ":%s", p.pass)
+ }
+ fmt.Fprintln(tmp)
+
+ fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user)
+ if len(p.pass) != 0 {
+ fmt.Fprintf(tmp, ":%s", p.pass)
+ }
+ fmt.Fprintln(tmp)
+
+ p.pgPassFile = tmp.Name()
+ return tmp.Close()
+}
+
+func (p *pgTester) createTestDB() error {
+ return p.runCmd("", "createdb", p.testDBName)
+}
+
+func (p *pgTester) dropTestDB() error {
+ return p.runCmd("", "dropdb", "--if-exists", p.testDBName)
+}
+
+// teardown executes cleanup tasks when the tests finish running
+func (p *pgTester) teardown() error {
+ var err error
+ if err = p.dbConn.Close(); err != nil {
+ return err
+ }
+ p.dbConn = nil
+
+ if !p.skipSQLCmd {
+ if err = p.dropTestDB(); err != nil {
+ return err
+ }
+ }
+
+ return os.Remove(p.pgPassFile)
+}
+
+func (p *pgTester) conn() (*sql.DB, error) {
+ if p.dbConn != nil {
+ return p.dbConn, nil
+ }
+
+ var err error
+ p.dbConn, err = sql.Open("postgres", driver.PSQLBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode))
+ if err != nil {
+ return nil, err
+ }
+
+ return p.dbConn, nil
+}