Skip to content

Commit

Permalink
[release/9.0] Return null when the type is nullable for Cosmos Max/Mi…
Browse files Browse the repository at this point in the history
…n/Average (#35216)

* Return null when the type is nullable for Cosmos Max/Min/Average (#35173)

* Return null when the type is nullable for Cosmos Max/Min/Average

Fixes #35094

This was a regression resulting from the major Cosmos query refactoring that happened in EF9. In EF8, the functions Min, Max, and Average would return null if the return type was nullable or was cast to a nullable when the collection is empty. In EF9, this started throwing, which is correct for non-nullable types, but a regression for nullable types.

* Added notes

* Added quirks

* Fix tests.
  • Loading branch information
ajcvickers authored Nov 27, 2024
1 parent 08b4d43 commit 59e92ae
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// </summary>
public class CosmosQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
{
private static readonly bool UseOldBehavior35094 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35094", out var enabled) && enabled;

private readonly CosmosQueryCompilationContext _queryCompilationContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly ITypeMappingSource _typeMappingSource;
Expand Down Expand Up @@ -445,23 +448,29 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec
/// </summary>
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
if (UseOldBehavior35094)
{
return null;
}
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}
if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, projection.Type, projection.TypeMapping);
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, projection.Type, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);

}

return TranslateAggregate(source, selector, resultType, "AVG");
}

/// <summary>
Expand Down Expand Up @@ -843,24 +852,29 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
/// </summary>
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
if (UseOldBehavior35094)
{
return null;
}
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}
if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);
projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}

return TranslateAggregate(source, selector, resultType, "MAX");
}

/// <summary>
Expand All @@ -871,24 +885,29 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
/// </summary>
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
if (UseOldBehavior35094)
{
return null;
}
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}
if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);
projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}

return TranslateAggregate(source, selector, resultType, "MIN");
}

/// <summary>
Expand Down Expand Up @@ -1520,6 +1539,35 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s

#endregion Queryable collection support

private ShapedQueryExpression? TranslateAggregate(ShapedQueryExpression source, LambdaExpression? selector, Type resultType, string functionName)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}

if (!_subquery && resultType.IsNullableType())
{
// For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094.
// Note that relational databases typically return null, which propagates. Cosmos will instead return no elements,
// and hence for Cosmos only we need to change no elements into null.
source = source.UpdateResultCardinality(ResultCardinality.SingleOrDefault);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
projection = _sqlExpressionFactory.Function(functionName, [projection], resultType, _typeMappingSource.FindMapping(resultType));

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}

private bool TryApplyPredicate(ShapedQueryExpression source, LambdaExpression predicate)
{
var select = (SelectExpression)source.QueryExpression;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.ComponentModel.DataAnnotations.Schema;

namespace Microsoft.EntityFrameworkCore.Query;

#nullable disable
Expand Down Expand Up @@ -50,6 +52,115 @@ public enum MemberType

#endregion 34911

#region 35094

// TODO: Move these tests to a better location. They require nullable properties with nulls in the database.

[ConditionalFact]
public virtual async Task Min_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().MinAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Min_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).MinAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableRef == null).MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_no_data()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.Id < 0).MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Equal(3.14, await context.Set<Context35094.Product>().MaxAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Max_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).MaxAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Equal("Value", await context.Set<Context35094.Product>().MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableRef == null).MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_no_data()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.Id < 0).MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Average_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().AverageAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Average_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).AverageAsync(p => p.NullableVal));
}

protected class Context35094(DbContextOptions options) : DbContext(options)
{
public DbSet<Product> Products { get; set; }

protected override void OnModelCreating(ModelBuilder modelBuilder)
=> modelBuilder.Entity<Product>().HasData(
new Product { Id = 1, NullableRef = "Value", NullableVal = 3.14 },
new Product { Id = 2, NullableVal = 3.14 },
new Product { Id = 3, NullableRef = "Value" });

public class Product
{
[DatabaseGenerated(DatabaseGeneratedOption.None)]
public int Id { get; set; }
public double? NullableVal { get; set; }
public string NullableRef { get; set; }
}
}

#endregion 35094

protected override string StoreName
=> "AdHocMiscellaneousQueryTests";

Expand Down
Loading

0 comments on commit 59e92ae

Please sign in to comment.