首页 > 解决方案 > Can't save related objects in django models using pre_save signal

问题描述

I have to implement multi-aspect type of inheritance from UML in Django ORM. I have Contract data type which depending on type of customer (regular or business customer) can be classified as RegularContract or BusinessContract. Also contract can have expiration date or be non-expirable (it is not specified how long it will be valid), so it also can be of type ExpiringContract or NonExpiringContract. This is how concept diagram looks: enter image description here

And this is how I've implemented this: enter image description here

models.py code:

class Contract(models.Model):
    approval_date = models.DateTimeField(null=False)

    def __getattr__(self, item):
        if self.expiringcontract:
            return getattr(self.expiringcontract, item)
        elif self.nonexpiringcontract:
            return getattr(self.nonexpiringcontract, item)


class ContractExpirationExtension(models.Model):
    base = models.OneToOneField("website.Contract",
                                on_delete=models.CASCADE)

    class Meta:
        abstract = True


class ExpiringContract(ContractExpirationExtension):
    termination_date = models.DateTimeField()

    @property
    def duration(self):
        return self.termination_date - self.base.approval_date


class NonExpiringContract(ContractExpirationExtension):

    @property
    def duration(self):
        return timedelta(days=100)


class ContractTypeExtension(models.Model):

    base = models.OneToOneField("website.Contract", on_delete=models.CASCADE)
    termination_delay = models.PositiveSmallIntegerField(default=30)

    class Meta:
        abstract = True

    @classmethod
    def create(cls, approval_date, contract_expiration_type, termination_delay, **kwargs):
        type_extension = cls(termination_delay=termination_delay)
        base = Contract(approval_date=approval_date)
        expiration_type = contract_expiration_type(**kwargs)
        expiration_type.base = base
        type_extension.base = base
        if contract_expiration_type.__name__ == ExpiringContract.__name__:
            type_extension.base.expiringcontract = expiration_type
        elif contract_expiration_type.__name__ == NonExpiringContract.__name__:
            type_extension.base.nonexpiringcontract = expiration_type
        return type_extension

    def __getattr__(self, item):
        if self.base:
            return getattr(self.base,item)


class RegularContract(ContractTypeExtension):
    termination_delay = models.PositiveSmallIntegerField(validators=[validate_term_delay_regular], blank=False)


class BusinessContract(ContractTypeExtension):
    termination_delay = models.PositiveSmallIntegerField(validators=[validate_term_delay_business], blank=False)

When we need to create new contract model instance we use create() method from classes that inherit ContractTypeExtension abstract class. In create() method I create Contract base instance and appropriate expiring or non expiring instance of contract based on class object argument i pass to create() method:

@classmethod
def create(cls, approval_date, contract_expiration_type, termination_delay, **kwargs):
    type_extension = cls(termination_delay=termination_delay)
    base = Contract(approval_date=approval_date)
    expiration_type = contract_expiration_type(**kwargs)
    expiration_type.base = base
    type_extension.base = base
    if contract_expiration_type.__name__ == ExpiringContract.__name__:
        type_extension.base.expiringcontract = expiration_type
    elif contract_expiration_type.__name__ == NonExpiringContract.__name__:
        type_extension.base.nonexpiringcontract = expiration_type
    return type_extension

Because my instance of regular or business contract contains other model instances in it I can't save it without saving base and expiration_type instance first, so I decided to create pre_save signal which will do exactly that:

signals.py:

from django.db.models.signals import pre_save, pre_delete from django.dispatch import receiver

from .models import RegularContract, BusinessContract

@receiver(pre_save, sender=RegularContract) 
@receiver(pre_save, sender=BusinessContract) 
def pre_save_contract(sender, instance, *args,**kwargs):
    print("Pre_save")
    if not instance.id:
        instance.base.save()
        try:
            instance.base.expiringcontract.save()
        except (TypeError, ValueError):
            instance.base.nonexpiringcontract.save()

I registered my signal file in __init__ of app and in apps.py config:

apps.py:

from django.apps import AppConfig


class WebsiteConfig(AppConfig):
    name = 'website'

    def ready(self):
       import website.signals

website.__init__.py:

default_app_config = 'website.apps.WebsiteConfig'

To test my code I've written simple test cases:

class BusinessContractTestCase(TestCase):

    def setUp(self):
        pass

    def test_exprirating_creation(self):
        approval_date = datetime.today()
        termination_delay = 30
        termination_date = approval_date+timedelta(days=720)
        contract = BusinessContract.create(approval_date=approval_date,                                                         contract_expiration_type=ExpiringContract,
                                      termination_delay=termination_delay,
                                       termination_date=termination_date)
        contract.save()
        self.assertEqual(contract.termination_date.date(), ExpiringContract.objects.first().termination_date.date())


class RegularContractTestCase(TestCase):

    def test_exprirating_creation(self):
        approval_date = datetime.today()
        termination_delay = 30
        termination_date = approval_date + timedelta(days=720)
        contract = RegularContract.create(approval_date=approval_date,
                                      contract_expiration_type=ExpiringContract,
                                      termination_delay=termination_delay,
                                      termination_date=termination_date)
        contract.save()
        self.assertEqual(contract.termination_date.date(),
                           ExpiringContract.objects.first().termination_date.date())

But when trying to run this tests they fail and I get this error:

Error
Traceback (most recent call last):
File "/home/ubuntu/workspace/webapp/website/tests.py", line 21, in test_exprirating_creation
contract.save()
File "/home/ubuntu/workspace/venv/lib/python3.5/site-packages/django/db/models/base.py", line 685, in save
"unsaved related object '%s'." % field.name
ValueError: save() prohibited to prevent data loss due to unsaved related object 'base'.

So why pre_save signal is not triggered in my code?

标签: pythondjangopython-3.x

解决方案


After short debugging, I've understood my problem (thanks to Willem Van Onsem for pointing out that detail with pre_save.) This is how I solved it. I slightly modified create() method. Instead of assigning base to directly to newly created instance and expiration_type to base I save those to temporary variables which I can later use in my signal method:

@classmethod
def create(cls, approval_date, contract_expiration_type, termination_delay, **kwargs):
    type_extension = cls(termination_delay=termination_delay)
    base = Contract(approval_date=approval_date)
    expiration_type = contract_expiration_type(**kwargs)
    type_extension.temp_base = base
    if contract_expiration_type.__name__ == ExpiringContract.__name__:
        type_extension.temp_expiringcontract = expiration_type
    elif contract_expiration_type.__name__ == NonExpiringContract.__name__:
        type_extension.base.temp_nonexpiringcontract = expiration_type
    return type_extension

Then in signals.py in pre_save signal I separately save base from temporary variable and assign to instance's base and separately assign my expiring/nonexpiring contract type instance from temporary variable to base and save it:

@receiver(pre_save, sender=RegularContract)
@receiver(pre_save, sender=BusinessContract)
def pre_save_contract(sender, instance, *args, **kwargs):
    print("Pre_save")
    instance.temp_base.save()
    instance.base = instance.temp_base
    if hasattr(instance,"temp_expiringcontract"):
        instance.base.expiringcontract = instance.temp_expiringcontract
        instance.base.expiringcontract.save()
    else:
        instance.base.nonexpiringcontract =    instance.temp_nonexpiringcontract
        instance.base.nonexpiringcontract.save()

This is probably not the best solution, but at least it works.


推荐阅读