Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bakito committed Jan 7, 2024
1 parent 1c34300 commit e127065
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 223 deletions.
7 changes: 3 additions & 4 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/mitchellh/go-homedir"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.uber.org/zap"
)

const (
Expand Down Expand Up @@ -131,7 +130,7 @@ func initConfig() {
}
}

func getConfig(logger *zap.SugaredLogger) (*types.Config, error) {
func getConfig() (*types.Config, error) {
cfg := &types.Config{}
if err := viper.Unmarshal(cfg); err != nil {
return nil, err
Expand All @@ -143,14 +142,14 @@ func getConfig(logger *zap.SugaredLogger) (*types.Config, error) {
}

if len(cfg.Replicas) == 0 {
cfg.Replicas = append(cfg.Replicas, collectEnvReplicas(logger)...)
cfg.Replicas = append(cfg.Replicas, collectEnvReplicas()...)
}

return cfg, nil
}

// Manually collect replicas from env.
func collectEnvReplicas(logger *zap.SugaredLogger) []types.AdGuardInstance {
func collectEnvReplicas() []types.AdGuardInstance {
var replicas []types.AdGuardInstance
for _, v := range os.Environ() {
if envReplicasURLPattern.MatchString(v) {
Expand Down
14 changes: 5 additions & 9 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import (
"fmt"
"os"

"github.com/bakito/adguardhome-sync/pkg/log"
"github.com/bakito/adguardhome-sync/pkg/types"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"go.uber.org/zap"
)

var envVars = []string{
Expand All @@ -28,9 +26,7 @@ var envVars = []string{
}

var _ = Describe("Run", func() {
var logger *zap.SugaredLogger
BeforeEach(func() {
logger = log.GetLogger("root")
for _, envVar := range envVars {
Ω(os.Unsetenv(envVar)).ShouldNot(HaveOccurred())
}
Expand All @@ -43,23 +39,23 @@ var _ = Describe("Run", func() {
})
Context("getConfig", func() {
It("features should be true by default", func() {
cfg, err := getConfig(logger)
cfg, err := getConfig()
Ω(err).ShouldNot(HaveOccurred())
verifyFeatures(cfg, true)
})
It("features should be false", func() {
for _, envVar := range envVars {
Ω(os.Setenv(envVar, "false")).ShouldNot(HaveOccurred())
}
cfg, err := getConfig(logger)
cfg, err := getConfig()
Ω(err).ShouldNot(HaveOccurred())
verifyFeatures(cfg, false)
})
Context("interface name", func() {
It("should set interface name of replica 1", func() {
Ω(os.Setenv("REPLICA1_URL", "https://foo.bar")).ShouldNot(HaveOccurred())
Ω(os.Setenv(fmt.Sprintf(envReplicasInterfaceName, "1"), "eth0")).ShouldNot(HaveOccurred())
cfg, err := getConfig(logger)
cfg, err := getConfig()
Ω(err).ShouldNot(HaveOccurred())
Ω(cfg.Replicas[0].InterfaceName).Should(Equal("eth0"))
})
Expand All @@ -68,15 +64,15 @@ var _ = Describe("Run", func() {
It("should enable the dhcp server of replica 1", func() {
Ω(os.Setenv("REPLICA1_URL", "https://foo.bar")).ShouldNot(HaveOccurred())
Ω(os.Setenv(fmt.Sprintf(envDHCPServerEnabled, "1"), "true")).ShouldNot(HaveOccurred())
cfg, err := getConfig(logger)
cfg, err := getConfig()
Ω(err).ShouldNot(HaveOccurred())
Ω(cfg.Replicas[0].DHCPServerEnabled).ShouldNot(BeNil())
Ω(*cfg.Replicas[0].DHCPServerEnabled).Should(BeTrue())
})
It("should disable the dhcp server of replica 1", func() {
Ω(os.Setenv("REPLICA1_URL", "https://foo.bar")).ShouldNot(HaveOccurred())
Ω(os.Setenv(fmt.Sprintf(envDHCPServerEnabled, "1"), "false")).ShouldNot(HaveOccurred())
cfg, err := getConfig(logger)
cfg, err := getConfig()
Ω(err).ShouldNot(HaveOccurred())
Ω(cfg.Replicas[0].DHCPServerEnabled).ShouldNot(BeNil())
Ω(*cfg.Replicas[0].DHCPServerEnabled).Should(BeFalse())
Expand Down
2 changes: 1 addition & 1 deletion cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var doCmd = &cobra.Command{
Long: `Synchronizes the configuration form an origin instance to a replica`,
RunE: func(cmd *cobra.Command, args []string) error {
logger = log.GetLogger("run")
cfg, err := getConfig(logger)
cfg, err := getConfig()
if err != nil {
logger.Error(err)
return err
Expand Down
82 changes: 41 additions & 41 deletions pkg/sync/action-general.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,38 @@ var (
actionProfileInfo = func(ac *actionContext) error {
if pro, err := ac.client.ProfileInfo(); err != nil {
return err
} else if merged := pro.ShouldSyncFor(ac.o.profileInfo); merged != nil {
} else if merged := pro.ShouldSyncFor(ac.origin.profileInfo); merged != nil {
return ac.client.SetProfileInfo(merged)
}
return nil
}
actionProtection = func(ac *actionContext) error {
if ac.o.status.ProtectionEnabled != ac.rs.ProtectionEnabled {
return ac.client.ToggleProtection(ac.o.status.ProtectionEnabled)
if ac.origin.status.ProtectionEnabled != ac.replicaStatus.ProtectionEnabled {
return ac.client.ToggleProtection(ac.origin.status.ProtectionEnabled)
}
return nil
}
actionParental = func(ac *actionContext) error {
if rp, err := ac.client.Parental(); err != nil {
return err
} else if ac.o.parental != rp {
return ac.client.ToggleParental(ac.o.parental)
} else if ac.origin.parental != rp {
return ac.client.ToggleParental(ac.origin.parental)
}
return nil
}
actionSafeSearchConfig = func(ac *actionContext) error {
if ssc, err := ac.client.SafeSearchConfig(); err != nil {
return err
} else if !ac.o.safeSearch.Equals(ssc) {
return ac.client.SetSafeSearchConfig(ac.o.safeSearch)
} else if !ac.origin.safeSearch.Equals(ssc) {
return ac.client.SetSafeSearchConfig(ac.origin.safeSearch)
}
return nil
}
actionSafeBrowsing = func(ac *actionContext) error {
if rs, err := ac.client.SafeBrowsing(); err != nil {
return err
} else if ac.o.safeBrowsing != rs {
if err = ac.client.ToggleSafeBrowsing(ac.o.safeBrowsing); err != nil {
} else if ac.origin.safeBrowsing != rs {
if err = ac.client.ToggleSafeBrowsing(ac.origin.safeBrowsing); err != nil {
return err
}
}
Expand All @@ -53,8 +53,8 @@ var (
if err != nil {
return err
}
if !ac.o.queryLogConfig.Equals(qlc) {
return ac.client.SetQueryLogConfig(ac.o.queryLogConfig)
if !ac.origin.queryLogConfig.Equals(qlc) {
return ac.client.SetQueryLogConfig(ac.origin.queryLogConfig)
}
return nil
}
Expand All @@ -63,18 +63,18 @@ var (
if err != nil {
return err
}
if ac.o.statsConfig.Interval != sc.Interval {
return ac.client.SetStatsConfig(ac.o.statsConfig)
if ac.origin.statsConfig.Interval != sc.Interval {
return ac.client.SetStatsConfig(ac.origin.statsConfig)
}
return nil
}
dnsRewrites = func(ac *actionContext) error {
actionDNSRewrites = func(ac *actionContext) error {
replicaRewrites, err := ac.client.RewriteList()
if err != nil {
return err
}

a, r, d := replicaRewrites.Merge(ac.o.rewrites)
a, r, d := replicaRewrites.Merge(ac.origin.rewrites)

if err = ac.client.DeleteRewriteEntries(r...); err != nil {
return err
Expand All @@ -88,58 +88,58 @@ var (
}
return nil
}
filters = func(ac *actionContext) error {
actionFilters = func(ac *actionContext) error {
rf, err := ac.client.Filtering()
if err != nil {
return err
}

if err = syncFilterType(ac.rl, ac.o.filters.Filters, rf.Filters, false, ac.client, ac.continueOnError); err != nil {
if err = syncFilterType(ac.rl, ac.origin.filters.Filters, rf.Filters, false, ac.client, ac.continueOnError); err != nil {
return err
}
if err = syncFilterType(ac.rl, ac.o.filters.WhitelistFilters, rf.WhitelistFilters, true, ac.client, ac.continueOnError); err != nil {
if err = syncFilterType(ac.rl, ac.origin.filters.WhitelistFilters, rf.WhitelistFilters, true, ac.client, ac.continueOnError); err != nil {
return err
}

if utils.PtrToString(ac.o.filters.UserRules) != utils.PtrToString(rf.UserRules) {
return ac.client.SetCustomRules(ac.o.filters.UserRules)
if utils.PtrToString(ac.origin.filters.UserRules) != utils.PtrToString(rf.UserRules) {
return ac.client.SetCustomRules(ac.origin.filters.UserRules)
}

if ac.o.filters.Enabled != rf.Enabled || ac.o.filters.Interval != rf.Interval {
return ac.client.ToggleFiltering(*ac.o.filters.Enabled, *ac.o.filters.Interval)
if ac.origin.filters.Enabled != rf.Enabled || ac.origin.filters.Interval != rf.Interval {
return ac.client.ToggleFiltering(*ac.origin.filters.Enabled, *ac.origin.filters.Interval)
}
return nil
}

blockedServices = func(ac *actionContext) error {
actionBlockedServices = func(ac *actionContext) error {
rs, err := ac.client.BlockedServices()
if err != nil {
return err
}

if !model.EqualsStringSlice(ac.o.blockedServices, rs, true) {
return ac.client.SetBlockedServices(ac.o.blockedServices)
if !model.EqualsStringSlice(ac.origin.blockedServices, rs, true) {
return ac.client.SetBlockedServices(ac.origin.blockedServices)
}
return nil
}
blockedServicesSchedule = func(ac *actionContext) error {
actionBlockedServicesSchedule = func(ac *actionContext) error {
rbss, err := ac.client.BlockedServicesSchedule()
if err != nil {
return err
}

if !ac.o.blockedServicesSchedule.Equals(rbss) {
return ac.client.SetBlockedServicesSchedule(ac.o.blockedServicesSchedule)
if !ac.origin.blockedServicesSchedule.Equals(rbss) {
return ac.client.SetBlockedServicesSchedule(ac.origin.blockedServicesSchedule)
}
return nil
}
clientSettings = func(ac *actionContext) error {
actionClientSettings = func(ac *actionContext) error {
rc, err := ac.client.Clients()
if err != nil {
return err
}

a, u, r := rc.Merge(ac.o.clients)
a, u, r := rc.Merge(ac.origin.clients)

for _, client := range r {
if err := ac.client.DeleteClient(client); err != nil {
Expand Down Expand Up @@ -171,35 +171,35 @@ var (
return nil
}

dnsAccessLists = func(ac *actionContext) error {
actionDNSAccessLists = func(ac *actionContext) error {
al, err := ac.client.AccessList()
if err != nil {
return err
}
if !al.Equals(ac.o.accessList) {
return ac.client.SetAccessList(ac.o.accessList)
if !al.Equals(ac.origin.accessList) {
return ac.client.SetAccessList(ac.origin.accessList)
}
return nil
}
dnsServerConfig = func(ac *actionContext) error {
actionDNSServerConfig = func(ac *actionContext) error {
dc, err := ac.client.DNSConfig()
if err != nil {
return err
}
if !dc.Equals(ac.o.dnsConfig) {
if err = ac.client.SetDNSConfig(ac.o.dnsConfig); err != nil {
if !dc.Equals(ac.origin.dnsConfig) {
if err = ac.client.SetDNSConfig(ac.origin.dnsConfig); err != nil {
return err
}
}
return nil
}
dhcpServerConfig = func(ac *actionContext) error {
if ac.o.dhcpServerConfig.HasConfig() {
actionDHCPServerConfig = func(ac *actionContext) error {
if ac.origin.dhcpServerConfig.HasConfig() {
sc, err := ac.client.DhcpConfig()
if err != nil {
return err
}
origClone := ac.o.dhcpServerConfig.Clone()
origClone := ac.origin.dhcpServerConfig.Clone()
if ac.replica.InterfaceName != "" {
// overwrite interface name
origClone.InterfaceName = utils.Ptr(ac.replica.InterfaceName)
Expand All @@ -215,13 +215,13 @@ var (
}
return nil
}
dhcpStaticLeases = func(ac *actionContext) error {
actionDHCPStaticLeases = func(ac *actionContext) error {
sc, err := ac.client.DhcpConfig()
if err != nil {
return err
}

a, r := model.MergeDhcpStaticLeases(sc.StaticLeases, ac.o.dhcpServerConfig.StaticLeases)
a, r := model.MergeDhcpStaticLeases(sc.StaticLeases, ac.origin.dhcpServerConfig.StaticLeases)

for _, lease := range r {
if err := ac.client.DeleteDHCPStaticLease(lease); err != nil {
Expand Down
24 changes: 12 additions & 12 deletions pkg/sync/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,47 +29,47 @@ func setupActions(cfg *types.Config) (actions []syncAction) {
}
if cfg.Features.DNS.Rewrites {
actions = append(actions,
action("DNS rewrites", dnsRewrites),
action("DNS rewrites", actionDNSRewrites),
)
}
if cfg.Features.Filters {
actions = append(actions,
action("filters", filters),
action("actionFilters", actionFilters),
)
}
if cfg.Features.Services {
actions = append(actions,
action("blocked services", blockedServices),
action("blocked services schedule", blockedServicesSchedule),
action("blocked services", actionBlockedServices),
action("blocked services schedule", actionBlockedServicesSchedule),
)
}
if cfg.Features.ClientSettings {
actions = append(actions,
action("client settings", clientSettings),
action("client settings", actionClientSettings),
)
}
if cfg.Features.DNS.AccessLists {
actions = append(actions,
action("DNS access lists", dnsAccessLists),
action("DNS access lists", actionDNSAccessLists),
)
}

if cfg.Features.DNS.ServerConfig {
actions = append(actions,
action("DNS server config", dnsServerConfig),
action("DNS server config", actionDNSServerConfig),
)
}
if cfg.Features.DHCP.ServerConfig {
actions = append(actions,
action("DHCP server config", dhcpServerConfig),
action("DHCP server config", actionDHCPServerConfig),
)
}
if cfg.Features.DHCP.StaticLeases {
actions = append(actions,
action("DHCP static leases", dhcpStaticLeases),
action("DHCP static leases", actionDHCPStaticLeases),
)
}
return
return actions
}

type syncAction interface {
Expand All @@ -79,9 +79,9 @@ type syncAction interface {

type actionContext struct {
rl *zap.SugaredLogger
o *origin
origin *origin
client client.Client
rs *model.ServerStatus
replicaStatus *model.ServerStatus
continueOnError bool
replica types.AdGuardInstance
}
Expand Down
Loading

0 comments on commit e127065

Please sign in to comment.