diff --git a/core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java b/core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java index ce5e6b8c92..9349ed02db 100644 --- a/core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java @@ -5,7 +5,15 @@ package org.opensearch.sql.expression.ip; +import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.expression.function.FunctionDSL.*; + import com.google.common.net.InetAddresses; +import java.math.BigInteger; +import java.net.InetAddress; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import lombok.experimental.UtilityClass; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -14,87 +22,85 @@ import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.DefaultFunctionResolver; -import java.math.BigInteger; -import java.net.InetAddress; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.expression.function.FunctionDSL.*; - -/** - * Utility class that defines and registers IP functions. - */ +/** Utility class that defines and registers IP functions. */ @UtilityClass public class IPFunction { - private static final Pattern cidrPattern = Pattern.compile("(?
.+)[/](?[0-9]+)"); - - public void register(BuiltinFunctionRepository repository) { - repository.register(cidr()); + private static final Pattern cidrPattern = + Pattern.compile("(?
.+)[/](?[0-9]+)"); + + public void register(BuiltinFunctionRepository repository) { + repository.register(cidr()); + } + + private DefaultFunctionResolver cidr() { + return define( + BuiltinFunctionName.CIDR.getName(), + impl(nullMissingHandling(IPFunction::exprCidr), BOOLEAN, STRING, STRING)); + } + + /** + * Returns whether the given IP address is within the specified IP address range. Supports both + * IPv4 and IPv6 addresses. + * + * @param addressExprValue IP address (e.g. "198.51.100.14" or "2001:0db8::ff00:42:8329"). + * @param rangeExprValue IP address range in CIDR notation (e.g. "198.51.100.0/24" or + * "2001:0db8::/32") + * @return null if the address is not valid; true if the address is in the range; otherwise false. + * @throws SemanticCheckException if the range is not valid + */ + private ExprValue exprCidr(ExprValue addressExprValue, ExprValue rangeExprValue) { + + // Get address + String addressString = addressExprValue.stringValue(); + if (!InetAddresses.isInetAddress(addressString)) { + return ExprValueUtils.nullValue(); } - private DefaultFunctionResolver cidr() { - return define(BuiltinFunctionName.CIDR.getName(), impl(nullMissingHandling(IPFunction::exprCidr), BOOLEAN, STRING, STRING)); + InetAddress address = InetAddresses.forString(addressString); + + // Get range and network length + String rangeString = rangeExprValue.stringValue(); + + Matcher cidrMatcher = cidrPattern.matcher(rangeString); + if (!cidrMatcher.matches()) + throw new SemanticCheckException( + String.format("CIDR notation '%s' in not valid", rangeString)); + + String rangeAddressString = cidrMatcher.group("address"); + if (!InetAddresses.isInetAddress(rangeAddressString)) + throw new SemanticCheckException( + String.format("IP address '%s' in not valid", rangeAddressString)); + + InetAddress rangeAddress = InetAddresses.forString(rangeAddressString); + + // Address and range must use the same IP version (IPv4 or IPv6). + if (!address.getClass().equals(rangeAddress.getClass())) { + return ExprValueUtils.booleanValue(false); } - /** - * Returns whether the given IP address is within the specified IP address range. - * Supports both IPv4 and IPv6 addresses. - * - * @param addressExprValue IP address (e.g. "198.51.100.14" or "2001:0db8::ff00:42:8329"). - * @param rangeExprValue IP address range in CIDR notation (e.g. "198.51.100.0/24" or "2001:0db8::/32") - * @return null if the address is not valid; true if the address is in the range; otherwise false. - * @throws SemanticCheckException if the range is not valid - */ - private ExprValue exprCidr(ExprValue addressExprValue, ExprValue rangeExprValue) { - - // Get address - String addressString = addressExprValue.stringValue(); - if (!InetAddresses.isInetAddress(addressString)) { - return ExprValueUtils.nullValue(); - } - - InetAddress address = InetAddresses.forString(addressString); - - // Get range and network length - String rangeString = rangeExprValue.stringValue(); - - Matcher cidrMatcher = cidrPattern.matcher(rangeString); - if (!cidrMatcher.matches()) - throw new SemanticCheckException(String.format("CIDR notation '%s' in not valid", rangeString)); - - String rangeAddressString = cidrMatcher.group("address"); - if (!InetAddresses.isInetAddress(rangeAddressString)) - throw new SemanticCheckException(String.format("IP address '%s' in not valid", rangeAddressString)); - - InetAddress rangeAddress = InetAddresses.forString(rangeAddressString); - - // Address and range must use the same IP version (IPv4 or IPv6). - if (!address.getClass().equals(rangeAddress.getClass())) { - return ExprValueUtils.booleanValue(false); - } - - int networkLengthBits = Integer.parseInt(cidrMatcher.group("networkLength")); - int addressLengthBits = address.getAddress().length * Byte.SIZE; - - if (networkLengthBits > addressLengthBits) - throw new SemanticCheckException(String.format("Network length of '%s' bits is not valid", networkLengthBits)); - - // Build bounds by converting the address to an integer, setting all the non-significant bits to - // zero for the lower bounds and one for the upper bounds, and then converting back to addresses. - BigInteger lowerBoundInt = InetAddresses.toBigInteger(rangeAddress); - BigInteger upperBoundInt = InetAddresses.toBigInteger(rangeAddress); - - int hostLengthBits = addressLengthBits - networkLengthBits; - for (int bit = 0; bit < hostLengthBits; bit++) { - lowerBoundInt = lowerBoundInt.clearBit(bit); - upperBoundInt = upperBoundInt.setBit(bit); - } - - // Convert the address to an integer and compare it to the bounds. - BigInteger addressInt = InetAddresses.toBigInteger(address); - return ExprValueUtils.booleanValue((addressInt.compareTo(lowerBoundInt) >= 0) && (addressInt.compareTo(upperBoundInt) <= 0)); + int networkLengthBits = Integer.parseInt(cidrMatcher.group("networkLength")); + int addressLengthBits = address.getAddress().length * Byte.SIZE; + + if (networkLengthBits > addressLengthBits) + throw new SemanticCheckException( + String.format("Network length of '%s' bits is not valid", networkLengthBits)); + + // Build bounds by converting the address to an integer, setting all the non-significant bits to + // zero for the lower bounds and one for the upper bounds, and then converting back to + // addresses. + BigInteger lowerBoundInt = InetAddresses.toBigInteger(rangeAddress); + BigInteger upperBoundInt = InetAddresses.toBigInteger(rangeAddress); + + int hostLengthBits = addressLengthBits - networkLengthBits; + for (int bit = 0; bit < hostLengthBits; bit++) { + lowerBoundInt = lowerBoundInt.clearBit(bit); + upperBoundInt = upperBoundInt.setBit(bit); } + + // Convert the address to an integer and compare it to the bounds. + BigInteger addressInt = InetAddresses.toBigInteger(address); + return ExprValueUtils.booleanValue( + (addressInt.compareTo(lowerBoundInt) >= 0) && (addressInt.compareTo(upperBoundInt) <= 0)); + } } diff --git a/core/src/test/java/org/opensearch/sql/expression/ip/IPFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/ip/IPFunctionTest.java index cd1c78287c..149a6b87c1 100644 --- a/core/src/test/java/org/opensearch/sql/expression/ip/IPFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/ip/IPFunctionTest.java @@ -5,6 +5,12 @@ package org.opensearch.sql.expression.ip; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.*; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -17,77 +23,78 @@ import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.env.Environment; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.when; -import static org.opensearch.sql.data.model.ExprValueUtils.*; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; - @ExtendWith(MockitoExtension.class) public class IPFunctionTest { - // IP range and address constants for testing. - private static final ExprValue IPv4Range = ExprValueUtils.stringValue("198.51.100.0/24"); - private static final ExprValue IPv6Range = ExprValueUtils.stringValue("2001:0db8::/32"); - - private static final ExprValue IPv4AddressBelow = ExprValueUtils.stringValue("198.51.99.1"); - private static final ExprValue IPv4AddressWithin = ExprValueUtils.stringValue("198.51.100.1"); - private static final ExprValue IPv4AddressAbove = ExprValueUtils.stringValue("198.51.101.2"); - - private static final ExprValue IPv6AddressBelow = ExprValueUtils.stringValue("2001:0db7::ff00:42:8329"); - private static final ExprValue IPv6AddressWithin = ExprValueUtils.stringValue("2001:0db8::ff00:42:8329"); - private static final ExprValue IPv6AddressAbove = ExprValueUtils.stringValue("2001:0db9::ff00:42:8329"); - - // Mock value environment for testing. - @Mock - private Environment env; - - @Test - public void cidr_invalid_address() { - assertEquals(LITERAL_NULL, execute(ExprValueUtils.stringValue("INVALID"), IPv4Range)); - } - - @Test - public void cidr_invalid_range() { - assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID"))); - assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32"))); - assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33"))); - } - - @Test - public void cidr_valid_ipv4() { - assertEquals(LITERAL_FALSE, execute(IPv4AddressBelow, IPv4Range)); - assertEquals(LITERAL_TRUE, execute(IPv4AddressWithin, IPv4Range)); - assertEquals(LITERAL_FALSE, execute(IPv4AddressAbove, IPv4Range)); - } - - @Test - public void cidr_valid_ipv6() { - assertEquals(LITERAL_FALSE, execute(IPv6AddressBelow, IPv6Range)); - assertEquals(LITERAL_TRUE, execute(IPv6AddressWithin, IPv6Range)); - assertEquals(LITERAL_FALSE, execute(IPv6AddressAbove, IPv6Range)); - } - - @Test - public void cidr_valid_different_versions() { - assertEquals(LITERAL_FALSE, execute(IPv4AddressWithin, IPv6Range)); - assertEquals(LITERAL_FALSE, execute(IPv6AddressWithin, IPv4Range)); - } - - /** - * Builds and evaluates a CIDR function expression with the given field - * and range expression values, and returns the resulting value. - */ - private ExprValue execute(ExprValue field, ExprValue range) { - - final String fieldName = "ip_address"; - FunctionExpression exp = DSL.cidr(DSL.ref(fieldName, STRING), DSL.literal(range)); - - // Mock the value environment to return the specified field - // expression as the value for the "ip_address" field. - when(DSL.ref(fieldName, STRING).valueOf(env)).thenReturn(field); - - return exp.valueOf(env); - } - + // IP range and address constants for testing. + private static final ExprValue IPv4Range = ExprValueUtils.stringValue("198.51.100.0/24"); + private static final ExprValue IPv6Range = ExprValueUtils.stringValue("2001:0db8::/32"); + + private static final ExprValue IPv4AddressBelow = ExprValueUtils.stringValue("198.51.99.1"); + private static final ExprValue IPv4AddressWithin = ExprValueUtils.stringValue("198.51.100.1"); + private static final ExprValue IPv4AddressAbove = ExprValueUtils.stringValue("198.51.101.2"); + + private static final ExprValue IPv6AddressBelow = + ExprValueUtils.stringValue("2001:0db7::ff00:42:8329"); + private static final ExprValue IPv6AddressWithin = + ExprValueUtils.stringValue("2001:0db8::ff00:42:8329"); + private static final ExprValue IPv6AddressAbove = + ExprValueUtils.stringValue("2001:0db9::ff00:42:8329"); + + // Mock value environment for testing. + @Mock private Environment env; + + @Test + public void cidr_invalid_address() { + assertEquals(LITERAL_NULL, execute(ExprValueUtils.stringValue("INVALID"), IPv4Range)); + } + + @Test + public void cidr_invalid_range() { + assertThrows( + SemanticCheckException.class, + () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID"))); + assertThrows( + SemanticCheckException.class, + () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32"))); + assertThrows( + SemanticCheckException.class, + () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33"))); + } + + @Test + public void cidr_valid_ipv4() { + assertEquals(LITERAL_FALSE, execute(IPv4AddressBelow, IPv4Range)); + assertEquals(LITERAL_TRUE, execute(IPv4AddressWithin, IPv4Range)); + assertEquals(LITERAL_FALSE, execute(IPv4AddressAbove, IPv4Range)); + } + + @Test + public void cidr_valid_ipv6() { + assertEquals(LITERAL_FALSE, execute(IPv6AddressBelow, IPv6Range)); + assertEquals(LITERAL_TRUE, execute(IPv6AddressWithin, IPv6Range)); + assertEquals(LITERAL_FALSE, execute(IPv6AddressAbove, IPv6Range)); + } + + @Test + public void cidr_valid_different_versions() { + assertEquals(LITERAL_FALSE, execute(IPv4AddressWithin, IPv6Range)); + assertEquals(LITERAL_FALSE, execute(IPv6AddressWithin, IPv4Range)); + } + + /** + * Builds and evaluates a CIDR function expression with the given field and range expression + * values, and returns the resulting value. + */ + private ExprValue execute(ExprValue field, ExprValue range) { + + final String fieldName = "ip_address"; + FunctionExpression exp = DSL.cidr(DSL.ref(fieldName, STRING), DSL.literal(range)); + + // Mock the value environment to return the specified field + // expression as the value for the "ip_address" field. + when(DSL.ref(fieldName, STRING).valueOf(env)).thenReturn(field); + + return exp.valueOf(env); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/IPFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/IPFunctionIT.java index 88b9948404..fee60b23a7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/IPFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/IPFunctionIT.java @@ -5,39 +5,50 @@ package org.opensearch.sql.ppl; -import org.json.JSONObject; -import org.junit.jupiter.api.Test; - -import java.io.IOException; - import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WEBLOG; import static org.opensearch.sql.util.MatcherUtils.*; -public class IPFunctionIT extends PPLIntegTestCase { - - @Override - public void init() throws IOException { - loadIndex(Index.WEBLOG); - } - - @Test - public void testCidr() throws IOException { - - JSONObject result; - - // No matches - result = executeQuery(String.format("source=%s | where cidr(host, '199.120.111.0/24') | fields url", TEST_INDEX_WEBLOG)); - verifySchema(result, schema("url", null, "boolean")); - verifyDataRows(result, rows("/shuttle/missions/sts-73/mission-sts-73.html")); +import java.io.IOException; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; - // One match - result = executeQuery(String.format("source=%s | where cidr(host, '199.120.110.0/24') | fields url", TEST_INDEX_WEBLOG)); - verifySchema(result, schema("url", null, "boolean")); - verifyDataRows(result, rows("/shuttle/missions/sts-73/mission-sts-73.html")); +public class IPFunctionIT extends PPLIntegTestCase { - // Multiple matches - result = executeQuery(String.format("source=%s | where cidr(host, '199.0.0.0/8') | fields url", TEST_INDEX_WEBLOG)); - verifySchema(result, schema("url", null, "boolean")); - verifyDataRows(result, rows("/history/apollo/"), rows("/shuttle/missions/sts-73/mission-sts-73.html")); - } + @Override + public void init() throws IOException { + loadIndex(Index.WEBLOG); + } + + @Test + public void testCidr() throws IOException { + + JSONObject result; + + // No matches + result = + executeQuery( + String.format( + "source=%s | where cidr(host, '199.120.111.0/24') | fields url", + TEST_INDEX_WEBLOG)); + verifySchema(result, schema("url", null, "boolean")); + verifyDataRows(result, rows("/shuttle/missions/sts-73/mission-sts-73.html")); + + // One match + result = + executeQuery( + String.format( + "source=%s | where cidr(host, '199.120.110.0/24') | fields url", + TEST_INDEX_WEBLOG)); + verifySchema(result, schema("url", null, "boolean")); + verifyDataRows(result, rows("/shuttle/missions/sts-73/mission-sts-73.html")); + + // Multiple matches + result = + executeQuery( + String.format( + "source=%s | where cidr(host, '199.0.0.0/8') | fields url", TEST_INDEX_WEBLOG)); + verifySchema(result, schema("url", null, "boolean")); + verifyDataRows( + result, rows("/history/apollo/"), rows("/shuttle/missions/sts-73/mission-sts-73.html")); + } }