Django's CommaSeparatedIntegerField
It doesn't do what you think it does, but it can still be useful.
I've you ever need to store lists of IDs in a field in Django, you've probably
come across the CommaSeparatedIntegerField. It sounds ideal - you can save
lists to and from the database. Brilliant.
Except that that's not how it works. The CommaSeparatedIntegerField
is in
fact nothing more than a CharField
with a regex validator (re.compile('^[\d,]+$')
) over the top that ensures that only commas and digits are input. That's it. If you save a list of integers into the database, and then retrieve them again, you get a string back out - as that's what is stored in the database. This partial implementationis arguably less useful than a DIY field, as it's only half there.
So how can we patch this up?
Validation of the field contents when submitted aside, there are two main places
that you need to watch when using CommaSeparatedIntegerField
values - on saving the object, and on re-hydrate it via the ORM (e.g. objects.get(id=1)
).
The easiest way to add validation to both of these activities is to hook into the Model's pre_save
and post_init
signals (the ORM fires a post_init
signal
whenever casting from the database representation to the relevant model).
This is the pattern we use for models with CommaSeparatedIntegerFields
:
- Add a
clean_csv_fields()
method - Register a signal handler for the model's
pre_save
andpost_init
signals - Call
clean_csv_fields()
here
The clean_csv_fields()
method simply ensures that the internal attributes
of the model instance are valid lists. Here's one I prepared earlier:
from django.db import models
from django.db.models.signals import pre_save, post_init
class MyModel(models.Model):
# model fields go here
included_ids = models.CommaSeparatedIntegerField(max_length=100)
excluded_ids = models.CommaSeparatedIntegerField(max_length=100)
def clean_csv_fields(self):
"""Convert all the CommaSeparatedIntegerField values to lists."""
# this function is applied to each element in the list
# in this case just cast it to an integer, ensuring that
# we get a list of ints back out
func = lambda x: int(x)
self.included_ids = smart_list(self.excluded_ids, func=func)
self.excluded_ids = smart_list(self.results_list, func=func)
# listen for the signals, and ensure that fields are cleaned
@receiver([pre_save, post_init], sender=MyModel)
def _on_model_signal(sender, instance, **kwargs):
instance.clean_csv_fields()
This will ensure that you have lists everywhere you expect them, whenever
you expect them.
What's this smart_list
function call?
The one unknown in the code snippet above is the reference to a smart_list
function, which is a function that does everything it can to parse what it's given into a list, whether it be a string, a list (in which case it is returned unchanged), a tuple etc.
The function is shown below - the important thing is that in addition to parsing the input parameter it also applies a function to each element in the resulting list - which can be used (as shown above) to cast elements to integers. It's pretty straight-forward.
def smart_list(value, delimiter=",", func=None):
"""Convert a value to a list, if possible.
Args:
value: the value to be parsed. Ideally a string of comma separated
values - e.g. "1,2,3,4", but could be a list, a tuple, ...
Kwargs:
delimiter: string, the delimiter used to split the value argument,
if it's a string / unicode. Defaults to ','.
func: a function applied to each individual element in the list
once the value arg is split. e.g. lambda x: int(x) would return
a list of integers. Defaults to None - in which case you just
get the list.
Returns: a list if one can be parsed out of the input value. If the
value input is an empty string or None, returns an empty
list. If the split or func parsing fails, raises a ValueError.
This is mainly used for ensuring the CSV model fields are properly
formatted. Use this function in the save() model method and post_init()
model signal to ensure that you always get a list back from the field.
"""
if value in ["", u"", "[]", u"[]", u"[ ]", None]:
return []
if isinstance(value, list):
l = value
elif isinstance(value, tuple):
l = list(value)
elif isinstance(value, basestring) or isinstance(value, unicode):
# TODO: regex this.
value = value.lstrip('[').rstrip(']').strip(' ')
if len(value) == 0:
return []
else:
l = value.split(delimiter)
elif isinstance(value, int):
l = [value]
else:
raise ValueError(u"Unparseable smart_list value: %s" % value)
try:
func = func or (lambda x: x)
return [func(e) for e in l]
except Exception as ex:
raise ValueError(u"Unable to parse value '%s': %s" % (value, ex))
And some tests for you - just to demonstrate what output you get from the
input. The original is in this gist - https://gist.github.com/hugorodgerbrown/557e08f8800997866b8d
# coding=utf-8
import unittest
class SmartListTests(unittest.TestCase):
def test_smart_list_empty(self):
"Test the smart_list function with empty input."
self.assertEqual(smart_list(""), [])
self.assertEqual(smart_list(" "), [])
self.assertEqual(smart_list(" "), [])
self.assertEqual(smart_list("[]"), [])
self.assertEqual(smart_list(u"[]"), [])
self.assertEqual(smart_list(u"[ ]"), [])
self.assertEqual(smart_list("[ ]"), [])
self.assertEqual(smart_list("[21]"), ["21"])
self.assertEqual(smart_list(None), [])
def test_smart_list_single_values(self):
"Test the smart_list function with valid single value input."
self.assertEqual(smart_list("1"), ["1"])
self.assertEqual(smart_list(u"ß "), [u"ß"])
self.assertEqual(smart_list(1), [1])
self.assertEqual(smart_list((1,)), [1])
self.assertEqual(smart_list([1]), [1])
def test_smart_list_errors(self):
"Test the smart_list function raises expected ValueErrors."
self.assertRaises(ValueError, smart_list, {"1": None})
self.assertRaises(ValueError, smart_list, object())
def test_smart_list_delimiter(self):
"Test the smart_list delimiter works."
self.assertEqual(smart_list("1,2"), ["1", "2"])
self.assertEqual(smart_list("1 2"), ["1 2"])
self.assertEqual(smart_list("1 2", " "), ["1", "2"])
def test_smart_list_func(self):
"Test the smart_list func parameter works."
self.assertEqual(smart_list("1,2", func=lambda x: int(x)), [1, 2])
self.assertEqual(smart_list("1,2", func=lambda x: int(x) * 2), [2, 4])
self.assertEqual(smart_list("[21]", func=lambda x: int(x)), [21])
self.assertRaises(
ValueError,
smart_list,
"1,A",
func=lambda x: int(x) * 2
)
Posted in: django