11

On a Django 2.0 project, i have the following issue on my unit tests and I can't find the cause.

-- UPDATE : I am using Postgres 10.1. The problem doesn't occur when I switch to sqlite3

I am implementing a model which tracks any change on another model

from django.db import models
from django.contrib.auth.models import User
from django.db.models.signals import post_save
from django.dispatch import receiver

class Investment(models.Model):
     """the main model"""
     status = models.IntegerField()


class InvestmentStatusTrack(models.Model):
    """track every change of status on an investment"""
    investment = models.ForeignKey(Investment, on_delete=models.CASCADE)
    status = models.IntegerField()
    modified_on = models.DateTimeField(
        blank=True, null=True, default=None, verbose_name=_('modified on'), db_index=True
    )
    modified_by = models.ForeignKey(
        User, blank=True, null=True, default=None, verbose_name=_('modified by'), on_delete=models.CASCADE
    )

    class Meta:
        ordering = ('-modified_on', )

    def __str__(self):
        return '{0} - {1}'.format(self.investment, self.status)


@receiver(post_save, sender=Investment)
def handle_status_track(sender, instance, created, **kwargs):
    """add a new track every time the investment status change"""
    request = get_request()  # a way to get the current request
    modified_by = None
    if request and request.user and request.user.is_authenticated:
        modified_by = request.user
    InvestmentStatusTrack.objects.create(
       investment=instance, status=instance.status, modified_on=datetime.now(), modified_by=modified_by
    )

Most of my unit test fails with the following traceback

Traceback (most recent call last):
  File "/env/lib/python3.6/site-packages/django/test/testcases.py", line 209, in __call__
    self._post_teardown()
  File "/env/lib/python3.6/site-packages/django/test/testcases.py", line 893, in _post_teardown
    self._fixture_teardown()
  File "/env/lib/python3.6/site-packages/django/test/testcases.py", line 1041, in _fixture_teardown
    connections[db_name].check_constraints()
  File "/env/lib/python3.6/site-packages/django/db/backends/postgresql/base.py", line 235, in check_constraints
    self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
  File "/env/lib/python3.6/site-packages/django/db/backends/utils.py", line 68, in execute
    return self._execute_with_wrappers(sql, params, many=False, executor=self._execute)
  File "/env/lib/python3.6/site-packages/django/db/backends/utils.py", line 77, in _execute_with_wrappers
    return executor(sql, params, many, context)
  File "/env/lib/python3.6/site-packages/django/db/backends/utils.py", line 85, in _execute
    return self.cursor.execute(sql, params)
  File "/env/lib/python3.6/site-packages/django/db/utils.py", line 89, in __exit__
    raise dj_exc_value.with_traceback(traceback) from exc_value
  File "/env/lib/python3.6/site-packages/django/db/backends/utils.py", line 83, in _execute
    return self.cursor.execute(sql)
django.db.utils.IntegrityError: insert or update on table "investments_investmentstatustrack" violates foreign key constraint "investments_investme_modified_by_id_3a12fb21_fk_auth_user"
DETAIL:  Key (modified_by_id)=(1) is not present in table "auth_user".

Any idea, how to fix this problem?

-- UPDATE : 2 unit test which shows the problem.

Both are successful when executed alone. It seems that the problem occurs on the unit test tearDown. The Foreign Key constraint fails at this moment because the User has already been deleted.

class TrackInvestmentStatusTest(ApiTestCase):

    def login(self, is_staff=False):
        password = "abc123"
        self.user = mommy.make(User, is_staff=is_staff, is_active=True)
        self.user.set_password(password)
        self.user.save()
        self.assertTrue(self.client.login(username=self.user.username, password=password))

    def test_add_investment(self):
        """it should add a new investment and add a track"""
        self.login()

        url = reverse('investments:investments-list')

        data = {}

        response = self.client.post(url, data=data)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        self.assertEqual(1, Investment.objects.count())
        investment = Investment.objects.all()[0]
        self.assertEqual(investment.status, Investment.STATUS_IN_PROJECT)

        self.assertEqual(1, InvestmentStatusTrack.objects.count())
        track = InvestmentStatusTrack.objects.all()[0]
        self.assertEqual(track.status, investment.status)
        self.assertEqual(track.investment, investment)
        self.assertEqual(track.modified_by, self.user)
        self.assertEqual(track.modified_on.date(), date.today())

    def test_save_status(self):
        """it should modify the investment and add a track"""

        self.login()

        investment_status = Investment.STATUS_IN_PROJECT

        investment = mommy.make(Investment, asset=asset, status=investment_status)
        investment_id = investment.id

        self.assertEqual(1, InvestmentStatusTrack.objects.count())
        track = InvestmentStatusTrack.objects.all()[0]
        self.assertEqual(track.status, investment.status)
        self.assertEqual(track.investment, investment)
        self.assertEqual(track.modified_by, None)
        self.assertEqual(track.modified_on.date(), date.today())

        url = reverse('investments:investments-detail', args=[investment.id])

        data = {
            'status': Investment.STATUS_ACCEPTED
        }

        response = self.client.patch(url, data=data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)

        self.assertEqual(1, Investment.objects.count())
        investment = Investment.objects.all()[0]
        self.assertEqual(investment.id, investment_id)
        self.assertEqual(investment.status, Investment.STATUS_ACCEPTED)

        self.assertEqual(2, InvestmentStatusTrack.objects.count())
        track = InvestmentStatusTrack.objects.all()[0]
        self.assertEqual(track.status, Investment.STATUS_ACCEPTED)
        self.assertEqual(track.investment, investment)
        self.assertEqual(track.modified_by, self.user)
        self.assertEqual(track.modified_on.date(), date.today())

        track = InvestmentStatusTrack.objects.all()[1]
        self.assertEqual(track.status, Investment.STATUS_IN_PROJECT)
        self.assertEqual(track.investment, investment)
        self.assertEqual(track.modified_by, None)
        self.assertEqual(track.modified_on.date(), date.today())
12
  • Did you make all your Makemigrations/Migrates? Commented Feb 16, 2018 at 17:03
  • Yes, it is up to date with the model Commented Feb 16, 2018 at 17:03
  • Either i am blind, or your code seems to be fine. Could you add one of your unit tests which throw this Traceback? Commented Feb 16, 2018 at 17:09
  • @Sativa I've updated my code with unit test. Commented Feb 16, 2018 at 17:16
  • The problem with some database drivers is that after an IntegrityError the connection stays in invalid state and every query after that will fail as well. You probably should issue a transaction.rollback in the tear-down step for the test just to be safe. Commented Feb 16, 2018 at 17:21

3 Answers 3

5
+50

So I debugged through the tests and I found the issue happening here.

The middleware you use for capturing request doesn't work in self.client.login. Because it is never called. In your first test you call

response = self.client.post(url, data=data)

This calls the middleware and sets the thread request and the current user. But in your next test you have a

investment = mommy.make(Investment, status=investment_status)

This fires the handle_status_track, which then gets the older request that was leftover from your previous test and has the user with id as 1. But the current user is with id=2, the id=1 user was created and destroyed in test 1 itself.

So your middleware to trick and capture the request is basically the culprit here in this case.

Edit-1

The issue will only happen in Test and won't happen in production. One simple fix to avoid this is to create set_user method in the middleware

def set_user(user):
    current_request = get_request()

    if current_request:
        current_request.user = user

And then update your login method to below

def login(self, is_staff=False):
    password = "abc123"
    self.user = mommy.make(User, is_staff=is_staff, is_active=True)
    self.user.set_password(password)
    self.user.save()
    self.assertTrue(self.client.login(username=self.user.username, password=password))
    set_user(self.user)

This will make sure each test gets the correct middleware.

The wrong exception stack trace

Your exception has below lin

  File "/env/lib/python3.6/site-packages/django/test/testcases.py", line 1041, in _fixture_teardown
    connections[db_name].check_constraints()

Now if you look at the code on that line

def _fixture_teardown(self):
    if not connections_support_transactions():
        return super()._fixture_teardown()
    try:
        for db_name in reversed(self._databases_names()):
            if self._should_check_constraints(connections[db_name]):
                connections[db_name].check_constraints()
    finally:
        self._rollback_atomics(self.atomics)

There is a try block, then how can an exception occur? One line 188 of testcases.py, you have

def __call__(self, result=None):
    """
    Wrapper around default __call__ method to perform common Django test
    set up. This means that user-defined Test Cases aren't required to
    include a call to super().setUp().
    """
    testMethod = getattr(self, self._testMethodName)
    skipped = (
        getattr(self.__class__, "__unittest_skip__", False) or
        getattr(testMethod, "__unittest_skip__", False)
    )

    if not skipped:
        try:
            self._pre_setup()
        except Exception:
            result.addError(self, sys.exc_info())
            return
    super().__call__(result)
    if not skipped:
        try:
            self._post_teardown()
        except Exception:
            result.addError(self, sys.exc_info())
            return

The result.addError(self, sys.exc_info()) captures the exception which was already handled by self._post_teardown, so you get the wrong trace. Not sure if it is a bug or a edge case, but that my analysis

Sign up to request clarification or add additional context in comments.

2 Comments

you are absolutely right. I didn't think that the request middleware could be the root cause. Do you think is it just a unit test issue or can it happen also in the app?
I've fixed the middleware by clearing the request once it is processed. I think this was mainly a unit test issue but it could also be a problem in some situations.
2

I fixed the problem by refactoring my code.

Now I don't create the track inside the save method of the investment or inside a post_save signal handler, but in a method which is called explicitly

My code looks like:

models.py

class Investment(models.Model):
    """the main model"""
    status = models.IntegerField()

    def handle_status_track(self):
        """add a new track every time the investment status change"""
        request = get_request()  # a way to get the current request
        modified_by = None
        if request and request.user and request.user.is_authenticated:
            modified_by = request.user
        InvestmentStatusTrack.objects.create(
            investment=self, status=self.status, modified_on=datetime.now(), modified_by=modified_by
        )


class InvestmentStatusTrack(models.Model):
    """track every change of status on an investment"""
    investment = models.ForeignKey(Investment, on_delete=models.CASCADE)
    status = models.IntegerField()
    modified_on = models.DateTimeField(
        blank=True, null=True, default=None, verbose_name=_('modified on'), db_index=True
    )
    modified_by = models.ForeignKey(
        User, blank=True, null=True, default=None, verbose_name=_('modified by'), on_delete=models.CASCADE
    )

    class Meta:
        ordering = ('-modified_on',)

views.py

class InvestmentViewSet(ViewSet):
    model = Investment
    serializer_class = InvestmentSerializer

    def perform_create(self, serializer):
        """save"""
        investment = serializer.save()
        investment.handle_status_track()

    def perform_update(self, serializer):
        """save"""
        investment = serializer.save()
        investment.handle_status_track()

The problem is that it doesn't do exactly the same : I need to handle the call to the method any time the object is saved. I am still wondering why the post_save signal causes this error.

4 Comments

Don't know if it's related but depending on how your setup you signals I had to put weak=False(docs.djangoproject.com/en/2.0/topics/signals/…) when running unittests to avoid them being garbage collected and causing weird problems.
@DanielBackman Thanks for this information. I didn't know this. I will check.
@DanielBackmanThanks for you suggestion. Unfortunately, putting weak=False when setting up the signal didn't solve the problem
@luc, I found the issue with your tests and its the use of get_request using the custom middleware which is polluting the system state
1

As mentionned by @Tarun Lalwani, the root cause of the problem is a bad management of the request middleware

Here is the fixed code for clarification:

from threading import current_thread

class RequestManager(object):
    """get django request from anywhere"""
    _shared = {}

    def __init__(self):
        """This is a Borg"""
        self.__dict__ = RequestManager._shared

    def _get_request_dict(self):
        """request dict"""
        if not hasattr(self, '_request'):
            self._request = {}  # pylint: disable=attribute-defined-outside-init
        return self._request

    def clean(self):
        """clean"""
        if hasattr(self, '_request'):
            del self._request

    def get_request(self):
        """return request"""
        _requests = self._get_request_dict()
        the_thread = current_thread()
        if the_thread not in _requests:
            return None
        return _requests[the_thread]

    def set_request(self, request):
        """set request"""
        _requests = self._get_request_dict()
        _requests[current_thread()] = request


class RequestMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):
        # Set the request
        RequestManager().set_request(request)

        response = self.get_response(request)

        # ---- THIS WAS THE MISSING PART -----
        # Clear the request
        RequestManager().set_request(None)
        # ------------------------------------

        return response

    def process_exception(self, request, exception):
        """handle exceptions"""
        # clear request also in case of exception
        RequestManager().set_request(None)


def get_request():
    """get current request from anywhere"""
    return RequestManager().get_request()

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.