Skip to content

Commit

Permalink
Reformat new files with Spotless.
Browse files Browse the repository at this point in the history
Signed-off-by: currantw <[email protected]>
  • Loading branch information
currantw committed Oct 29, 2024
1 parent bef53cc commit 8533862
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 175 deletions.
156 changes: 81 additions & 75 deletions core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@

package org.opensearch.sql.expression.ip;

import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.*;

import com.google.common.net.InetAddresses;
import java.math.BigInteger;
import java.net.InetAddress;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
Expand All @@ -14,87 +22,85 @@
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.DefaultFunctionResolver;

import java.math.BigInteger;
import java.net.InetAddress;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.*;

/**
* Utility class that defines and registers IP functions.
*/
/** Utility class that defines and registers IP functions. */
@UtilityClass
public class IPFunction {

private static final Pattern cidrPattern = Pattern.compile("(?<address>.+)[/](?<networkLength>[0-9]+)");

public void register(BuiltinFunctionRepository repository) {
repository.register(cidr());
private static final Pattern cidrPattern =
Pattern.compile("(?<address>.+)[/](?<networkLength>[0-9]+)");

public void register(BuiltinFunctionRepository repository) {
repository.register(cidr());
}

private DefaultFunctionResolver cidr() {
return define(
BuiltinFunctionName.CIDR.getName(),
impl(nullMissingHandling(IPFunction::exprCidr), BOOLEAN, STRING, STRING));
}

/**
* Returns whether the given IP address is within the specified IP address range. Supports both
* IPv4 and IPv6 addresses.
*
* @param addressExprValue IP address (e.g. "198.51.100.14" or "2001:0db8::ff00:42:8329").
* @param rangeExprValue IP address range in CIDR notation (e.g. "198.51.100.0/24" or
* "2001:0db8::/32")
* @return null if the address is not valid; true if the address is in the range; otherwise false.
* @throws SemanticCheckException if the range is not valid
*/
private ExprValue exprCidr(ExprValue addressExprValue, ExprValue rangeExprValue) {

// Get address
String addressString = addressExprValue.stringValue();
if (!InetAddresses.isInetAddress(addressString)) {
return ExprValueUtils.nullValue();
}

private DefaultFunctionResolver cidr() {
return define(BuiltinFunctionName.CIDR.getName(), impl(nullMissingHandling(IPFunction::exprCidr), BOOLEAN, STRING, STRING));
InetAddress address = InetAddresses.forString(addressString);

// Get range and network length
String rangeString = rangeExprValue.stringValue();

Matcher cidrMatcher = cidrPattern.matcher(rangeString);
if (!cidrMatcher.matches())
throw new SemanticCheckException(
String.format("CIDR notation '%s' in not valid", rangeString));

String rangeAddressString = cidrMatcher.group("address");
if (!InetAddresses.isInetAddress(rangeAddressString))
throw new SemanticCheckException(
String.format("IP address '%s' in not valid", rangeAddressString));

InetAddress rangeAddress = InetAddresses.forString(rangeAddressString);

// Address and range must use the same IP version (IPv4 or IPv6).
if (!address.getClass().equals(rangeAddress.getClass())) {
return ExprValueUtils.booleanValue(false);
}

/**
* Returns whether the given IP address is within the specified IP address range.
* Supports both IPv4 and IPv6 addresses.
*
* @param addressExprValue IP address (e.g. "198.51.100.14" or "2001:0db8::ff00:42:8329").
* @param rangeExprValue IP address range in CIDR notation (e.g. "198.51.100.0/24" or "2001:0db8::/32")
* @return null if the address is not valid; true if the address is in the range; otherwise false.
* @throws SemanticCheckException if the range is not valid
*/
private ExprValue exprCidr(ExprValue addressExprValue, ExprValue rangeExprValue) {

// Get address
String addressString = addressExprValue.stringValue();
if (!InetAddresses.isInetAddress(addressString)) {
return ExprValueUtils.nullValue();
}

InetAddress address = InetAddresses.forString(addressString);

// Get range and network length
String rangeString = rangeExprValue.stringValue();

Matcher cidrMatcher = cidrPattern.matcher(rangeString);
if (!cidrMatcher.matches())
throw new SemanticCheckException(String.format("CIDR notation '%s' in not valid", rangeString));

String rangeAddressString = cidrMatcher.group("address");
if (!InetAddresses.isInetAddress(rangeAddressString))
throw new SemanticCheckException(String.format("IP address '%s' in not valid", rangeAddressString));

InetAddress rangeAddress = InetAddresses.forString(rangeAddressString);

// Address and range must use the same IP version (IPv4 or IPv6).
if (!address.getClass().equals(rangeAddress.getClass())) {
return ExprValueUtils.booleanValue(false);
}

int networkLengthBits = Integer.parseInt(cidrMatcher.group("networkLength"));
int addressLengthBits = address.getAddress().length * Byte.SIZE;

if (networkLengthBits > addressLengthBits)
throw new SemanticCheckException(String.format("Network length of '%s' bits is not valid", networkLengthBits));

// Build bounds by converting the address to an integer, setting all the non-significant bits to
// zero for the lower bounds and one for the upper bounds, and then converting back to addresses.
BigInteger lowerBoundInt = InetAddresses.toBigInteger(rangeAddress);
BigInteger upperBoundInt = InetAddresses.toBigInteger(rangeAddress);

int hostLengthBits = addressLengthBits - networkLengthBits;
for (int bit = 0; bit < hostLengthBits; bit++) {
lowerBoundInt = lowerBoundInt.clearBit(bit);
upperBoundInt = upperBoundInt.setBit(bit);
}

// Convert the address to an integer and compare it to the bounds.
BigInteger addressInt = InetAddresses.toBigInteger(address);
return ExprValueUtils.booleanValue((addressInt.compareTo(lowerBoundInt) >= 0) && (addressInt.compareTo(upperBoundInt) <= 0));
int networkLengthBits = Integer.parseInt(cidrMatcher.group("networkLength"));
int addressLengthBits = address.getAddress().length * Byte.SIZE;

if (networkLengthBits > addressLengthBits)
throw new SemanticCheckException(
String.format("Network length of '%s' bits is not valid", networkLengthBits));

// Build bounds by converting the address to an integer, setting all the non-significant bits to
// zero for the lower bounds and one for the upper bounds, and then converting back to
// addresses.
BigInteger lowerBoundInt = InetAddresses.toBigInteger(rangeAddress);
BigInteger upperBoundInt = InetAddresses.toBigInteger(rangeAddress);

int hostLengthBits = addressLengthBits - networkLengthBits;
for (int bit = 0; bit < hostLengthBits; bit++) {
lowerBoundInt = lowerBoundInt.clearBit(bit);
upperBoundInt = upperBoundInt.setBit(bit);
}

// Convert the address to an integer and compare it to the bounds.
BigInteger addressInt = InetAddresses.toBigInteger(address);
return ExprValueUtils.booleanValue(
(addressInt.compareTo(lowerBoundInt) >= 0) && (addressInt.compareTo(upperBoundInt) <= 0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

package org.opensearch.sql.expression.ip;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.model.ExprValueUtils.*;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand All @@ -17,77 +23,78 @@
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.env.Environment;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.model.ExprValueUtils.*;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;

@ExtendWith(MockitoExtension.class)
public class IPFunctionTest {

// IP range and address constants for testing.
private static final ExprValue IPv4Range = ExprValueUtils.stringValue("198.51.100.0/24");
private static final ExprValue IPv6Range = ExprValueUtils.stringValue("2001:0db8::/32");

private static final ExprValue IPv4AddressBelow = ExprValueUtils.stringValue("198.51.99.1");
private static final ExprValue IPv4AddressWithin = ExprValueUtils.stringValue("198.51.100.1");
private static final ExprValue IPv4AddressAbove = ExprValueUtils.stringValue("198.51.101.2");

private static final ExprValue IPv6AddressBelow = ExprValueUtils.stringValue("2001:0db7::ff00:42:8329");
private static final ExprValue IPv6AddressWithin = ExprValueUtils.stringValue("2001:0db8::ff00:42:8329");
private static final ExprValue IPv6AddressAbove = ExprValueUtils.stringValue("2001:0db9::ff00:42:8329");

// Mock value environment for testing.
@Mock
private Environment<Expression, ExprValue> env;

@Test
public void cidr_invalid_address() {
assertEquals(LITERAL_NULL, execute(ExprValueUtils.stringValue("INVALID"), IPv4Range));
}

@Test
public void cidr_invalid_range() {
assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID")));
assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32")));
assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33")));
}

@Test
public void cidr_valid_ipv4() {
assertEquals(LITERAL_FALSE, execute(IPv4AddressBelow, IPv4Range));
assertEquals(LITERAL_TRUE, execute(IPv4AddressWithin, IPv4Range));
assertEquals(LITERAL_FALSE, execute(IPv4AddressAbove, IPv4Range));
}

@Test
public void cidr_valid_ipv6() {
assertEquals(LITERAL_FALSE, execute(IPv6AddressBelow, IPv6Range));
assertEquals(LITERAL_TRUE, execute(IPv6AddressWithin, IPv6Range));
assertEquals(LITERAL_FALSE, execute(IPv6AddressAbove, IPv6Range));
}

@Test
public void cidr_valid_different_versions() {
assertEquals(LITERAL_FALSE, execute(IPv4AddressWithin, IPv6Range));
assertEquals(LITERAL_FALSE, execute(IPv6AddressWithin, IPv4Range));
}

/**
* Builds and evaluates a CIDR function expression with the given field
* and range expression values, and returns the resulting value.
*/
private ExprValue execute(ExprValue field, ExprValue range) {

final String fieldName = "ip_address";
FunctionExpression exp = DSL.cidr(DSL.ref(fieldName, STRING), DSL.literal(range));

// Mock the value environment to return the specified field
// expression as the value for the "ip_address" field.
when(DSL.ref(fieldName, STRING).valueOf(env)).thenReturn(field);

return exp.valueOf(env);
}

// IP range and address constants for testing.
private static final ExprValue IPv4Range = ExprValueUtils.stringValue("198.51.100.0/24");
private static final ExprValue IPv6Range = ExprValueUtils.stringValue("2001:0db8::/32");

private static final ExprValue IPv4AddressBelow = ExprValueUtils.stringValue("198.51.99.1");
private static final ExprValue IPv4AddressWithin = ExprValueUtils.stringValue("198.51.100.1");
private static final ExprValue IPv4AddressAbove = ExprValueUtils.stringValue("198.51.101.2");

private static final ExprValue IPv6AddressBelow =
ExprValueUtils.stringValue("2001:0db7::ff00:42:8329");
private static final ExprValue IPv6AddressWithin =
ExprValueUtils.stringValue("2001:0db8::ff00:42:8329");
private static final ExprValue IPv6AddressAbove =
ExprValueUtils.stringValue("2001:0db9::ff00:42:8329");

// Mock value environment for testing.
@Mock private Environment<Expression, ExprValue> env;

@Test
public void cidr_invalid_address() {
assertEquals(LITERAL_NULL, execute(ExprValueUtils.stringValue("INVALID"), IPv4Range));
}

@Test
public void cidr_invalid_range() {
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID")));
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32")));
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33")));
}

@Test
public void cidr_valid_ipv4() {
assertEquals(LITERAL_FALSE, execute(IPv4AddressBelow, IPv4Range));
assertEquals(LITERAL_TRUE, execute(IPv4AddressWithin, IPv4Range));
assertEquals(LITERAL_FALSE, execute(IPv4AddressAbove, IPv4Range));
}

@Test
public void cidr_valid_ipv6() {
assertEquals(LITERAL_FALSE, execute(IPv6AddressBelow, IPv6Range));
assertEquals(LITERAL_TRUE, execute(IPv6AddressWithin, IPv6Range));
assertEquals(LITERAL_FALSE, execute(IPv6AddressAbove, IPv6Range));
}

@Test
public void cidr_valid_different_versions() {
assertEquals(LITERAL_FALSE, execute(IPv4AddressWithin, IPv6Range));
assertEquals(LITERAL_FALSE, execute(IPv6AddressWithin, IPv4Range));
}

/**
* Builds and evaluates a CIDR function expression with the given field and range expression
* values, and returns the resulting value.
*/
private ExprValue execute(ExprValue field, ExprValue range) {

final String fieldName = "ip_address";
FunctionExpression exp = DSL.cidr(DSL.ref(fieldName, STRING), DSL.literal(range));

// Mock the value environment to return the specified field
// expression as the value for the "ip_address" field.
when(DSL.ref(fieldName, STRING).valueOf(env)).thenReturn(field);

return exp.valueOf(env);
}
}
Loading

0 comments on commit 8533862

Please sign in to comment.