Skip to content

Commit

Permalink
Fixes issue verifying Windows CSP profiles that contain ADMX policies.
Browse files Browse the repository at this point in the history
  • Loading branch information
getvictor committed Jan 16, 2025
1 parent 7ab3470 commit 9e9972c
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 7 deletions.
1 change: 1 addition & 0 deletions changes/24790-admx-policies
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixes issue verifying Windows CSP profiles that contain ADMX policies.
123 changes: 123 additions & 0 deletions server/mdm/microsoft/admx/admx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Package admx handles ADMX (Administrative Template File) policies for Microsoft MDM server.
// See: https://learn.microsoft.com/en-us/windows/client-management/understanding-admx-backed-policies
//
// ADMX policy payload example:
// <![CDATA[
//
// <enabled/>
// <data id="Publishing_Server2_Name_Prompt" value="Name"/>
// <data id="Publishing_Server_URL_Prompt" value="http://someuri"/>
// <data id="Global_Publishing_Refresh_Options" value="1"/>
//
// ]]>
package admx

import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"html"
"regexp"
"slices"
"strings"
)

var cdataRegexp = regexp.MustCompile(`(?s)<!\[CDATA\[(.*?)]]>`)

func IsADMX(text string) bool {
// We try to unmarshal the string to see if it looks like a valid ADMX policy
policy, err := unmarshal(text)
if err != nil {
return false
}
return policy.Enabled.Local == "enabled" || policy.Disabled.Local == "disabled" || len(policy.Data) > 0
}

func Equal(a, b string) (bool, error) {
aPolicy, err := unmarshal(a)
if err != nil {
return false, fmt.Errorf("unmarshalling ADMX policy a: %w", err)
}
bPolicy, err := unmarshal(b)
if err != nil {
return false, fmt.Errorf("unmarshalling ADMX policy b: %w", err)
}
return aPolicy.Equal(bPolicy), nil
}

func unmarshal(a string) (admxPolicy, error) {
a, err := convertToXMLString(a)
if err != nil {
return admxPolicy{}, fmt.Errorf("converting ADMX policy to XML string: %w", err)
}
var policy admxPolicy
err = xml.NewDecoder(bytes.NewReader([]byte(a))).Decode(&policy)
if err != nil {
return admxPolicy{}, fmt.Errorf("unmarshalling ADMX policy: %w", err)
}
return policy, nil
}

func convertToXMLString(a string) (string, error) {
matches := cdataRegexp.FindAllStringSubmatch(a, -1)
if len(matches) > 1 {
return "", errors.New("multiple CDATA matches found in ADMX policy")
}
if len(matches) == 1 && len(matches[0]) > 1 {
// If CDATA is present, we extract the content. Otherwise, we use the original string.
a = matches[0][1]
}
a = html.UnescapeString(a)
// ADMX policy elements are not case-sensitive. For example: <enabled/> and <Enabled/> are equivalent
// For simplicity, we compare everything in lowercase.
a = strings.ToLower(a)
// We wrap the policy in a <policy> tag to ensure it can be unmarshalled by the XML decoder
a = `<policy>` + a + `</policy>`
return a, nil
}

type admxPolicy struct {
Enabled xml.Name `xml:"enabled,omitempty"`
Disabled xml.Name `xml:"disabled,omitempty"`
Data []admxPolicyItem `xml:"data"`
}

func (a admxPolicy) Equal(b admxPolicy) bool {
if a.Disabled.Local != b.Disabled.Local {
return false
}
if a.Disabled.Local == "disabled" {
// If the ADMX policy is disabled, the data is not relevant
return true
}
if a.Enabled.Local != b.Enabled.Local {
return false
}
if len(a.Data) != len(b.Data) {
return false
}
a.sortData()
b.sortData()
for i := range a.Data {
if !a.Data[i].Equal(b.Data[i]) {
return false
}
}
return true
}

func (a *admxPolicy) sortData() {
slices.SortFunc(a.Data, func(i, j admxPolicyItem) int {
return strings.Compare(i.ID, j.ID)
})
}

type admxPolicyItem struct {
ID string `xml:"id,attr"`
Value string `xml:"value,attr"`
}

func (a admxPolicyItem) Equal(b admxPolicyItem) bool {
return a.ID == b.ID && a.Value == b.Value
}
140 changes: 140 additions & 0 deletions server/mdm/microsoft/admx/admx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package admx

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestIsADMX(t *testing.T) {
t.Parallel()
assert.False(t, IsADMX(""))
assert.False(t, IsADMX("not an ADMX policy"))
assert.False(t, IsADMX(`<![CDATA[bozo]]>`))
assert.False(t, IsADMX(`<![CDATA[<bozo/>]]>`))
assert.True(t, IsADMX(`<![CDATA[<enabled/>]]>`))
assert.True(t, IsADMX(`<![CDATA[<disabled/>]]>`))
assert.True(t, IsADMX(`<![CDATA[<data id="id" value="value"/>]]>`))
assert.True(t, IsADMX(
` <![CDATA[
<enabled/>
]]>`))
assert.True(t,
IsADMX("&lt;Enabled/&gt;&lt;Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/&gt;&lt;Data id=\"ExecutionPolicy\" value=\"AllSigned\"/&gt;&lt;Data id=\"Listbox_ModuleNames\" value=\"*\"/&gt;&lt;Data id=\"OutputDirectory\" value=\"false\"/&gt;&lt;Data id=\"SourcePathForUpdateHelp\" value=\"false\"/&gt;"))
}

func TestEqual(t *testing.T) {
t.Parallel()
testCases := []struct {
name, a, b, errorContains string
equal bool
}{
{
name: "empty policies",
a: "",
b: "",
equal: true,
errorContains: "",
},
{
name: "enabled policies",
a: "<![CDATA[<enabled/>]]>",
b: "&lt;Enabled/&gt;",
equal: true,
errorContains: "",
},
{
name: "disabled policies",
a: "<![CDATA[<disabled/>]]>",
b: "&lt;Disabled/&gt;",
equal: true,
errorContains: "",
},
{
name: "unequal policies",
a: "<![CDATA[<disabled/>]]>",
b: "&lt;enabled/&gt;",
equal: false,
errorContains: "",
},
{
name: "enabled policies with data",
a: `<![CDATA[<enabled/>
<data id="ExecutionPolicy" value="AllSigned"/>
<data id="Listbox_ModuleNames" value="*"/>
<data id="OutputDirectory" value="false"/>
<data id="EnableScriptBlockInvocationLogging" value="true"/>
<data id="SourcePathForUpdateHelp" value="false"/>]]>`,
b: "&lt;Enabled/&gt;&lt;Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/&gt;&lt;Data id=\"ExecutionPolicy\" value=\"AllSigned\"/&gt;&lt;Data id=\"Listbox_ModuleNames\" value=\"*\"/&gt;&lt;Data id=\"OutputDirectory\" value=\"false\"/&gt;&lt;Data id=\"SourcePathForUpdateHelp\" value=\"false\"/&gt;",
equal: true,
errorContains: "",
},
{
name: "disabled policies with data",
a: `<![CDATA[<disabled/>
<data id="ExecutionPolicy" value="AllSigned"/>
<data id="SourcePathForUpdateHelp" value="false"/>]]>`,
b: "&lt;Disabled/&gt;&lt;Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/&gt;&lt;Data id=\"ExecutionPolicy\" value=\"AllSigned\"/&gt;&lt;Data id=\"Listbox_ModuleNames\" value=\"*\"/&gt;&lt;Data id=\"OutputDirectory\" value=\"false\"/&gt;&lt;Data id=\"SourcePathForUpdateHelp\" value=\"false\"/&gt;",
equal: true,
errorContains: "",
},
{
name: "unparsable policy",
a: "",
b: "<bozo",
equal: false,
errorContains: "unmarshalling ADMX policy",
},
{
name: "multiple CDATA",
a: "<![CDATA[<enabled/>]]><![CDATA[<bozo>]]>",
b: "",
equal: false,
errorContains: "multiple CDATA matches found",
},
{
name: "unequal policies with missing enable",
a: `<![CDATA[<Xenabled/>
<data id="ExecutionPolicy" value="AllSigned"/>
<data id="Listbox_ModuleNames" value="*"/>
<data id="OutputDirectory" value="false"/>
<data id="EnableScriptBlockInvocationLogging" value="true"/>
<data id="SourcePathForUpdateHelp" value="false"/>]]>`,
b: "&lt;Enabled/&gt;&lt;Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/&gt;&lt;Data id=\"ExecutionPolicy\" value=\"AllSigned\"/&gt;&lt;Data id=\"Listbox_ModuleNames\" value=\"*\"/&gt;&lt;Data id=\"OutputDirectory\" value=\"false\"/&gt;&lt;Data id=\"SourcePathForUpdateHelp\" value=\"false\"/&gt;",
equal: false,
errorContains: "",
},
{
name: "unequal policies with data 1",
a: `<![CDATA[<enabled/>
<data id="EnableScriptBlockInvocationLogging" value="true"/>
<data id="SourcePathForUpdateHelp" value="false"/>]]>`,
b: "&lt;Enabled/&gt;&lt;Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/&gt;&lt;Data id=\"ExecutionPolicy\" value=\"AllSigned\"/&gt;&lt;Data id=\"Listbox_ModuleNames\" value=\"*\"/&gt;&lt;Data id=\"OutputDirectory\" value=\"false\"/&gt;&lt;Data id=\"SourcePathForUpdateHelp\" value=\"false\"/&gt;",
equal: false,
errorContains: "",
},
{
name: "unequal policies with data 2",
a: `<![CDATA[<enabled/>
<data id="ExecutionPolicy" value="XXXX"/>
<data id="Listbox_ModuleNames" value="*"/>
<data id="OutputDirectory" value="false"/>
<data id="EnableScriptBlockInvocationLogging" value="true"/>
<data id="SourcePathForUpdateHelp" value="false"/>]]>`,
b: "&lt;Enabled/&gt;&lt;Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/&gt;&lt;Data id=\"ExecutionPolicy\" value=\"AllSigned\"/&gt;&lt;Data id=\"Listbox_ModuleNames\" value=\"*\"/&gt;&lt;Data id=\"OutputDirectory\" value=\"false\"/&gt;&lt;Data id=\"SourcePathForUpdateHelp\" value=\"false\"/&gt;",
equal: false,
errorContains: "",
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
equal, err := Equal(tt.a, tt.b)
if tt.errorContains == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.errorContains)
}
assert.Equal(t, tt.equal, equal)
})
}
}
32 changes: 25 additions & 7 deletions server/mdm/microsoft/profile_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm"
"github.com/fleetdm/fleet/v4/server/mdm/microsoft/admx"
)

// LoopOverExpectedHostProfiles loops all the <LocURI> values on all the profiles for a
Expand Down Expand Up @@ -72,20 +73,24 @@ func HashLocURI(profileName, locURI string) string {
func VerifyHostMDMProfiles(ctx context.Context, ds fleet.ProfileVerificationStore, host *fleet.Host, rawPolicyResultsSyncML []byte) error {
policyResults, err := transformPolicyResults(rawPolicyResultsSyncML)
if err != nil {
return err
return ctxerr.Wrap(ctx, err, "transforming policy results")
}

verified, missing, err := compareResultsToExpectedProfiles(ctx, ds, host, policyResults)
if err != nil {
return err
return ctxerr.Wrap(ctx, err, "comparing results to expected profiles")
}

toFail, toRetry, err := splitMissingProfilesIntoFailAndRetryBuckets(ctx, ds, host, missing, verified)
if err != nil {
return err
return ctxerr.Wrap(ctx, err, "splitting missing profiles into fail and retry buckets")
}

return ds.UpdateHostMDMProfilesVerification(ctx, host, slices.Collect(maps.Keys(verified)), toFail, toRetry)
err = ds.UpdateHostMDMProfilesVerification(ctx, host, slices.Collect(maps.Keys(verified)), toFail, toRetry)
if err != nil {
return ctxerr.Wrap(ctx, err, "updating host mdm profiles during verification")
}
return nil
}

func splitMissingProfilesIntoFailAndRetryBuckets(ctx context.Context, ds fleet.ProfileVerificationStore, host *fleet.Host,
Expand Down Expand Up @@ -127,19 +132,32 @@ func compareResultsToExpectedProfiles(ctx context.Context, ds fleet.ProfileVerif
missing = map[string]struct{}{}
verified = map[string]struct{}{}
err = LoopOverExpectedHostProfiles(ctx, ds, host, func(profile *fleet.ExpectedMDMProfile, ref, locURI, wantData string) {
// if we didn't get a status for a LocURI, mark the profile as
// missing.
// if we didn't get a status for a LocURI, mark the profile as missing.
gotStatus, ok := policyResults.cmdRefToStatus[ref]
if !ok {
missing[profile.Name] = struct{}{}
return
}
// it's okay if we didn't get a result
gotResults := policyResults.cmdRefToResult[ref]
// non-200 status don't have results. Consider it failed
// TODO: should we be more granular instead? eg: special case
// `4xx` responses? I'm sure there are edge cases we're not
// accounting for here, but it's unclear at this moment.
if !strings.HasPrefix(gotStatus, "2") || wantData != gotResults {
var equal bool
switch {
case !strings.HasPrefix(gotStatus, "2"):
equal = false
case wantData == gotResults:
equal = true
case admx.IsADMX(wantData):
equal, err = admx.Equal(wantData, gotResults)
if err != nil {
err = fmt.Errorf("comparing ADMX policies: %w", err)
return
}
}
if !equal {
withinGracePeriod := profile.IsWithinGracePeriod(host.DetailUpdatedAt)
if !withinGracePeriod {
missing[profile.Name] = struct{}{}
Expand Down
35 changes: 35 additions & 0 deletions server/mdm/microsoft/profile_verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,41 @@ func TestVerifyHostMDMProfilesHappyPaths(t *testing.T) {
toFail: []string{"N2", "N4"},
toRetry: []string{"N1"},
},
{
name: "single profile with CDATA reported and verified",
hostProfiles: []hostProfile{
{"N1", syncml.ForTestWithData(map[string]string{
"L1": `
<![CDATA[<enabled/><data id="ExecutionPolicy" value="AllSigned"/>
<data id="Listbox_ModuleNames" value="*"/>
<data id="OutputDirectory" value="false"/>
<data id="EnableScriptBlockInvocationLogging" value="true"/>
<data id="SourcePathForUpdateHelp" value="false"/>]]>`,
}), 0},
},
report: []osqueryReport{{"N1", "200", "L1",
"&lt;Enabled/&gt;&lt;Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/&gt;&lt;Data id=\"ExecutionPolicy\" value=\"AllSigned\"/&gt;&lt;Data id=\"Listbox_ModuleNames\" value=\"*\"/&gt;&lt;Data id=\"OutputDirectory\" value=\"false\"/&gt;&lt;Data id=\"SourcePathForUpdateHelp\" value=\"false\"/&gt;",
}},
toVerify: []string{"N1"},
toFail: []string{},
toRetry: []string{},
},
{
name: "single profile with CDATA to retry",
hostProfiles: []hostProfile{
{"N1", syncml.ForTestWithData(map[string]string{
"L1": `
<![CDATA[<enabled/><data id="ExecutionPolicy" value="AllSigned"/>
<data id="SourcePathForUpdateHelp" value="false"/>]]>`,
}), 0},
},
report: []osqueryReport{{"N1", "200", "L1",
"&lt;Enabled/&gt;&lt;Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/&gt;&lt;Data id=\"ExecutionPolicy\" value=\"AllSigned\"/&gt;&lt;Data id=\"Listbox_ModuleNames\" value=\"*\"/&gt;&lt;Data id=\"OutputDirectory\" value=\"false\"/&gt;&lt;Data id=\"SourcePathForUpdateHelp\" value=\"false\"/&gt;",
}},
toVerify: []string{},
toFail: []string{},
toRetry: []string{"N1"},
},
}

for _, tt := range cases {
Expand Down

0 comments on commit 9e9972c

Please sign in to comment.