LCOV - code coverage report
Current view: top level - Objects - unionobject.c (source / functions) Hit Total Coverage
Test: CPython lcov report Lines: 175 200 87.5 %
Date: 2022-07-07 18:19:46 Functions: 17 17 100.0 %

          Line data    Source code
       1             : // types.UnionType -- used to represent e.g. Union[int, str], int | str
       2             : #include "Python.h"
       3             : #include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
       4             : #include "pycore_unionobject.h"
       5             : #include "structmember.h"
       6             : 
       7             : 
       8             : static PyObject *make_union(PyObject *);
       9             : 
      10             : 
      11             : typedef struct {
      12             :     PyObject_HEAD
      13             :     PyObject *args;
      14             :     PyObject *parameters;
      15             : } unionobject;
      16             : 
      17             : static void
      18        9383 : unionobject_dealloc(PyObject *self)
      19             : {
      20        9383 :     unionobject *alias = (unionobject *)self;
      21             : 
      22        9383 :     _PyObject_GC_UNTRACK(self);
      23             : 
      24        9383 :     Py_XDECREF(alias->args);
      25        9383 :     Py_XDECREF(alias->parameters);
      26        9383 :     Py_TYPE(self)->tp_free(self);
      27        9383 : }
      28             : 
      29             : static int
      30      242786 : union_traverse(PyObject *self, visitproc visit, void *arg)
      31             : {
      32      242786 :     unionobject *alias = (unionobject *)self;
      33      242786 :     Py_VISIT(alias->args);
      34      242786 :     Py_VISIT(alias->parameters);
      35      242786 :     return 0;
      36             : }
      37             : 
      38             : static Py_hash_t
      39          19 : union_hash(PyObject *self)
      40             : {
      41          19 :     unionobject *alias = (unionobject *)self;
      42          19 :     PyObject *args = PyFrozenSet_New(alias->args);
      43          19 :     if (args == NULL) {
      44           0 :         return (Py_hash_t)-1;
      45             :     }
      46          19 :     Py_hash_t hash = PyObject_Hash(args);
      47          19 :     Py_DECREF(args);
      48          19 :     return hash;
      49             : }
      50             : 
      51             : static PyObject *
      52         173 : union_richcompare(PyObject *a, PyObject *b, int op)
      53             : {
      54         173 :     if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
      55         144 :         Py_RETURN_NOTIMPLEMENTED;
      56             :     }
      57             : 
      58          29 :     PyObject *a_set = PySet_New(((unionobject*)a)->args);
      59          29 :     if (a_set == NULL) {
      60           0 :         return NULL;
      61             :     }
      62          29 :     PyObject *b_set = PySet_New(((unionobject*)b)->args);
      63          29 :     if (b_set == NULL) {
      64           0 :         Py_DECREF(a_set);
      65           0 :         return NULL;
      66             :     }
      67          29 :     PyObject *result = PyObject_RichCompare(a_set, b_set, op);
      68          29 :     Py_DECREF(b_set);
      69          29 :     Py_DECREF(a_set);
      70          29 :     return result;
      71             : }
      72             : 
      73             : static int
      74        9958 : is_same(PyObject *left, PyObject *right)
      75             : {
      76        9958 :     int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
      77        9958 :     return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
      78             : }
      79             : 
      80             : static int
      81        9404 : contains(PyObject **items, Py_ssize_t size, PyObject *obj)
      82             : {
      83       19346 :     for (int i = 0; i < size; i++) {
      84        9958 :         int is_duplicate = is_same(items[i], obj);
      85        9958 :         if (is_duplicate) {  // -1 or 1
      86          16 :             return is_duplicate;
      87             :         }
      88             :     }
      89        9388 :     return 0;
      90             : }
      91             : 
      92             : static PyObject *
      93        9395 : merge(PyObject **items1, Py_ssize_t size1,
      94             :       PyObject **items2, Py_ssize_t size2)
      95             : {
      96        9395 :     PyObject *tuple = NULL;
      97        9395 :     Py_ssize_t pos = 0;
      98             : 
      99       18798 :     for (int i = 0; i < size2; i++) {
     100        9404 :         PyObject *arg = items2[i];
     101        9404 :         int is_duplicate = contains(items1, size1, arg);
     102        9404 :         if (is_duplicate < 0) {
     103           1 :             Py_XDECREF(tuple);
     104           1 :             return NULL;
     105             :         }
     106        9403 :         if (is_duplicate) {
     107          15 :             continue;
     108             :         }
     109             : 
     110        9388 :         if (tuple == NULL) {
     111        9383 :             tuple = PyTuple_New(size1 + size2 - i);
     112        9383 :             if (tuple == NULL) {
     113           0 :                 return NULL;
     114             :             }
     115       19309 :             for (; pos < size1; pos++) {
     116        9926 :                 PyObject *a = items1[pos];
     117        9926 :                 Py_INCREF(a);
     118        9926 :                 PyTuple_SET_ITEM(tuple, pos, a);
     119             :             }
     120             :         }
     121        9388 :         Py_INCREF(arg);
     122        9388 :         PyTuple_SET_ITEM(tuple, pos, arg);
     123        9388 :         pos++;
     124             :     }
     125             : 
     126        9394 :     if (tuple) {
     127        9383 :         (void) _PyTuple_Resize(&tuple, pos);
     128             :     }
     129        9394 :     return tuple;
     130             : }
     131             : 
     132             : static PyObject **
     133       18790 : get_types(PyObject **obj, Py_ssize_t *size)
     134             : {
     135       18790 :     if (*obj == Py_None) {
     136        4313 :         *obj = (PyObject *)&_PyNone_Type;
     137             :     }
     138       18790 :     if (_PyUnion_Check(*obj)) {
     139         533 :         PyObject *args = ((unionobject *) *obj)->args;
     140         533 :         *size = PyTuple_GET_SIZE(args);
     141         533 :         return &PyTuple_GET_ITEM(args, 0);
     142             :     }
     143             :     else {
     144       18257 :         *size = 1;
     145       18257 :         return obj;
     146             :     }
     147             : }
     148             : 
     149             : static int
     150       20676 : is_unionable(PyObject *obj)
     151             : {
     152       16363 :     return (obj == Py_None ||
     153       24137 :         PyType_Check(obj) ||
     154       44813 :         _PyGenericAlias_Check(obj) ||
     155        1480 :         _PyUnion_Check(obj));
     156             : }
     157             : 
     158             : PyObject *
     159       10339 : _Py_union_type_or(PyObject* self, PyObject* other)
     160             : {
     161       10339 :     if (!is_unionable(self) || !is_unionable(other)) {
     162         944 :         Py_RETURN_NOTIMPLEMENTED;
     163             :     }
     164             : 
     165             :     Py_ssize_t size1, size2;
     166        9395 :     PyObject **items1 = get_types(&self, &size1);
     167        9395 :     PyObject **items2 = get_types(&other, &size2);
     168        9395 :     PyObject *tuple = merge(items1, size1, items2, size2);
     169        9395 :     if (tuple == NULL) {
     170          12 :         if (PyErr_Occurred()) {
     171           1 :             return NULL;
     172             :         }
     173          11 :         Py_INCREF(self);
     174          11 :         return self;
     175             :     }
     176             : 
     177        9383 :     PyObject *new_union = make_union(tuple);
     178        9383 :     Py_DECREF(tuple);
     179        9383 :     return new_union;
     180             : }
     181             : 
     182             : static int
     183          52 : union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
     184             : {
     185          52 :     PyObject *qualname = NULL;
     186          52 :     PyObject *module = NULL;
     187             :     PyObject *tmp;
     188          52 :     PyObject *r = NULL;
     189             :     int err;
     190             : 
     191          52 :     if (p == (PyObject *)&_PyNone_Type) {
     192           2 :         return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
     193             :     }
     194             : 
     195          50 :     if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) {
     196           0 :         goto exit;
     197             :     }
     198             : 
     199          50 :     if (tmp) {
     200          11 :         Py_DECREF(tmp);
     201          11 :         if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) {
     202           0 :             goto exit;
     203             :         }
     204          11 :         if (tmp) {
     205             :             // It looks like a GenericAlias
     206          11 :             Py_DECREF(tmp);
     207          11 :             goto use_repr;
     208             :         }
     209             :     }
     210             : 
     211          39 :     if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
     212           0 :         goto exit;
     213             :     }
     214          39 :     if (qualname == NULL) {
     215           0 :         goto use_repr;
     216             :     }
     217          39 :     if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) {
     218           0 :         goto exit;
     219             :     }
     220          39 :     if (module == NULL || module == Py_None) {
     221           0 :         goto use_repr;
     222             :     }
     223             : 
     224             :     // Looks like a class
     225          78 :     if (PyUnicode_Check(module) &&
     226          39 :         _PyUnicode_EqualToASCIIString(module, "builtins"))
     227             :     {
     228             :         // builtins don't need a module name
     229          39 :         r = PyObject_Str(qualname);
     230          39 :         goto exit;
     231             :     }
     232             :     else {
     233           0 :         r = PyUnicode_FromFormat("%S.%S", module, qualname);
     234           0 :         goto exit;
     235             :     }
     236             : 
     237          11 : use_repr:
     238          11 :     r = PyObject_Repr(p);
     239          50 : exit:
     240          50 :     Py_XDECREF(qualname);
     241          50 :     Py_XDECREF(module);
     242          50 :     if (r == NULL) {
     243           0 :         return -1;
     244             :     }
     245          50 :     err = _PyUnicodeWriter_WriteStr(writer, r);
     246          50 :     Py_DECREF(r);
     247          50 :     return err;
     248             : }
     249             : 
     250             : static PyObject *
     251          24 : union_repr(PyObject *self)
     252             : {
     253          24 :     unionobject *alias = (unionobject *)self;
     254          24 :     Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
     255             : 
     256             :     _PyUnicodeWriter writer;
     257          24 :     _PyUnicodeWriter_Init(&writer);
     258          76 :      for (Py_ssize_t i = 0; i < len; i++) {
     259          52 :         if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
     260           0 :             goto error;
     261             :         }
     262          52 :         PyObject *p = PyTuple_GET_ITEM(alias->args, i);
     263          52 :         if (union_repr_item(&writer, p) < 0) {
     264           0 :             goto error;
     265             :         }
     266             :     }
     267          24 :     return _PyUnicodeWriter_Finish(&writer);
     268           0 : error:
     269           0 :     _PyUnicodeWriter_Dealloc(&writer);
     270           0 :     return NULL;
     271             : }
     272             : 
     273             : static PyMemberDef union_members[] = {
     274             :         {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
     275             :         {0}
     276             : };
     277             : 
     278             : static PyObject *
     279           5 : union_getitem(PyObject *self, PyObject *item)
     280             : {
     281           5 :     unionobject *alias = (unionobject *)self;
     282             :     // Populate __parameters__ if needed.
     283           5 :     if (alias->parameters == NULL) {
     284           3 :         alias->parameters = _Py_make_parameters(alias->args);
     285           3 :         if (alias->parameters == NULL) {
     286           0 :             return NULL;
     287             :         }
     288             :     }
     289             : 
     290           5 :     PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
     291           5 :     if (newargs == NULL) {
     292           0 :         return NULL;
     293             :     }
     294             : 
     295             :     PyObject *res;
     296           5 :     Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
     297           5 :     if (nargs == 0) {
     298           0 :         res = make_union(newargs);
     299             :     }
     300             :     else {
     301           5 :         res = PyTuple_GET_ITEM(newargs, 0);
     302           5 :         Py_INCREF(res);
     303          10 :         for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
     304           5 :             PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
     305           5 :             Py_SETREF(res, PyNumber_Or(res, arg));
     306           5 :             if (res == NULL) {
     307           0 :                 break;
     308             :             }
     309             :         }
     310             :     }
     311           5 :     Py_DECREF(newargs);
     312           5 :     return res;
     313             : }
     314             : 
     315             : static PyMappingMethods union_as_mapping = {
     316             :     .mp_subscript = union_getitem,
     317             : };
     318             : 
     319             : static PyObject *
     320          26 : union_parameters(PyObject *self, void *Py_UNUSED(unused))
     321             : {
     322          26 :     unionobject *alias = (unionobject *)self;
     323          26 :     if (alias->parameters == NULL) {
     324          16 :         alias->parameters = _Py_make_parameters(alias->args);
     325          16 :         if (alias->parameters == NULL) {
     326           0 :             return NULL;
     327             :         }
     328             :     }
     329          26 :     Py_INCREF(alias->parameters);
     330          26 :     return alias->parameters;
     331             : }
     332             : 
     333             : static PyGetSetDef union_properties[] = {
     334             :     {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL},
     335             :     {0}
     336             : };
     337             : 
     338             : static PyNumberMethods union_as_number = {
     339             :         .nb_or = _Py_union_type_or, // Add __or__ function
     340             : };
     341             : 
     342             : static const char* const cls_attrs[] = {
     343             :         "__module__",  // Required for compatibility with typing module
     344             :         NULL,
     345             : };
     346             : 
     347             : static PyObject *
     348         748 : union_getattro(PyObject *self, PyObject *name)
     349             : {
     350         748 :     unionobject *alias = (unionobject *)self;
     351         748 :     if (PyUnicode_Check(name)) {
     352        1493 :         for (const char * const *p = cls_attrs; ; p++) {
     353        1493 :             if (*p == NULL) {
     354         745 :                 break;
     355             :             }
     356         748 :             if (_PyUnicode_EqualToASCIIString(name, *p)) {
     357           3 :                 return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
     358             :             }
     359             :         }
     360             :     }
     361         745 :     return PyObject_GenericGetAttr(self, name);
     362             : }
     363             : 
     364             : PyObject *
     365          38 : _Py_union_args(PyObject *self)
     366             : {
     367          38 :     assert(_PyUnion_Check(self));
     368          38 :     return ((unionobject *) self)->args;
     369             : }
     370             : 
     371             : PyTypeObject _PyUnion_Type = {
     372             :     PyVarObject_HEAD_INIT(&PyType_Type, 0)
     373             :     .tp_name = "types.UnionType",
     374             :     .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
     375             :               "\n"
     376             :               "E.g. for int | str"),
     377             :     .tp_basicsize = sizeof(unionobject),
     378             :     .tp_dealloc = unionobject_dealloc,
     379             :     .tp_alloc = PyType_GenericAlloc,
     380             :     .tp_free = PyObject_GC_Del,
     381             :     .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
     382             :     .tp_traverse = union_traverse,
     383             :     .tp_hash = union_hash,
     384             :     .tp_getattro = union_getattro,
     385             :     .tp_members = union_members,
     386             :     .tp_richcompare = union_richcompare,
     387             :     .tp_as_mapping = &union_as_mapping,
     388             :     .tp_as_number = &union_as_number,
     389             :     .tp_repr = union_repr,
     390             :     .tp_getset = union_properties,
     391             : };
     392             : 
     393             : static PyObject *
     394        9383 : make_union(PyObject *args)
     395             : {
     396        9383 :     assert(PyTuple_CheckExact(args));
     397             : 
     398        9383 :     unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
     399        9383 :     if (result == NULL) {
     400           0 :         return NULL;
     401             :     }
     402             : 
     403        9383 :     Py_INCREF(args);
     404        9383 :     result->parameters = NULL;
     405        9383 :     result->args = args;
     406        9383 :     _PyObject_GC_TRACK(result);
     407        9383 :     return (PyObject*)result;
     408             : }

Generated by: LCOV version 1.14