Skip to content

Commit

Permalink
fix valid rules getting deleted in specific conditions after whalewal…
Browse files Browse the repository at this point in the history
…l is restarted (#114)
  • Loading branch information
capnspacehook authored Jun 16, 2023
1 parent 878d763 commit 36cb2b7
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 35 deletions.
12 changes: 8 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,15 @@ func (r ruleConfig) MarshalLogObject(enc zapcore.ObjectEncoder) error {
enc.AddString("container", r.Container)
}
enc.AddString("proto", r.Proto.String())
if err := enc.AddArray("src_ports", portsList(r.SrcPorts)); err != nil {
return err
if len(r.SrcPorts) != 0 {
if err := enc.AddArray("src_ports", portsList(r.SrcPorts)); err != nil {
return err
}
}
if err := enc.AddArray("dst_ports", portsList(r.DstPorts)); err != nil {
return err
if len(r.DstPorts) != 0 {
if err := enc.AddArray("dst_ports", portsList(r.DstPorts)); err != nil {
return err
}
}
if err := enc.AddObject("verdict", r.Verdict); err != nil {
return err
Expand Down
42 changes: 23 additions & 19 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,24 +439,27 @@ func (r *RuleManager) populateOutputRules(ctx context.Context, tx database.TX, c

if !found {
// we need to add rules to this container's chain, but it
// hasn't been processed yet; add the rule to the database
// so when we are processing this container, this rule will
// be created
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
if err := encoder.Encode(ruleCfg); err != nil {
return fmt.Errorf("error encoding waiting container rule: %w", err)
}
err := tx.AddWaitingContainerRule(ctx, database.AddWaitingContainerRuleParams{
SrcContainerID: id,
DstContainerName: ruleCfg.Container,
Rule: buf.Bytes(),
})
if err != nil {
return fmt.Errorf("error adding waiting container rule to database: %w", err)
}
// hasn't been processed yet; wait until this container
// is processed to create the rules
cfg.Output[i].skip = true
}
// Add the rule to the database so when we are processing
// this container, this rule will be created. This is done
// even when the container has been processed so future
// rule creation will be idempotent.
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
if err := encoder.Encode(ruleCfg); err != nil {
return fmt.Errorf("error encoding waiting container rule: %w", err)
}
err := tx.AddWaitingContainerRule(ctx, database.AddWaitingContainerRuleParams{
SrcContainerID: id,
DstContainerName: ruleCfg.Container,
Rule: buf.Bytes(),
})
if err != nil {
return fmt.Errorf("error adding waiting container rule to database: %w", err)
}
}
}

Expand Down Expand Up @@ -516,6 +519,7 @@ func buildChainName(name, id string) string {
return fmt.Sprintf("%s%s-%s", chainPrefix, name, id[:12])
}

// TODO: avoid creating almost duplicate rules as output rules
// createPortMappingRules adds nftables rules to allow or deny access to
// mapped ports.
func (r *RuleManager) createPortMappingRules(nfc firewallClient, logger *zap.Logger, container types.ContainerJSON, contName string, mappedPortsCfg mappedPorts, addrs map[string][]byte, chain *nftables.Chain) ([]*nftables.Rule, error) {
Expand Down Expand Up @@ -911,17 +915,17 @@ func (r ruleDetails) MarshalLogObject(enc zapcore.ObjectEncoder) error {
if r.estChain != nil {
enc.AddString("est_chain", r.estChain.Name)
}
enc.AddString("cont_id", r.contID)
enc.AddString("cont_id", r.contID[:12])
if r.estContID != "" {
enc.AddString("est_cont_id", r.estContID)
enc.AddString("est_cont_id", r.estContID[:12])
}

return nil
}

// createNFTRules returns a slice of [*nftables.Rule] described by rd.
func createNFTRules(nfc firewallClient, logger *zap.Logger, rd ruleDetails) ([]*nftables.Rule, error) {
logger.Debug("creating rule", zap.Object("rule", rd))
logger.Debug("generating rule", zap.Object("rule", rd))

rules := make([]*nftables.Rule, 0, 3)
estContID := rd.contID
Expand Down
187 changes: 175 additions & 12 deletions whalewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ mapped_ports:
return rulesEqual(logger, r1, r2)
}

testCreatingRules := func(tt ruleCreationTest, allContainersStarted bool, clearRules bool) func(*testing.T) {
testCreatingRules := func(tt ruleCreationTest, allContainersStarted, clearRules bool) func(*testing.T) {
return func(t *testing.T) {
is := is.New(t)

Expand All @@ -2072,7 +2072,7 @@ mapped_ports:

var dockerCli *mockDockerClient
if allContainersStarted {
dockerCli = newMockDockerClient(tt.containers)
dockerCli = newMockDockerClient(clone(tt.containers))
} else {
dockerCli = newMockDockerClient(nil)
}
Expand Down Expand Up @@ -2110,12 +2110,16 @@ mapped_ports:
subTestName := "containers are new"
if !containerIsNew {
subTestName = "containers are not new"

if len(tt.containers) > 1 {
reverse(dockerCli.containers)
}
}

t.Run(subTestName, func(t *testing.T) {
// create rules
for _, c := range tt.containers {
if !allContainersStarted {
if !allContainersStarted && len(dockerCli.containers) < len(tt.containers) {
dockerCli.containers = append(dockerCli.containers, c)
}

Expand Down Expand Up @@ -2172,15 +2176,33 @@ mapped_ports:
}

for _, tt := range tests {
if len(tt.containers) == 1 {
t.Run(tt.name+"/delete container rules", testCreatingRules(tt, true, false))
t.Run(tt.name+"/clear all rules", testCreatingRules(tt, true, true))
} else {
t.Run(tt.name+"/all containers started/delete container rules", testCreatingRules(tt, true, false))
t.Run(tt.name+"/all containers started/clear all rules", testCreatingRules(tt, true, true))
t.Run(tt.name+"/one container at a time/delete container rules", testCreatingRules(tt, false, false))
t.Run(tt.name+"/one container at a time/clear all rules", testCreatingRules(tt, false, true))
}
t.Run(tt.name, func(t *testing.T) {
if len(tt.containers) == 1 {
t.Run("delete container rules", testCreatingRules(tt, true, false))
t.Run("clear all rules", testCreatingRules(tt, true, true))
} else {
runTests := func(t *testing.T) {
t.Run("all containers started/delete container rules", testCreatingRules(tt, true, false))
t.Run("all containers started/clear all rules", testCreatingRules(tt, true, true))
t.Run("one container at a time/delete container rules", testCreatingRules(tt, false, false))
t.Run("one container at a time/clear all rules", testCreatingRules(tt, false, true))
}

runTests(t)
// run same tests with containers in reverse order
reverse(tt.containers)
t.Run("container order reversed", func(t *testing.T) {
// reverse order of all expected rules except the
// drop rule (which will always be last)
for chain, rules := range tt.expectedRules {
reverse(rules[:len(rules)-1])
tt.expectedRules[chain] = rules
}

runTests(t)
})
}
})
}
}

Expand Down Expand Up @@ -2326,6 +2348,141 @@ output:
compareRules(t, comparer, cont2ChainName, cont2RulesBefore, cont2RulesAfter)
}

func TestCreationIdempotency(t *testing.T) {
containers := []types.ContainerJSON{
{
ContainerJSONBase: &types.ContainerJSONBase{
ID: cont2ID,
Name: "/" + cont2Name,
},
Config: &container.Config{
Labels: map[string]string{
enabledLabel: "true",
},
},
NetworkSettings: &types.NetworkSettings{
Networks: map[string]*network.EndpointSettings{
"cont_net": {
Gateway: gatewayAddr.String(),
IPAddress: cont2Addr.String(),
},
},
},
},
{
ContainerJSONBase: &types.ContainerJSONBase{
ID: cont1ID,
Name: "/" + cont1Name,
},
Config: &container.Config{
Labels: map[string]string{
enabledLabel: "true",
rulesLabel: `
output:
- container: container2
network: cont_net
proto: udp
dst_ports:
- 9001`,
},
},
NetworkSettings: &types.NetworkSettings{
Networks: map[string]*network.EndpointSettings{
"cont_net": {
Gateway: gatewayAddr.String(),
IPAddress: cont1Addr.String(),
},
},
},
},
}

is := is.New(t)
logger, err := zap.NewDevelopment()
is.NoErr(err)

comparer := func(r1, r2 *nftables.Rule) bool {
return rulesEqual(logger, r1, r2)
}

dbFile := filepath.Join(t.TempDir(), "db.sqlite")
r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout)
is.NoErr(err)

dockerCli := newMockDockerClient(nil)
r.newDockerClient = func() (dockerClient, error) {
return dockerCli, nil
}

// create mock nftables client and add required prerequisite
// DOCKER-USER chain
mfc := newMockFirewall(logger)
mfc.AddTable(filterTable)
mfc.AddChain(&nftables.Chain{
Name: dockerChainName,
Table: filterTable,
Type: nftables.ChainTypeFilter,
})
is.NoErr(mfc.Flush())
r.newFirewallClient = func() (firewallClient, error) {
return newMockFirewall(logger), nil
}

// create new database and base rules
err = r.init(context.Background())
is.NoErr(err)
err = r.createBaseRules()
is.NoErr(err)
t.Cleanup(func() {
err := r.clearRules(context.Background())
is.NoErr(err)
})

// create rules
for _, c := range containers {
dockerCli.containers = append(dockerCli.containers, c)
err := r.createContainerRules(context.Background(), c, true)
is.NoErr(err)
}

cont1ChainName := buildChainName(cont1Name, cont1ID)
cont1Chain := &nftables.Chain{
Table: filterTable,
Name: cont1ChainName,
}
cont1RulesBefore, err := mfc.GetRules(filterTable, cont1Chain)
is.NoErr(err)
is.True(len(cont1RulesBefore) == 2)

cont2ChainName := buildChainName(cont2Name, cont2ID)
cont2Chain := &nftables.Chain{
Table: filterTable,
Name: cont2ChainName,
}
cont2RulesBefore, err := mfc.GetRules(filterTable, cont2Chain)
is.NoErr(err)
is.True(len(cont2RulesBefore) == 2)

// recreate rules in opposite container order
reverse(containers)
for _, c := range containers {
err := r.createContainerRules(context.Background(), c, false)
is.NoErr(err)
}

cont1RulesAfter, err := mfc.GetRules(filterTable, cont1Chain)
is.NoErr(err)
is.True(len(cont1RulesAfter) == 2)

cont2RulesAfter, err := mfc.GetRules(filterTable, cont2Chain)
is.NoErr(err)
is.True(len(cont2RulesAfter) == 2)

// ensure rules of both containers are the same as before
compareRules(t, comparer, cont1ChainName, cont1RulesBefore, cont1RulesAfter)
compareRules(t, comparer, cont2ChainName, cont2RulesBefore, cont2RulesAfter)
}

type dbOnCommit struct {
database.DB
onCommit func(database.TX) error
Expand Down Expand Up @@ -2535,3 +2692,9 @@ func slicesJoin[T any](s ...[]T) (ret []T) {

return ret
}

func reverse[E any](s []E) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
}

0 comments on commit 36cb2b7

Please sign in to comment.