Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- fix foreign keys with tenancy
  • Loading branch information
devkral committed Jan 13, 2025
1 parent d24d558 commit 87773c8
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 23 deletions.
3 changes: 2 additions & 1 deletion docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ hide:
For models this can be passed : `class Foo(edgy.Model, on_conflict="keep"): ...`.
- Passing a tuple or list of types to `replace_related_field` is now allowed.
- Add `through_registry` to ManyToMany.
- Add `no_copy` to MetaInfo.
- Add `no_copy` to models MetaInfo.
- Add `ModelCollisionError` exception.

### Changed
Expand All @@ -32,6 +32,7 @@ hide:
- Fix deleting (clearing cache) of BaseForeignKey target.
- Creating two models with the same name did lead to silent replacements.
- Invalidating causes schema errors.
- ManyToMany and ForeignKey fields in connection with tenancy.

## 0.24.2

Expand Down
23 changes: 14 additions & 9 deletions edgy/core/connection/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.exc import DBAPIError, ProgrammingError

from edgy.core.connection.database import Database
from edgy.core.db.context_vars import NO_GLOBAL_FIELD_CONSTRAINTS
from edgy.exceptions import SchemaError

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,14 +72,16 @@ async def create_schema(
if init_models:
for model_class in self.registry.models.values():
model_class.table_schema(schema=schema, update_cache=update_cache)
if init_tenant_models and init_models:
for model_class in self.registry.tenant_models.values():
model_class.table_schema(schema=schema, update_cache=update_cache)
elif init_tenant_models:
if init_tenant_models:
token = NO_GLOBAL_FIELD_CONSTRAINTS.set(True)
try:
for model_class in self.registry.tenant_models.values():
tenant_tables.append(model_class.build(schema=schema))
finally:
NO_GLOBAL_FIELD_CONSTRAINTS.reset(token)
# we need two passes
for model_class in self.registry.tenant_models.values():
tenant_tables.append(
model_class.table_schema(schema=schema, update_cache=update_cache)
)
model_class.add_global_field_constraints(schema=schema)

def execute_create(connection: sqlalchemy.Connection, name: Optional[str]) -> None:
try:
Expand All @@ -87,8 +90,10 @@ def execute_create(connection: sqlalchemy.Connection, name: Optional[str]) -> No
)
except ProgrammingError as e:
raise SchemaError(detail=e.orig.args[0]) from e
for table in tenant_tables:
table.create(connection, checkfirst=if_not_exists)
if tenant_tables:
self.registry.metadata_by_name[name].create_all(
connection, checkfirst=if_not_exists, tables=tenant_tables
)
if init_models:
self.registry.metadata_by_name[name].create_all(
connection, checkfirst=if_not_exists
Expand Down
3 changes: 3 additions & 0 deletions edgy/core/db/context_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"CURRENT_MODEL_INSTANCE", default=None
)
CURRENT_PHASE: ContextVar[str] = ContextVar("CURRENT_PHASE", default="")
NO_GLOBAL_FIELD_CONSTRAINTS: ContextVar[bool] = ContextVar(
"NO_GLOBAL_FIELD_CONSTRAINTS", default=False
)
EXPLICIT_SPECIFIED_VALUES: ContextVar[Optional[set[str]]] = ContextVar(
"EXPLICIT_SPECIFIED_VALUES", default=None
)
Expand Down
16 changes: 14 additions & 2 deletions edgy/core/db/fields/many_to_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,13 @@ def create_through_model(
Generates a middle model based on the owner of the field and the field itself and adds
it to the main registry to make sure it generates the proper models and migrations.
"""
from edgy.contrib.multi_tenancy.base import TenantModel
from edgy.contrib.multi_tenancy.metaclasses import TenantMeta
from edgy.core.db.models.metaclasses import MetaInfo

__bases__: tuple[type[BaseModelType], ...] = ()
__bases__: tuple[type[BaseModelType], ...] = (
(TenantModel,) if getattr(self.owner.meta, "is_tenant", False) else ()
)
pknames = set()
if self.through:
through = self.through
Expand Down Expand Up @@ -243,7 +247,15 @@ def callback(model_class: type["BaseModelType"]) -> None:
if has_pknames:
meta_args["unique_together"] = [(self.from_foreign_key, self.to_foreign_key)]

new_meta: MetaInfo = MetaInfo(None, registry=False, no_copy=True, **meta_args)
# TenantMeta is compatible to normal meta
new_meta: MetaInfo = TenantMeta(
None,
registry=False,
no_copy=True,
is_tenant=getattr(self.owner.meta, "is_tenant", False),
register_default=getattr(self.owner.meta, "register_default", False),
**meta_args,
)

to_related_name: Union[str, Literal[False]]
if self.related_name is False:
Expand Down
5 changes: 4 additions & 1 deletion edgy/core/db/fields/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ def to_model(
return {field_name: value}

def get_global_constraints(
self, name: str, columns: Sequence[sqlalchemy.Column], schemes: Sequence[str] = ()
self,
name: str,
columns: Sequence[sqlalchemy.Column],
schemes: Sequence[str] = (),
) -> Sequence[Union[sqlalchemy.Constraint, sqlalchemy.Index]]:
"""Return global constraints and indexes.
Useful for multicolumn fields
Expand Down
53 changes: 49 additions & 4 deletions edgy/core/db/models/mixins/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CURRENT_INSTANCE,
EXPLICIT_SPECIFIED_VALUES,
MODEL_GETATTR_BEHAVIOR,
NO_GLOBAL_FIELD_CONSTRAINTS,
get_schema,
)
from edgy.core.db.datastructures import Index, UniqueConstraint
Expand Down Expand Up @@ -382,7 +383,7 @@ def copy_edgy_model(

@property
def table(self) -> sqlalchemy.Table:
if getattr(self, "_table", None) is None:
if self.__dict__.get("_table", None) is None:
schema = self.get_active_instance_schema()
return cast(
"sqlalchemy.Table",
Expand All @@ -391,13 +392,22 @@ def table(self) -> sqlalchemy.Table:
return self._table

@table.setter
def table(self, value: sqlalchemy.Table) -> None:
def table(self, value: Optional[sqlalchemy.Table]) -> None:
with contextlib.suppress(AttributeError):
del self._pknames
with contextlib.suppress(AttributeError):
del self._pkcolumns
self._table = value

@table.deleter
def table(self) -> None:
with contextlib.suppress(AttributeError):
del self._pknames
with contextlib.suppress(AttributeError):
del self._pkcolumns
with contextlib.suppress(AttributeError):
del self._table

@property
def pkcolumns(self) -> Sequence[str]:
if self.__dict__.get("_pkcolumns", None) is None:
Expand Down Expand Up @@ -669,7 +679,9 @@ async def save(

@classmethod
def build(
cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None
cls,
schema: Optional[str] = None,
metadata: Optional[sqlalchemy.MetaData] = None,
) -> sqlalchemy.Table:
"""
Builds the SQLAlchemy table representation from the loaded fields.
Expand Down Expand Up @@ -697,7 +709,10 @@ def build(
for name, field in cls.meta.fields.items():
current_columns = field.get_columns(name)
columns.extend(current_columns)
global_constraints.extend(field.get_global_constraints(name, current_columns, schemes))
if not NO_GLOBAL_FIELD_CONSTRAINTS.get():
global_constraints.extend(
field.get_global_constraints(name, current_columns, schemes)
)

# Handle the uniqueness together
uniques = []
Expand All @@ -723,6 +738,36 @@ def build(
else cls.get_active_class_schema(check_schema=False, check_tenant=False),
)

@classmethod
def add_global_field_constraints(
cls,
schema: Optional[str] = None,
metadata: Optional[sqlalchemy.MetaData] = None,
) -> sqlalchemy.Table:
"""
Add global constraints to table. Required for tenants.
"""
tablename: str = cls.meta.tablename
registry = cls.meta.registry
assert registry, "registry is not set"
if metadata is None:
metadata = registry.metadata_by_url[str(cls.database.url)]
schemes: list[str] = []
if schema:
schemes.append(schema)
if cls.__using_schema__ is not Undefined:
schemes.append(cls.__using_schema__)
db_schema = cls.get_db_schema() or ""
schemes.append(db_schema)
table = metadata.tables[tablename if not schema else f"{schema}.{tablename}"]
for name, field in cls.meta.fields.items():
current_columns: list[sqlalchemy.Column] = []
for column_name in cls.meta.field_to_column_names[name]:
current_columns.append(table.columns[column_name])
for constraint in field.get_global_constraints(name, current_columns, schemes):
table.append_constraint(constraint)
return table

@classmethod
def _get_unique_constraints(
cls, fields: Union[Sequence, str, sqlalchemy.UniqueConstraint]
Expand Down
4 changes: 3 additions & 1 deletion edgy/core/db/models/mixins/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def real_add_to_registry(cls: type["BaseModelType"], **kwargs: Any) -> type["Bas

@classmethod
def build(
cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None
cls,
schema: Optional[str] = None,
metadata: Optional[sqlalchemy.MetaData] = None,
) -> Any:
"""
The inspect is done in an async manner and reflects the objects from the database.
Expand Down
4 changes: 3 additions & 1 deletion edgy/core/db/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def model_dump(self, show_pk: Union[bool, None] = None, **kwargs: Any) -> dict[s
@classmethod
@abstractmethod
def build(
cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None
cls,
schema: Optional[str] = None,
metadata: Optional[sqlalchemy.MetaData] = None,
) -> sqlalchemy.Table:
"""
Builds the SQLAlchemy table representation from the loaded fields.
Expand Down
10 changes: 7 additions & 3 deletions edgy/core/db/relationships/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def expand_relationship(self, value: Any) -> Any:
**{self.from_foreign_key: self.instance, self.to_foreign_key: value}
)
instance.identifying_db_fields = [self.from_foreign_key, self.to_foreign_key] # type: ignore
if getattr(through.meta, "is_tenant", False):
instance.__using_schema__ = self.instance.get_active_instance_schema() # type: ignore
return instance

def stage(self, *children: "BaseModelType") -> None:
Expand Down Expand Up @@ -150,9 +152,9 @@ async def remove(self, child: Optional["BaseModelType"] = None) -> None:
try:
child = await self.get()
except ObjectNotFound:
raise RelationshipNotFound(detail="no child found") from None
raise RelationshipNotFound(detail="No child found.") from None
else:
raise RelationshipNotFound(detail="no child specified")
raise RelationshipNotFound(detail="No child specified.")
if not isinstance(
child,
(self.to, self.to.proxy_model, self.through, self.through.proxy_model), # type: ignore
Expand All @@ -164,7 +166,7 @@ async def remove(self, child: Optional["BaseModelType"] = None) -> None:
count = await child.query.filter(*child.identifying_clauses()).count()
if count == 0:
raise RelationshipNotFound(
detail=f"There is no relationship between '{self.from_foreign_key}' and '{self.to_foreign_key}: {getattr(child,self.to_foreign_key).pk}'."
detail=f"There is no relationship between '{self.from_foreign_key}' and '{self.to_foreign_key}: {getattr(child, self.to_foreign_key).pk}'."
)
else:
await child.delete()
Expand Down Expand Up @@ -242,6 +244,8 @@ def expand_relationship(self, value: Any) -> Any:
value = {next(iter(related_columns)): value}
instance = target.proxy_model(**value)
instance.identifying_db_fields = related_columns # type: ignore
if getattr(target.meta, "is_tenant", False):
instance.__using_schema__ = self.instance.get_active_instance_schema() # type: ignore
return instance

def stage(self, *children: "BaseModelType") -> None:
Expand Down
27 changes: 26 additions & 1 deletion tests/contrib/multi_tenancy/test_tenant_models_using.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,54 @@ class Meta:
is_tenant = True


class Cart(TenantModel):
products = fields.ManyToMany(Product)

class Meta:
registry = models
is_tenant = True


@pytest.mark.parametrize("use_copy", ["false", "instant", "after"])
async def test_schema_with_using_in_different_place(use_copy):
if use_copy == "instant":
copied = models.__copy__()
NewTenant = copied.get_model("Tenant")
NewProduct = copied.get_model("Product")
NewCart = copied.get_model("Cart")
else:
NewTenant = Tenant
NewProduct = Product
NewCart = Cart
tenant = await NewTenant.query.create(
schema_name="edgy", domain_url="https://edgy.dymmond.com", tenant_name="edgy"
)
if use_copy == "after":
copied = models.__copy__()
NewTenant = copied.get_model("Tenant")
NewProduct = copied.get_model("Product")
NewCart = copied.get_model("Cart")
cart = await NewCart.query.using(schema=tenant.schema_name).create()
assert cart.__using_schema__ == tenant.schema_name
for i in range(5):
await NewProduct.query.using(schema=tenant.schema_name).create(name=f"product-{i}")
product = await NewProduct.query.using(schema=tenant.schema_name).create(
name=f"product-{i}"
)
if i % 2 == 0:
product_through = cart.products.through(cart=cart, product=product)
product_through.__using_schema__ = tenant.schema_name
assert await cart.products.add(product_through)
else:
assert await cart.products.add(product)

total = await NewProduct.query.filter().using(schema=tenant.schema_name).all()

assert len(total) == 5

total = await cart.products.filter().using(schema=tenant.schema_name).all()

assert len(total) == 5

total = await NewProduct.query.all()

assert len(total) == 0
Expand Down

0 comments on commit 87773c8

Please sign in to comment.