You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
3.4 KiB
Go

package dbtimesprovider
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/doug-martin/goqu/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
_ "modernc.org/sqlite"
"prayertimes/internal/database"
"prayertimes/pkg/prayer"
)
func testDB(t *testing.T) *goqu.Database {
t.Helper()
db, err := database.NewSqliteDB(":memory:")
require.NoError(t, err)
err = Migrate(db.Db.(*sql.DB))
require.NoError(t, err)
_, err = db.Insert("locations").Rows(goqu.Record{"id": 1}).Executor().Exec()
require.NoError(t, err)
t.Cleanup(func() {
db.Db.(*sql.DB).Close()
})
return db
}
type mockProvider func() ([]prayer.Times, error)
func (m mockProvider) Get(ctx context.Context, location string) ([]prayer.Times, error) { return m() }
func (m mockProvider) Name() string { return "mock" }
func TestProvider_Get(t *testing.T) {
then := time.Date(2023, 3, 5, 0, 0, 0, 0, time.UTC)
tests := []struct {
name string
setupDB func(t *testing.T, db *goqu.Database)
provider prayer.TimesProvider
clock time.Time
assertRes func(t *testing.T, db *goqu.Database, times []prayer.Times, err error)
}{
{
name: "provider succeeds, empty db",
provider: mockProvider(func() ([]prayer.Times, error) {
return []prayer.Times{
{Date: time.Date(2023, 3, 4, 0, 0, 0, 0, time.UTC)},
{Date: time.Date(2023, 3, 5, 0, 0, 0, 0, time.UTC)},
}, nil
}),
clock: then,
assertRes: func(t *testing.T, db *goqu.Database, times []prayer.Times, err error) {
assert.NoError(t, err)
assert.Len(t, times, 2)
cnt, err := db.From("prayer_times").Count()
assert.NoError(t, err)
assert.Equal(t, int64(2), cnt)
},
},
{
name: "provider fails, empty db",
provider: mockProvider(func() ([]prayer.Times, error) {
return nil, fmt.Errorf("no")
}),
clock: then,
assertRes: func(t *testing.T, db *goqu.Database, times []prayer.Times, err error) {
assert.Error(t, err)
assert.Empty(t, times)
},
},
{
name: "provider fails, populated db",
setupDB: func(t *testing.T, db *goqu.Database) {
_, err := db.Insert("prayer_times").Rows(
prayerTimesRow{ProviderID: 1, LocationID: "1", Date: time.Date(2023, 3, 4, 0, 0, 0, 0, time.UTC), Fajr: "01:00", Sunrise: "02:00", Dhuhr: "03:00", Asr: "04:00", Maghrib: "05:00", Isha: "06:00"},
prayerTimesRow{ProviderID: 1, LocationID: "1", Date: time.Date(2023, 3, 5, 0, 0, 0, 0, time.UTC), Fajr: "01:00", Sunrise: "02:00", Dhuhr: "03:00", Asr: "04:00", Maghrib: "05:00", Isha: "06:00"},
prayerTimesRow{ProviderID: 1, LocationID: "1", Date: time.Date(2023, 3, 6, 0, 0, 0, 0, time.UTC), Fajr: "01:00", Sunrise: "02:00", Dhuhr: "03:00", Asr: "04:00", Maghrib: "05:00", Isha: "06:00"},
).Executor().Exec()
require.NoError(t, err)
},
provider: mockProvider(func() ([]prayer.Times, error) {
return nil, fmt.Errorf("no")
}),
clock: then,
assertRes: func(t *testing.T, db *goqu.Database, times []prayer.Times, err error) {
assert.NoError(t, err)
assert.Len(t, times, 2)
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
db := testDB(t)
p := Provider{
db: db,
provider: tt.provider,
clockFunc: func() time.Time { return tt.clock },
}
if tt.setupDB != nil {
tt.setupDB(t, db)
}
actual, err := p.Get(context.Background(), "1")
tt.assertRes(t, db, actual, err)
})
}
}