diff options
Diffstat (limited to 'src/models/psql_main_test.go')
| -rw-r--r-- | src/models/psql_main_test.go | 231 |
1 files changed, 231 insertions, 0 deletions
diff --git a/src/models/psql_main_test.go b/src/models/psql_main_test.go new file mode 100644 index 0000000..63c615f --- /dev/null +++ b/src/models/psql_main_test.go @@ -0,0 +1,231 @@ +// Code generated by SQLBoiler 4.16.1 (https://github.com/volatiletech/sqlboiler). DO NOT EDIT. +// This file is meant to be re-generated in place and/or deleted at any time. + +package models + +import ( + "bytes" + "database/sql" + "fmt" + "io" + "os" + "os/exec" + "regexp" + "strings" + + "github.com/friendsofgo/errors" + "github.com/kat-co/vala" + _ "github.com/lib/pq" + "github.com/spf13/viper" + "github.com/volatiletech/randomize" + "github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-psql/driver" +) + +var rgxPGFkey = regexp.MustCompile(`(?m)^ALTER TABLE .*\n\s+ADD CONSTRAINT .*? FOREIGN KEY .*?;\n`) + +type pgTester struct { + dbConn *sql.DB + + dbName string + host string + user string + pass string + sslmode string + port int + + pgPassFile string + + testDBName string + skipSQLCmd bool +} + +func init() { + dbMain = &pgTester{} +} + +// setup dumps the database schema and imports it into a temporary randomly +// generated test database so that tests can be run against it using the +// generated sqlboiler ORM package. +func (p *pgTester) setup() error { + var err error + + viper.SetDefault("psql.schema", "public") + viper.SetDefault("psql.port", 5432) + viper.SetDefault("psql.sslmode", "require") + + p.dbName = viper.GetString("psql.dbname") + p.host = viper.GetString("psql.host") + p.user = viper.GetString("psql.user") + p.pass = viper.GetString("psql.pass") + p.port = viper.GetInt("psql.port") + p.sslmode = viper.GetString("psql.sslmode") + p.testDBName = viper.GetString("psql.testdbname") + p.skipSQLCmd = viper.GetBool("psql.skipsqlcmd") + + err = vala.BeginValidation().Validate( + vala.StringNotEmpty(p.user, "psql.user"), + vala.StringNotEmpty(p.host, "psql.host"), + vala.Not(vala.Equals(p.port, 0, "psql.port")), + vala.StringNotEmpty(p.dbName, "psql.dbname"), + vala.StringNotEmpty(p.sslmode, "psql.sslmode"), + ).Check() + + if err != nil { + return err + } + + // if no testing DB passed + if len(p.testDBName) == 0 { + // Create a randomized db name. + p.testDBName = randomize.StableDBName(p.dbName) + } + + if err = p.makePGPassFile(); err != nil { + return err + } + + if !p.skipSQLCmd { + if err = p.dropTestDB(); err != nil { + return err + } + if err = p.createTestDB(); err != nil { + return err + } + + dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName) + dumpCmd.Env = append(os.Environ(), p.pgEnv()...) + createCmd := exec.Command("psql", p.testDBName) + createCmd.Env = append(os.Environ(), p.pgEnv()...) + + r, w := io.Pipe() + dumpCmdStderr := &bytes.Buffer{} + createCmdStderr := &bytes.Buffer{} + + dumpCmd.Stdout = w + dumpCmd.Stderr = dumpCmdStderr + + createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r) + createCmd.Stderr = createCmdStderr + + if err = dumpCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start pg_dump command") + } + if err = createCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start psql command") + } + + if err = dumpCmd.Wait(); err != nil { + fmt.Println(err) + fmt.Println(dumpCmdStderr.String()) + return errors.Wrap(err, "failed to wait for pg_dump command") + } + + _ = w.Close() // After dumpCmd is done, close the write end of the pipe + + if err = createCmd.Wait(); err != nil { + fmt.Println(err) + fmt.Println(createCmdStderr.String()) + return errors.Wrap(err, "failed to wait for psql command") + } + } + + return nil +} + +func (p *pgTester) runCmd(stdin, command string, args ...string) error { + cmd := exec.Command(command, args...) + cmd.Env = append(os.Environ(), p.pgEnv()...) + + if len(stdin) != 0 { + cmd.Stdin = strings.NewReader(stdin) + } + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cmd.Stdout = stdout + cmd.Stderr = stderr + if err := cmd.Run(); err != nil { + fmt.Println("failed running:", command, args) + fmt.Println(stdout.String()) + fmt.Println(stderr.String()) + return err + } + + return nil +} + +func (p *pgTester) pgEnv() []string { + return []string{ + fmt.Sprintf("PGHOST=%s", p.host), + fmt.Sprintf("PGPORT=%d", p.port), + fmt.Sprintf("PGUSER=%s", p.user), + fmt.Sprintf("PGPASSFILE=%s", p.pgPassFile), + } +} + +func (p *pgTester) makePGPassFile() error { + tmp, err := os.CreateTemp("", "pgpass") + if err != nil { + return errors.Wrap(err, "failed to create option file") + } + + fmt.Fprintf(tmp, "%s:%d:postgres:%s", p.host, p.port, p.user) + if len(p.pass) != 0 { + fmt.Fprintf(tmp, ":%s", p.pass) + } + fmt.Fprintln(tmp) + + fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user) + if len(p.pass) != 0 { + fmt.Fprintf(tmp, ":%s", p.pass) + } + fmt.Fprintln(tmp) + + fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user) + if len(p.pass) != 0 { + fmt.Fprintf(tmp, ":%s", p.pass) + } + fmt.Fprintln(tmp) + + p.pgPassFile = tmp.Name() + return tmp.Close() +} + +func (p *pgTester) createTestDB() error { + return p.runCmd("", "createdb", p.testDBName) +} + +func (p *pgTester) dropTestDB() error { + return p.runCmd("", "dropdb", "--if-exists", p.testDBName) +} + +// teardown executes cleanup tasks when the tests finish running +func (p *pgTester) teardown() error { + var err error + if err = p.dbConn.Close(); err != nil { + return err + } + p.dbConn = nil + + if !p.skipSQLCmd { + if err = p.dropTestDB(); err != nil { + return err + } + } + + return os.Remove(p.pgPassFile) +} + +func (p *pgTester) conn() (*sql.DB, error) { + if p.dbConn != nil { + return p.dbConn, nil + } + + var err error + p.dbConn, err = sql.Open("postgres", driver.PSQLBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode)) + if err != nil { + return nil, err + } + + return p.dbConn, nil +} |
