Skip to content

Commit

Permalink
Use System.Linq.Expressions.ExpressionVisitor where available (#182)
Browse files Browse the repository at this point in the history
* Use System.Linq.Expressions.ExpressionVisitor where available

* Add tests for Block and Throw expressions
  • Loading branch information
TheConstructor authored Mar 18, 2023
1 parent d2af653 commit e978e34
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
#if NOEF
using System;
#if NET35
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq.Expressions;
using JetBrains.Annotations;

// ReSharper disable once CheckNamespace
namespace LinqKit
namespace System.Linq.Expressions
{
/// <summary>
/// This comes from Matt Warren's sample:
/// http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx
/// </summary>
public abstract class ExpressionVisitor
internal abstract class ExpressionVisitor
{
/// <summary> Visit expression tree </summary>
[Pure]
Expand Down Expand Up @@ -69,7 +67,7 @@ public virtual Expression Visit(Expression exp)
case ExpressionType.Parameter:
return VisitParameter((ParameterExpression)exp);
case ExpressionType.MemberAccess:
return VisitMemberAccess((MemberExpression)exp);
return VisitMember((MemberExpression)exp);
case ExpressionType.Call:
return VisitMethodCall((MethodCallExpression)exp);
case ExpressionType.Lambda:
Expand All @@ -85,32 +83,11 @@ public virtual Expression Visit(Expression exp)
return VisitMemberInit((MemberInitExpression)exp);
case ExpressionType.ListInit:
return VisitListInit((ListInitExpression)exp);
#if !NET35
case ExpressionType.Index:
return VisitIndex((IndexExpression)exp);

case ExpressionType.Extension:
return VisitExtension(exp);
#endif
default:
throw new Exception($"Unhandled expression type: '{exp.NodeType}'");
}
}

/// <summary>
/// Visit Extension expression to fix bugs:
/// - https://github.com/scottksmith95/LINQKit/issues/116
/// - https://github.com/scottksmith95/LINQKit/issues/118
///
/// TODO (2020-07-16) I'm not sure if just returning the expression will work in all cases...
///
/// See also https://nejcskofic.github.io/2017/07/30/extending-linq-expressions/
/// </summary>
protected virtual Expression VisitExtension(Expression extensionExpression)
{
return extensionExpression;
}

/// <summary> Visit member binding </summary>
protected virtual MemberBinding VisitBinding(MemberBinding binding)
{
Expand Down Expand Up @@ -214,7 +191,7 @@ protected virtual Expression VisitParameter(ParameterExpression p)
}

/// <summary> Visit member access </summary>
protected virtual Expression VisitMemberAccess(MemberExpression m)
protected virtual Expression VisitMember(MemberExpression m)
{
Expression exp = Visit(m.Expression);
if (exp != m.Expression)
Expand Down Expand Up @@ -260,11 +237,7 @@ protected virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCol

if (list != null)
{
#if (PORTABLE || PORTABLE40)
return new ReadOnlyCollection<Expression>(list);
#else
return list.AsReadOnly();
#endif
}

return original;
Expand Down Expand Up @@ -405,20 +378,6 @@ protected virtual Expression VisitInvocation(InvocationExpression iv)
Expression expr = Visit(iv.Expression);
return args != iv.Arguments || expr != iv.Expression ? Expression.Invoke(expr, args) : iv;
}

#if !NET35
/// <summary> Visit index expression </summary>
protected virtual Expression VisitIndex(IndexExpression exp)
{
var obj = Visit(exp.Object);
var args = VisitExpressionList(exp.Arguments);
if (obj != exp.Object || args != exp.Arguments)
{
return Expression.MakeIndex(obj, exp.Indexer, args);
}
return exp;
}
#endif
}
}
#endif
4 changes: 2 additions & 2 deletions src/LinqKit.Core/ExpressionExpander.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ protected override Expression VisitMethodCall(MethodCallExpression m)
return base.VisitMethodCall(m);
}

protected override Expression VisitMemberAccess(MemberExpression m)
protected override Expression VisitMember(MemberExpression m)
{
if (GetExpandLambda(m.Member, out var methodLambda))
{
Expand All @@ -186,7 +186,7 @@ protected override Expression VisitMemberAccess(MemberExpression m)
// Strip out any references to expressions captured by outer variables - LINQ to SQL can't handle these:
return m.Member.DeclaringType != null && m.Member.DeclaringType.Name.StartsWith("<>") ?
TransformExpr(m)
: base.VisitMemberAccess(m);
: base.VisitMember(m);
}

Expression TransformExpr(MemberExpression input)
Expand Down
2 changes: 1 addition & 1 deletion src/LinqKit.Net35/LinqKit.Net35.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
<Compile Include="..\LinqKit.Core\ExpressionStarter.cs">
<Link>ExpressionStarter.cs</Link>
</Compile>
<Compile Include="..\LinqKit.Core\ExpressionVisitor.cs">
<Compile Include="..\LinqKit.Core\Compatibility\ExpressionVisitor.cs">
<Link>ExpressionVisitor.cs</Link>
</Compile>
<Compile Include="..\LinqKit.Core\Extensions.cs">
Expand Down
3 changes: 0 additions & 3 deletions src/LinqKit.Net45/LinqKit.Net45.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@
<Compile Include="..\LinqKit.Core\ExpressionStarter.cs">
<Link>ExpressionStarter.cs</Link>
</Compile>
<Compile Include="..\LinqKit.Core\ExpressionVisitor.cs">
<Link>ExpressionVisitor.cs</Link>
</Compile>
<Compile Include="..\LinqKit.Core\Extensions.cs">
<Link>Extensions.cs</Link>
</Compile>
Expand Down
46 changes: 46 additions & 0 deletions tests/LinqKit.Tests.Net452/ExpressionExpanderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,51 @@ public void ExpressionExpander_Expression_Index()
var executed = Linq.Expr((List<string> l) => expression2.Invoke(expression1.Invoke(l))).Expand().ToString();
Assert.Equal(expression1.ToString(), executed);
}

[Fact]
public void ExpressionExpander_Expression_Block()
{
var objParameter = Expression.Parameter(typeof(object), "o");
var objVar = Expression.Variable(typeof(object));

var lambda = Expression.Lambda<Func<object, string>>(Expression.Block(new[] {objVar},
Expression.Assign(objVar, objParameter),
Expression.Call(objVar, nameof(ToString), Type.EmptyTypes)),
objParameter);

var expandedLambda = Linq.Expr((object o) => lambda.Invoke(o))
.Expand();
Assert.Equal(lambda.ToString(), expandedLambda.ToString());
Assert.Equal(lambda.Invoke(42), expandedLambda.Invoke(42));
}

[Fact]
public void ExpressionExpander_Expression_Throw()
{
var objParameter = Expression.Parameter(typeof(object), "o");
var msgParameter = Expression.Parameter(typeof(string), "msg");

var exceptionConstructor = typeof(ArgumentNullException).GetConstructor(new []{typeof(string)});

var lambda = Expression.Lambda<Func<object, string, object>>(
Expression.Condition(
Expression.Equal(objParameter, Expression.Constant(null)),
Expression.Throw(Expression.New(exceptionConstructor, msgParameter), typeof(object)),
objParameter),
objParameter,
msgParameter);

var expandedLambda = Linq.Expr((object o, string msg) => lambda.Invoke(o, msg))
.Expand();
Assert.Equal(lambda.ToString(), expandedLambda.ToString());
Assert.Equal("x",
Assert.Throws<ArgumentNullException>(() => lambda.Invoke(null, "x"))
.ParamName);
Assert.Equal("x",
Assert.Throws<ArgumentNullException>(() => expandedLambda.Invoke(null, "x"))
.ParamName);
var obj = new object();
Assert.Same(lambda.Invoke(obj, "x"), expandedLambda.Invoke(obj, "x"));
}
}
}

0 comments on commit e978e34

Please sign in to comment.