diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5d70849..814c7f1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,6 +16,10 @@ jobs: services: redis: image: redis + postgres: + image: postgres + env: + POSTGRES_PASSWORD: postgres steps: - name: Install common dependencies run: apk add --no-cache gcc libc-dev @@ -23,3 +27,8 @@ jobs: - run: go test -v ./... env: REDIS_ADDR: redis:6379 + DATABASE_ADDR: postgres:5432 + DATABASE_USER: postgres + DATABASE_PASSWORD: postgres + DATABASE_DATABASE: postgres + DATABASE_SSL: false diff --git a/README.md b/README.md index 6bbf703..ca4110b 100644 --- a/README.md +++ b/README.md @@ -76,4 +76,5 @@ database: password: "" database: "" pool: 10 + ssl: true ``` \ No newline at end of file diff --git a/clients/pg.go b/clients/pg.go index 6ca5943..6c97919 100644 --- a/clients/pg.go +++ b/clients/pg.go @@ -2,9 +2,16 @@ package clients import ( "context" + "crypto/tls" + "database/sql" + "errors" + "fmt" + "net/url" + "strings" "time" - "github.com/go-pg/pg/v10" + pg "github.com/go-pg/pg/v10" + _ "github.com/lib/pq" "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -15,6 +22,16 @@ type ( } ctxKey int + + dbConfig struct { + addr string + user string + password string + database string + poolSize int + ssl bool + debug bool + } ) const ctxRequestStartKey ctxKey = 1 + iota @@ -43,24 +60,86 @@ func (d dbQueryHook) AfterQuery(ctx context.Context, event *pg.QueryEvent) error return nil } -func NewPostgreSQL(config *viper.Viper, logger *logrus.Logger) *pg.DB { - config.SetDefault("database.pool", 10) - config.SetDefault("database.debug", false) +func NewPostgreSQL(config *viper.Viper, logger *logrus.Logger) (*pg.DB, error) { + cfg, err := parseDBConfig(config) + if err != nil { + return nil, err + } - connection := pg.Connect(&pg.Options{ - Addr: config.GetString("database.addr"), - User: config.GetString("database.user"), - Password: config.GetString("database.password"), - Database: config.GetString("database.database"), - PoolSize: config.GetInt("database.pool"), - }) + opts := &pg.Options{ + Addr: cfg.addr, + User: cfg.user, + Password: cfg.password, + Database: cfg.database, + PoolSize: cfg.poolSize, + } - if config.GetBool("database.debug") { + if cfg.ssl { + hp := strings.Split(cfg.addr, ":") + if len(hp) != 2 { + return nil, errors.New("database address has wrong format") + } + + opts.TLSConfig = &tls.Config{ + InsecureSkipVerify: false, + ServerName: hp[0], + } + } + + connection := pg.Connect(opts) + + if cfg.debug { entry := logger.WithField("module", "db") connection.AddQueryHook(dbQueryHook{ logger: entry, }) } - return connection + return connection, nil +} + +// NewPostgreSQLForMigrations is a connection that is used for migrations. +// Migrations are implemented with `goose`, which supports only `*sql.DB`. +func NewPostgreSQLForMigrations(config *viper.Viper) (*sql.DB, error) { + cfg, err := parseDBConfig(config) + if err != nil { + return nil, err + } + + dsn := fmt.Sprintf( + "postgres://%s:%s@%s/%s", + cfg.user, + strings.ReplaceAll(url.QueryEscape(cfg.password), ":", "%3A"), + cfg.addr, + cfg.database, + ) + + if cfg.ssl { + dsn += "?sslmode=verify-ca" + } else { + dsn += "?sslmode=disable" + } + + return sql.Open("postgres", dsn) +} + +func parseDBConfig(config *viper.Viper) (dbConfig, error) { + config.SetDefault("database.pool", 10) + config.SetDefault("database.debug", false) + config.SetDefault("database.ssl", false) + + dbAddr := config.GetString("database.addr") + if dbAddr == "" { + return dbConfig{}, errors.New("missing database address") + } + + return dbConfig{ + addr: dbAddr, + user: config.GetString("database.user"), + password: config.GetString("database.password"), + database: config.GetString("database.database"), + poolSize: config.GetInt("database.pool"), + ssl: config.GetBool("database.ssl"), + debug: config.GetBool("database.debug"), + }, nil } diff --git a/clients/pg_test.go b/clients/pg_test.go new file mode 100644 index 0000000..1f99ef4 --- /dev/null +++ b/clients/pg_test.go @@ -0,0 +1,73 @@ +package clients + +import ( + "os" + "testing" + + "github.com/sirupsen/logrus" + "github.com/spf13/viper" +) + +func TestNewPostgreSQL(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + + cfg := setupConfig() + logger := logrus.New() + + db, err := NewPostgreSQL(cfg, logger) + if err != nil { + t.Fatal(err) + } + + type StringResult struct { + Message string + } + var res StringResult + _, err = db.QueryOne(&res, "SELECT 'hello' AS message") + if err != nil { + t.Fatal(err) + } + + if res.Message != "hello" { + t.Error("unexpected message") + } +} + +func TestNewPostgreSQLForMigrations(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + + cfg := setupConfig() + + db, err := NewPostgreSQLForMigrations(cfg) + if err != nil { + t.Fatal(err) + } + + type StringResult struct { + Message string + } + var res StringResult + err = db.QueryRow("SELECT 'hello' AS message").Scan(&res.Message) + if err != nil { + t.Fatal(err) + } + + if res.Message != "hello" { + t.Error("unexpected message") + } +} + +func setupConfig() *viper.Viper { + cfg := viper.New() + cfg.Set("database.addr", os.Getenv("DATABASE_ADDR")) + cfg.Set("database.user", os.Getenv("DATABASE_USER")) + cfg.Set("database.password", os.Getenv("DATABASE_PASSWORD")) + cfg.Set("database.database", os.Getenv("DATABASE_DATABASE")) + cfg.Set("database.ssl", os.Getenv("DATABASE_SSL")) + + return cfg +} diff --git a/narada/commands/migrations.go b/narada/commands/migrations.go index ce385a6..b785b2e 100644 --- a/narada/commands/migrations.go +++ b/narada/commands/migrations.go @@ -1,11 +1,10 @@ package commands import ( - "database/sql" "errors" - "fmt" "github.com/cryptopay-dev/narada" + "github.com/cryptopay-dev/narada/clients" "github.com/pressly/goose" "github.com/sirupsen/logrus" @@ -31,7 +30,7 @@ func MigrateUp(p *narada.Narada) *cli.Command { logger.Println("starting migrations") dir := c.String("dir") - db, err := connect(v) + db, err := clients.NewPostgreSQLForMigrations(v) if err != nil { return err } @@ -61,7 +60,7 @@ func MigrateDown(p *narada.Narada) *cli.Command { logger.Println("rolling back migration") dir := c.String("dir") - db, err := connect(v) + db, err := clients.NewPostgreSQLForMigrations(v) if err != nil { return err } @@ -107,15 +106,3 @@ func CreateMigration(p *narada.Narada) *cli.Command { }, } } - -func connect(v *viper.Viper) (*sql.DB, error) { - dsn := fmt.Sprintf( - "postgres://%s:%s@%s/%s?sslmode=disable", - v.GetString("database.user"), - v.GetString("database.password"), - v.GetString("database.addr"), - v.GetString("database.database"), - ) - - return sql.Open("postgres", dsn) -}