From 36cb2b7735d1a7006a69d2e9ee5bbb6ec67ef346 Mon Sep 17 00:00:00 2001 From: Andrew LeFevre Date: Fri, 16 Jun 2023 19:28:11 -0400 Subject: [PATCH] fix valid rules getting deleted in specific conditions after whalewall is restarted (#114) --- config.go | 12 ++- create.go | 42 ++++++----- whalewall_test.go | 187 +++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 206 insertions(+), 35 deletions(-) diff --git a/config.go b/config.go index f26b35d..5cd4ebb 100644 --- a/config.go +++ b/config.go @@ -64,11 +64,15 @@ func (r ruleConfig) MarshalLogObject(enc zapcore.ObjectEncoder) error { enc.AddString("container", r.Container) } enc.AddString("proto", r.Proto.String()) - if err := enc.AddArray("src_ports", portsList(r.SrcPorts)); err != nil { - return err + if len(r.SrcPorts) != 0 { + if err := enc.AddArray("src_ports", portsList(r.SrcPorts)); err != nil { + return err + } } - if err := enc.AddArray("dst_ports", portsList(r.DstPorts)); err != nil { - return err + if len(r.DstPorts) != 0 { + if err := enc.AddArray("dst_ports", portsList(r.DstPorts)); err != nil { + return err + } } if err := enc.AddObject("verdict", r.Verdict); err != nil { return err diff --git a/create.go b/create.go index 1dae0a0..07336c4 100644 --- a/create.go +++ b/create.go @@ -439,24 +439,27 @@ func (r *RuleManager) populateOutputRules(ctx context.Context, tx database.TX, c if !found { // we need to add rules to this container's chain, but it - // hasn't been processed yet; add the rule to the database - // so when we are processing this container, this rule will - // be created - var buf bytes.Buffer - encoder := gob.NewEncoder(&buf) - if err := encoder.Encode(ruleCfg); err != nil { - return fmt.Errorf("error encoding waiting container rule: %w", err) - } - err := tx.AddWaitingContainerRule(ctx, database.AddWaitingContainerRuleParams{ - SrcContainerID: id, - DstContainerName: ruleCfg.Container, - Rule: buf.Bytes(), - }) - if err != nil { - return fmt.Errorf("error adding waiting container rule to database: %w", err) - } + // hasn't been processed yet; wait until this container + // is processed to create the rules cfg.Output[i].skip = true } + // Add the rule to the database so when we are processing + // this container, this rule will be created. This is done + // even when the container has been processed so future + // rule creation will be idempotent. + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + if err := encoder.Encode(ruleCfg); err != nil { + return fmt.Errorf("error encoding waiting container rule: %w", err) + } + err := tx.AddWaitingContainerRule(ctx, database.AddWaitingContainerRuleParams{ + SrcContainerID: id, + DstContainerName: ruleCfg.Container, + Rule: buf.Bytes(), + }) + if err != nil { + return fmt.Errorf("error adding waiting container rule to database: %w", err) + } } } @@ -516,6 +519,7 @@ func buildChainName(name, id string) string { return fmt.Sprintf("%s%s-%s", chainPrefix, name, id[:12]) } +// TODO: avoid creating almost duplicate rules as output rules // createPortMappingRules adds nftables rules to allow or deny access to // mapped ports. func (r *RuleManager) createPortMappingRules(nfc firewallClient, logger *zap.Logger, container types.ContainerJSON, contName string, mappedPortsCfg mappedPorts, addrs map[string][]byte, chain *nftables.Chain) ([]*nftables.Rule, error) { @@ -911,9 +915,9 @@ func (r ruleDetails) MarshalLogObject(enc zapcore.ObjectEncoder) error { if r.estChain != nil { enc.AddString("est_chain", r.estChain.Name) } - enc.AddString("cont_id", r.contID) + enc.AddString("cont_id", r.contID[:12]) if r.estContID != "" { - enc.AddString("est_cont_id", r.estContID) + enc.AddString("est_cont_id", r.estContID[:12]) } return nil @@ -921,7 +925,7 @@ func (r ruleDetails) MarshalLogObject(enc zapcore.ObjectEncoder) error { // createNFTRules returns a slice of [*nftables.Rule] described by rd. func createNFTRules(nfc firewallClient, logger *zap.Logger, rd ruleDetails) ([]*nftables.Rule, error) { - logger.Debug("creating rule", zap.Object("rule", rd)) + logger.Debug("generating rule", zap.Object("rule", rd)) rules := make([]*nftables.Rule, 0, 3) estContID := rd.contID diff --git a/whalewall_test.go b/whalewall_test.go index f75d257..3120466 100644 --- a/whalewall_test.go +++ b/whalewall_test.go @@ -2062,7 +2062,7 @@ mapped_ports: return rulesEqual(logger, r1, r2) } - testCreatingRules := func(tt ruleCreationTest, allContainersStarted bool, clearRules bool) func(*testing.T) { + testCreatingRules := func(tt ruleCreationTest, allContainersStarted, clearRules bool) func(*testing.T) { return func(t *testing.T) { is := is.New(t) @@ -2072,7 +2072,7 @@ mapped_ports: var dockerCli *mockDockerClient if allContainersStarted { - dockerCli = newMockDockerClient(tt.containers) + dockerCli = newMockDockerClient(clone(tt.containers)) } else { dockerCli = newMockDockerClient(nil) } @@ -2110,12 +2110,16 @@ mapped_ports: subTestName := "containers are new" if !containerIsNew { subTestName = "containers are not new" + + if len(tt.containers) > 1 { + reverse(dockerCli.containers) + } } t.Run(subTestName, func(t *testing.T) { // create rules for _, c := range tt.containers { - if !allContainersStarted { + if !allContainersStarted && len(dockerCli.containers) < len(tt.containers) { dockerCli.containers = append(dockerCli.containers, c) } @@ -2172,15 +2176,33 @@ mapped_ports: } for _, tt := range tests { - if len(tt.containers) == 1 { - t.Run(tt.name+"/delete container rules", testCreatingRules(tt, true, false)) - t.Run(tt.name+"/clear all rules", testCreatingRules(tt, true, true)) - } else { - t.Run(tt.name+"/all containers started/delete container rules", testCreatingRules(tt, true, false)) - t.Run(tt.name+"/all containers started/clear all rules", testCreatingRules(tt, true, true)) - t.Run(tt.name+"/one container at a time/delete container rules", testCreatingRules(tt, false, false)) - t.Run(tt.name+"/one container at a time/clear all rules", testCreatingRules(tt, false, true)) - } + t.Run(tt.name, func(t *testing.T) { + if len(tt.containers) == 1 { + t.Run("delete container rules", testCreatingRules(tt, true, false)) + t.Run("clear all rules", testCreatingRules(tt, true, true)) + } else { + runTests := func(t *testing.T) { + t.Run("all containers started/delete container rules", testCreatingRules(tt, true, false)) + t.Run("all containers started/clear all rules", testCreatingRules(tt, true, true)) + t.Run("one container at a time/delete container rules", testCreatingRules(tt, false, false)) + t.Run("one container at a time/clear all rules", testCreatingRules(tt, false, true)) + } + + runTests(t) + // run same tests with containers in reverse order + reverse(tt.containers) + t.Run("container order reversed", func(t *testing.T) { + // reverse order of all expected rules except the + // drop rule (which will always be last) + for chain, rules := range tt.expectedRules { + reverse(rules[:len(rules)-1]) + tt.expectedRules[chain] = rules + } + + runTests(t) + }) + } + }) } } @@ -2326,6 +2348,141 @@ output: compareRules(t, comparer, cont2ChainName, cont2RulesBefore, cont2RulesAfter) } +func TestCreationIdempotency(t *testing.T) { + containers := []types.ContainerJSON{ + { + ContainerJSONBase: &types.ContainerJSONBase{ + ID: cont2ID, + Name: "/" + cont2Name, + }, + Config: &container.Config{ + Labels: map[string]string{ + enabledLabel: "true", + }, + }, + NetworkSettings: &types.NetworkSettings{ + Networks: map[string]*network.EndpointSettings{ + "cont_net": { + Gateway: gatewayAddr.String(), + IPAddress: cont2Addr.String(), + }, + }, + }, + }, + { + ContainerJSONBase: &types.ContainerJSONBase{ + ID: cont1ID, + Name: "/" + cont1Name, + }, + Config: &container.Config{ + Labels: map[string]string{ + enabledLabel: "true", + rulesLabel: ` +output: + - container: container2 + network: cont_net + proto: udp + dst_ports: + - 9001`, + }, + }, + NetworkSettings: &types.NetworkSettings{ + Networks: map[string]*network.EndpointSettings{ + "cont_net": { + Gateway: gatewayAddr.String(), + IPAddress: cont1Addr.String(), + }, + }, + }, + }, + } + + is := is.New(t) + logger, err := zap.NewDevelopment() + is.NoErr(err) + + comparer := func(r1, r2 *nftables.Rule) bool { + return rulesEqual(logger, r1, r2) + } + + dbFile := filepath.Join(t.TempDir(), "db.sqlite") + r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout) + is.NoErr(err) + + dockerCli := newMockDockerClient(nil) + r.newDockerClient = func() (dockerClient, error) { + return dockerCli, nil + } + + // create mock nftables client and add required prerequisite + // DOCKER-USER chain + mfc := newMockFirewall(logger) + mfc.AddTable(filterTable) + mfc.AddChain(&nftables.Chain{ + Name: dockerChainName, + Table: filterTable, + Type: nftables.ChainTypeFilter, + }) + is.NoErr(mfc.Flush()) + r.newFirewallClient = func() (firewallClient, error) { + return newMockFirewall(logger), nil + } + + // create new database and base rules + err = r.init(context.Background()) + is.NoErr(err) + err = r.createBaseRules() + is.NoErr(err) + t.Cleanup(func() { + err := r.clearRules(context.Background()) + is.NoErr(err) + }) + + // create rules + for _, c := range containers { + dockerCli.containers = append(dockerCli.containers, c) + err := r.createContainerRules(context.Background(), c, true) + is.NoErr(err) + } + + cont1ChainName := buildChainName(cont1Name, cont1ID) + cont1Chain := &nftables.Chain{ + Table: filterTable, + Name: cont1ChainName, + } + cont1RulesBefore, err := mfc.GetRules(filterTable, cont1Chain) + is.NoErr(err) + is.True(len(cont1RulesBefore) == 2) + + cont2ChainName := buildChainName(cont2Name, cont2ID) + cont2Chain := &nftables.Chain{ + Table: filterTable, + Name: cont2ChainName, + } + cont2RulesBefore, err := mfc.GetRules(filterTable, cont2Chain) + is.NoErr(err) + is.True(len(cont2RulesBefore) == 2) + + // recreate rules in opposite container order + reverse(containers) + for _, c := range containers { + err := r.createContainerRules(context.Background(), c, false) + is.NoErr(err) + } + + cont1RulesAfter, err := mfc.GetRules(filterTable, cont1Chain) + is.NoErr(err) + is.True(len(cont1RulesAfter) == 2) + + cont2RulesAfter, err := mfc.GetRules(filterTable, cont2Chain) + is.NoErr(err) + is.True(len(cont2RulesAfter) == 2) + + // ensure rules of both containers are the same as before + compareRules(t, comparer, cont1ChainName, cont1RulesBefore, cont1RulesAfter) + compareRules(t, comparer, cont2ChainName, cont2RulesBefore, cont2RulesAfter) +} + type dbOnCommit struct { database.DB onCommit func(database.TX) error @@ -2535,3 +2692,9 @@ func slicesJoin[T any](s ...[]T) (ret []T) { return ret } + +func reverse[E any](s []E) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +}