From 2b93ee552c6d49315394d954cd5cf44093a1c275 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Sat, 28 Jul 2018 09:37:05 -0400 Subject: [PATCH] Implement Record.get() `Record.get()` allows record objects to better masquerade as dicts. Fixes: #330. --- asyncpg/protocol/record/recordobj.c | 116 ++++++++++++++++++++-------- docs/api/index.rst | 7 ++ tests/test_record.py | 13 +++- 3 files changed, 100 insertions(+), 36 deletions(-) diff --git a/asyncpg/protocol/record/recordobj.c b/asyncpg/protocol/record/recordobj.c index d272c8d1..e1de8b01 100644 --- a/asyncpg/protocol/record/recordobj.c +++ b/asyncpg/protocol/record/recordobj.c @@ -254,6 +254,62 @@ record_item(ApgRecordObject *o, Py_ssize_t i) } +typedef enum item_by_name_result { + APG_ITEM_FOUND = 0, + APG_ERROR = -1, + APG_ITEM_NOT_FOUND = -2 +} item_by_name_result_t; + + +/* Lookup a record value by its name. Return 0 on success, -2 if the + * value was not found (with KeyError set), and -1 on all other errors. + */ +static item_by_name_result_t +record_item_by_name(ApgRecordObject *o, PyObject *item, PyObject **result) +{ + PyObject *mapped; + PyObject *val; + Py_ssize_t i; + + mapped = PyObject_GetItem(o->desc->mapping, item); + if (mapped == NULL) { + goto noitem; + } + + if (!PyIndex_Check(mapped)) { + Py_DECREF(mapped); + goto error; + } + + i = PyNumber_AsSsize_t(mapped, PyExc_IndexError); + Py_DECREF(mapped); + + if (i < 0) { + if (PyErr_Occurred()) + PyErr_Clear(); + goto error; + } + + val = record_item(o, i); + if (val == NULL) { + PyErr_Clear(); + goto error; + } + + *result = val; + + return APG_ITEM_FOUND; + +noitem: + PyErr_SetObject(PyExc_KeyError, item); + return APG_ITEM_NOT_FOUND; + +error: + PyErr_SetString(PyExc_RuntimeError, "invalid record descriptor"); + return APG_ERROR; +} + + static PyObject * record_subscript(ApgRecordObject* o, PyObject* item) { @@ -299,42 +355,13 @@ record_subscript(ApgRecordObject* o, PyObject* item) } } else { - PyObject *mapped; - mapped = PyObject_GetItem(o->desc->mapping, item); - if (mapped != NULL) { - Py_ssize_t i; - PyObject *result; - - if (!PyIndex_Check(mapped)) { - Py_DECREF(mapped); - goto noitem; - } - - i = PyNumber_AsSsize_t(mapped, PyExc_IndexError); - Py_DECREF(mapped); - - if (i < 0) { - if (PyErr_Occurred()) { - PyErr_Clear(); - } - goto noitem; - } + PyObject* result; - result = record_item(o, i); - if (result == NULL) { - PyErr_Clear(); - goto noitem; - } + if (record_item_by_name(o, item, &result) < 0) + return NULL; + else return result; - } - else { - goto noitem; - } } - -noitem: - _PyErr_SetKeyError(item); - return NULL; } @@ -483,6 +510,28 @@ record_contains(ApgRecordObject *o, PyObject *arg) } +static PyObject * +record_get(ApgRecordObject* o, PyObject* args) +{ + PyObject *key; + PyObject *defval = Py_None; + PyObject *val = NULL; + int res; + + if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &defval)) + return NULL; + + res = record_item_by_name(o, key, &val); + if (res == APG_ITEM_NOT_FOUND) { + PyErr_Clear(); + Py_INCREF(defval); + val = defval; + } + + return val; +} + + static PySequenceMethods record_as_sequence = { (lenfunc)record_length, /* sq_length */ 0, /* sq_concat */ @@ -506,6 +555,7 @@ static PyMethodDef record_methods[] = { {"values", (PyCFunction)record_values, METH_NOARGS}, {"keys", (PyCFunction)record_keys, METH_NOARGS}, {"items", (PyCFunction)record_items, METH_NOARGS}, + {"get", (PyCFunction)record_get, METH_VARARGS}, {NULL, NULL} /* sentinel */ }; diff --git a/docs/api/index.rst b/docs/api/index.rst index 1b29d1f4..7e97a96c 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -302,6 +302,13 @@ items either by a numeric index or by a field name: Return an iterator over the *values* of the record *r*. + .. describe:: get(name[, default]) + + Return the value for *name* if the record has a field named *name*, + else return *default*. If *default* is not given, return ``None``. + + .. versionadded:: 0.18 + .. method:: values() Return an iterator over the record values. diff --git a/tests/test_record.py b/tests/test_record.py index cca74dd0..e9abab45 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -49,7 +49,7 @@ def test_record_gc(self): mapping = {key: val} with self.checkref(key, val): r = Record(mapping, (0,)) - with self.assertRaises(KeyError): + with self.assertRaises(RuntimeError): r[key] del r @@ -58,7 +58,7 @@ def test_record_gc(self): mapping = {key: val} with self.checkref(key, val): r = Record(mapping, (0,)) - with self.assertRaises(KeyError): + with self.assertRaises(RuntimeError): r[key] del r @@ -90,7 +90,7 @@ def test_record_len_getindex(self): with self.assertRaisesRegex(KeyError, 'spam'): Record(None, (1,))['spam'] - with self.assertRaisesRegex(KeyError, 'spam'): + with self.assertRaisesRegex(RuntimeError, 'invalid record descriptor'): Record({'spam': 123}, (1,))['spam'] def test_record_slice(self): @@ -272,6 +272,13 @@ def test_record_cmp(self): sorted([r1, r2, r3, r4, r5, r6, r7]), [r1, r2, r3, r6, r7, r4, r5]) + def test_record_get(self): + r = Record(R_AB, (42, 43)) + with self.checkref(r): + self.assertEqual(r.get('a'), 42) + self.assertEqual(r.get('nonexistent'), None) + self.assertEqual(r.get('nonexistent', 'default'), 'default') + def test_record_not_pickleable(self): r = Record(R_A, (42,)) with self.assertRaises(Exception):