Skip to content

Commit

Permalink
add support for EmbeddedModelField
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Nov 4, 2024
1 parent 087000e commit df74d7b
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 2 deletions.
2 changes: 1 addition & 1 deletion django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def execute_sql(self, result_type):
elif hasattr(value, "prepare_database_save"):
if field.remote_field:
value = value.prepare_database_save(field)
else:
elif not hasattr(field, "embedded_model"):
raise TypeError(
f"Tried to update field {field} with a model "
f"instance, {value!r}. Use a value compatible with "
Expand Down
3 changes: 2 additions & 1 deletion django_mongodb/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .auto import ObjectIdAutoField
from .duration import register_duration_field
from .embedded_model import EmbeddedModelField
from .json import register_json_field

__all__ = ["register_fields", "ObjectIdAutoField"]
__all__ = ["register_fields", "EmbeddedModelField", "ObjectIdAutoField"]


def register_fields():
Expand Down
166 changes: 166 additions & 0 deletions django_mongodb/fields/embedded_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from importlib import import_module

from django.db import IntegrityError, models
from django.db.models.fields.related import lazy_related_operation


class EmbeddedModelField(models.Field):
"""Field that stores a model instance."""

def __init__(self, embedded_model=None, *args, **kwargs):
"""
`embedded_model` is the model class of the instance that will be
stored. Like other relational fields, it may also be passed as a
string.
"""
self.embedded_model = embedded_model
super().__init__(*args, **kwargs)

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if path.startswith("django_mongodb.fields.embedded_model"):
path = path.replace("django_mongodb.fields.embedded_model", "django_mongodb.fields")
if self.embedded_model:
kwargs["embedded_model"] = self.embedded_model
return name, path, args, kwargs

def get_internal_type(self):
return "EmbeddedModelField"

def _set_model(self, model):
"""
Resolve embedded model class once the field knows the model it belongs
to.
If the model argument passed to __init__() was a string, resolve that
string to the corresponding model class, similar to relation fields.
However, we need to know our own model to generate a valid key
for the embedded model class lookup and EmbeddedModelFields are
not contributed_to_class if used in iterable fields. Thus the
collection field sets this field's "model" attribute in its
contribute_to_class().
"""
self._model = model
if model is not None and isinstance(self.embedded_model, str):

def _resolve_lookup(_, resolved_model):
self.embedded_model = resolved_model

lazy_related_operation(_resolve_lookup, model, self.embedded_model)

model = property(lambda self: self._model, _set_model)

def stored_model(self, column_values):
"""
Return the fixed embedded_model this field was initialized
with (typed embedding) or tries to determine the model from
_module / _model keys stored together with column_values
(untyped embedding).
Give precedence to the field's definition model, as silently using a
differing serialized one could hide some data integrity problems.
Note that a single untyped EmbeddedModelField may process
instances of different models (especially when used as a type
of a collection field).
"""
module = column_values.pop("_module", None)
model = column_values.pop("_model", None)
if self.embedded_model is not None:
return self.embedded_model
if module is not None:
return getattr(import_module(module), model)
raise IntegrityError(
"Untyped EmbeddedModelField trying to load data without serialized model class info."
)

def from_db_value(self, value, expression, connection):
return self.to_python(value)

def to_python(self, value):
"""
Passes embedded model fields' values through embedded fields
to_python methods and reinstiatates the embedded instance.
We expect to receive a field.attname => value dict together
with a model class from back-end database deconversion (which
needs to know fields of the model beforehand).
"""
# Either the model class has already been determined during
# deconverting values from the database or we've got a dict
# from a deserializer that may contain model class info.
if isinstance(value, tuple):
embedded_model, attribute_values = value
elif isinstance(value, dict):
embedded_model = self.stored_model(value)
attribute_values = value
else:
return value
# Create the model instance.
instance = embedded_model(
**{
# Pass values through respective fields' to_python(), leaving
# fields for which no value is specified uninitialized.
field.attname: field.to_python(attribute_values[field.attname])
for field in embedded_model._meta.fields
if field.attname in attribute_values
}
)
instance._state.adding = False
return instance

def get_db_prep_save(self, embedded_instance, connection):
"""
Apply pre_save() and get_db_prep_save() of embedded instance
fields and passes a field => value mapping down to database
type conversions.
The embedded instance will be saved as a column => value dict
in the end (possibly augmented with info about instance's model
for untyped embedding), but because we need to apply database
type conversions on embedded instance fields' values and for
these we need to know fields those values come from, we need to
entrust the database layer with creating the dict.
"""
if embedded_instance is None:
return None
# The field's value should be an instance of the model given in
# its declaration or at least of some model.
embedded_model = self.embedded_model or models.Model
if not isinstance(embedded_instance, embedded_model):
raise TypeError(
f"Expected instance of type {embedded_model!r}, not {type(embedded_instance)!r}."
)
# Apply pre_save() and get_db_prep_save() of embedded instance
# fields, create the field => value mapping to be passed to
# storage preprocessing.
field_values = {}
add = embedded_instance._state.adding
for field in embedded_instance._meta.fields:
value = field.get_db_prep_save(
field.pre_save(embedded_instance, add), connection=connection
)
# Exclude unset primary keys (e.g. {'id': None}).
if field.primary_key and value is None:
continue
field_values[field.attname] = value
if self.embedded_model is None:
# Untyped fields must store model info alongside values.
field_values.update(
(
("_module", embedded_instance.__class__.__module__),
("_model", embedded_instance.__class__.__name__),
)
)
# This instance will exist in the database soon.
# TODO.XXX: Ensure that this doesn't cause race conditions.
embedded_instance._state.adding = False
return field_values

def validate(self, value, model_instance):
super().validate(value, model_instance)
if self.embedded_model is None:
return
for field in self.embedded_model._meta.fields:
attname = field.attname
field.validate(getattr(value, attname), model_instance)
32 changes: 32 additions & 0 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from django.db import models

from django_mongodb.fields import EmbeddedModelField


class Target(models.Model):
index = models.IntegerField()


class DecimalModel(models.Model):
decimal = models.DecimalField(max_digits=9, decimal_places=2)


class DecimalKey(models.Model):
decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True)


class DecimalParent(models.Model):
child = models.ForeignKey(DecimalKey, models.CASCADE)


class EmbeddedModelFieldModel(models.Model):
simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True)
untyped = EmbeddedModelField(null=True, blank=True)
decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True)


class EmbeddedModel(models.Model):
some_relation = models.ForeignKey(Target, models.CASCADE, null=True, blank=True)
someint = models.IntegerField(db_column="custom_column")
auto_now = models.DateTimeField(auto_now=True)
auto_now_add = models.DateTimeField(auto_now_add=True)
125 changes: 125 additions & 0 deletions tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import time
from decimal import Decimal

from django.core.exceptions import ValidationError
from django.db import models
from django.test import SimpleTestCase, TestCase

from django_mongodb.fields import EmbeddedModelField

from .models import (
DecimalKey,
DecimalParent,
EmbeddedModel,
EmbeddedModelFieldModel,
Target,
)


class MethodTests(SimpleTestCase):
def test_deconstruct(self):
field = EmbeddedModelField()
name, path, args, kwargs = field.deconstruct()
self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField")
self.assertEqual(args, [])
self.assertEqual(kwargs, {})

def test_deconstruct_with_model(self):
field = EmbeddedModelField("EmbeddedModel", null=True)
name, path, args, kwargs = field.deconstruct()
self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField")
self.assertEqual(args, [])
self.assertEqual(kwargs, {"embedded_model": "EmbeddedModel", "null": True})

def test_validate(self):
instance = EmbeddedModelFieldModel(simple=EmbeddedModel(someint=None))
# This isn't quite right because "someint" is the field that's non-null.
msg = "{'simple': ['This field cannot be null.']}"
with self.assertRaisesMessage(ValidationError, msg):
instance.full_clean()


class QueryingTests(TestCase):
def assertEqualDatetime(self, d1, d2):
"""Compares d1 and d2, ignoring microseconds."""
self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))

def assertNotEqualDatetime(self, d1, d2):
self.assertNotEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))

def test_save_load(self):
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
instance = EmbeddedModelFieldModel.objects.get()
self.assertIsInstance(instance.simple, EmbeddedModel)
# Make sure get_prep_value is called.
self.assertEqual(instance.simple.someint, 5)
# Primary keys should not be populated...
self.assertEqual(instance.simple.id, None)
# ... unless set explicitly.
instance.simple.id = instance.id
instance.save()
instance = EmbeddedModelFieldModel.objects.get()
self.assertEqual(instance.simple.id, instance.id)

def test_save_load_untyped(self):
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
instance = EmbeddedModelFieldModel.objects.get()
self.assertIsInstance(instance.simple, EmbeddedModel)
# Make sure get_prep_value is called.
self.assertEqual(instance.simple.someint, 5)
# Primary keys should not be populated...
self.assertEqual(instance.simple.id, None)
# ... unless set explicitly.
instance.simple.id = instance.id
instance.save()
instance = EmbeddedModelFieldModel.objects.get()
self.assertEqual(instance.simple.id, instance.id)

def _test_pre_save(self, instance, get_field):
# Field.pre_save() is called on embedded model fields.

instance.save()
auto_now = get_field(instance).auto_now
auto_now_add = get_field(instance).auto_now_add
self.assertNotEqual(auto_now, None)
self.assertNotEqual(auto_now_add, None)

time.sleep(1) # FIXME
instance.save()
self.assertNotEqualDatetime(get_field(instance).auto_now, get_field(instance).auto_now_add)

instance = EmbeddedModelFieldModel.objects.get()
instance.save()
# auto_now_add shouldn't have changed now, but auto_now should.
self.assertEqualDatetime(get_field(instance).auto_now_add, auto_now_add)
self.assertGreater(get_field(instance).auto_now, auto_now)

def test_pre_save(self):
obj = EmbeddedModelFieldModel(simple=EmbeddedModel())
self._test_pre_save(obj, lambda instance: instance.simple)

def test_pre_save_untyped(self):
obj = EmbeddedModelFieldModel(untyped=EmbeddedModel())
self._test_pre_save(obj, lambda instance: instance.untyped)

def test_error_messages(self):
for model_kwargs, expected in (
({"simple": 42}, EmbeddedModel),
({"untyped": 42}, models.Model),
):
msg = "Expected instance of type %r" % expected
with self.assertRaisesMessage(TypeError, msg):
EmbeddedModelFieldModel(**model_kwargs).save()

def test_foreign_key_in_embedded_object(self):
simple = EmbeddedModel(some_relation=Target.objects.create(index=1))
obj = EmbeddedModelFieldModel.objects.create(simple=simple)
simple = EmbeddedModelFieldModel.objects.get().simple
self.assertNotIn("some_relation", simple.__dict__)
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id))
self.assertIsInstance(simple.some_relation, Target)

def test_embedded_field_with_foreign_conversion(self):
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
decimal_parent = DecimalParent.objects.create(child=decimal)
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)

0 comments on commit df74d7b

Please sign in to comment.