[PYTHON] Create a custom field where enum can be specified in choices

I tried to specify it in CharField's choices option because enum34, which is planned to be introduced in python 3.4 and backported to python 2.4, is convenient, but I could not use it as it is, so I made it.

Reference URL

Operating environment


# -*- encoding: utf-8 -*-
from django.db import models
from django.utils.functional import curry
import enum

class ChoiceField(models.CharField):
    u"""enum in choices.A custom field that allows you to specify a subclass of the Enum.
    __metaclass__ = models.SubfieldBase

    def __init__(self, *args, **kwargs):
        self.enum = kwargs.get("choices", None)
        if self.enum and issubclass(self.enum, enum.Enum):
            kwargs["choices"] = self._from_enum()
            kwargs["max_length"] = self._calc_max_length()
        super(ChoiceField, self).__init__(*args, **kwargs)

    def _calc_max_length(self):
        u"""Find the character string length required to save in the DB.
        return max([len(unicode(item)) for item in self.enum])

    def _from_enum(self):
        u"""Convert to a tuple that can be specified in choices.
        return [(item, item.value) for item in self.enum]

    def _get_display(self, field):
        u"""Methods to retrofit the model for display
        return getattr(self, field.attname).value

    def contribute_to_class(self, cls, name, virtual_only=False):
        super(ChoiceField, self).contribute_to_class(cls, name, virtual_only)
        setattr(cls, 'get_%s_display' % self.name, curry(self._get_display, field=self))
        #Retrofit to Enum class as required by MaxLengthValidator
        setattr(self.enum, '__len__', lambda x: len(unicode(x)))

    def get_prep_value(self, value):
        if isinstance(value, basestring):
            return value
        return None if value is None else unicode(value)

    def to_python(self, value):
        if not value or not isinstance(value, basestring):
            return value

            return self.enum[value]
        except KeyError:
            for m in self.enum:
                if value == m:
                    return value
                if value == m.value:
                    return m
                if value == m.name:
                    return m
                if value.endswith(m.name):
                    return m
            raise Exception('%s is not a valid value for enum %s' % (value, self.enum))

The test looks like this


# -*- encoding: utf-8 -*-
from django import test
from django.db import models
import enum
import os

class Gender(enum.Enum):
    __order__ = "Male Female"
    Male = u"male"
    Female = u"Female"

class Person(models.Model):
    name = models.CharField(u"name", max_length=255)
    gender = choices.ChoiceField(u"sex", choices=Gender, max_length=255)

    class Meta:
        app_label = os.path.basename(os.path.abspath(os.path.join(os.path.split(__file__)[0], os.pardir)))

class PersonTest(test.TestCase):
    def test_choices(self):
        obj = Person.objects.create(name="John", gender=Gender.Male)
        self.assertEqual(Gender.Male, obj.gender)
        self.assertEqual(u"male", obj.get_gender_display())

    def test_get(self):
        obj1 = Person.objects.create(name="John", gender=Gender.Male)
        obj2 = Person.objects.get(name="John", gender=Gender.Male)
        self.assertEqual(obj1, obj2)

    def test_filter_in(self):
        obj1 = Person.objects.create(name="John", gender=Gender.Male)
        obj2 = Person.objects.create(name="Jane", gender=Gender.Female)
        actual = Person.objects.filter(gender__in=[Gender.Male, Gender.Female])
        self.assertItemsEqual([obj1, obj2], actual)

