Coverage Report

Created: 2022-07-08 09:39

/home/mdboom/Work/builds/cpython/Objects/unionobject.c
Line
Count
Source (jump to first uncovered line)
1
// types.UnionType -- used to represent e.g. Union[int, str], int | str
2
#include "Python.h"
3
#include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
4
#include "pycore_unionobject.h"
5
#include "structmember.h"
6
7
8
static PyObject *make_union(PyObject *);
9
10
11
typedef struct {
12
    PyObject_HEAD
13
    PyObject *args;
14
    PyObject *parameters;
15
} unionobject;
16
17
static void
18
unionobject_dealloc(PyObject *self)
19
{
20
    unionobject *alias = (unionobject *)self;
21
22
    _PyObject_GC_UNTRACK(self);
23
24
    Py_XDECREF(alias->args);
25
    Py_XDECREF(alias->parameters);
26
    Py_TYPE(self)->tp_free(self);
27
}
28
29
static int
30
union_traverse(PyObject *self, visitproc visit, void *arg)
31
{
32
    unionobject *alias = (unionobject *)self;
33
    Py_VISIT(alias->args);
34
    Py_VISIT(alias->parameters);
35
    return 0;
36
}
37
38
static Py_hash_t
39
union_hash(PyObject *self)
40
{
41
    unionobject *alias = (unionobject *)self;
42
    PyObject *args = PyFrozenSet_New(alias->args);
43
    if (args == NULL) {
  Branch (43:9): [True: 0, False: 19]
44
        return (Py_hash_t)-1;
45
    }
46
    Py_hash_t hash = PyObject_Hash(args);
47
    Py_DECREF(args);
48
    return hash;
49
}
50
51
static PyObject *
52
union_richcompare(PyObject *a, PyObject *b, int op)
53
{
54
    if (!_PyUnion_Check(b) || 
(35
op != 35
Py_EQ35
&&
op != 6
Py_NE6
)) {
  Branch (54:9): [True: 138, False: 35]
  Branch (54:32): [True: 6, False: 29]
  Branch (54:47): [True: 6, False: 0]
55
        Py_RETURN_NOTIMPLEMENTED;
56
    }
57
58
    PyObject *a_set = PySet_New(((unionobject*)a)->args);
59
    if (a_set == NULL) {
  Branch (59:9): [True: 0, False: 29]
60
        return NULL;
61
    }
62
    PyObject *b_set = PySet_New(((unionobject*)b)->args);
63
    if (b_set == NULL) {
  Branch (63:9): [True: 0, False: 29]
64
        Py_DECREF(a_set);
65
        return NULL;
66
    }
67
    PyObject *result = PyObject_RichCompare(a_set, b_set, op);
68
    Py_DECREF(b_set);
69
    Py_DECREF(a_set);
70
    return result;
71
}
72
73
static int
74
is_same(PyObject *left, PyObject *right)
75
{
76
    int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
77
    return is_ga ? 
PyObject_RichCompareBool(left, right, 48
Py_EQ48
) :
left == right400
;
  Branch (77:12): [True: 48, False: 400]
78
}
79
80
static int
81
contains(PyObject **items, Py_ssize_t size, PyObject *obj)
82
{
83
    for (int i = 0; i < size; 
i++432
) {
  Branch (83:21): [True: 448, False: 358]
84
        int is_duplicate = is_same(items[i], obj);
85
        if (is_duplicate) {  // -1 or 1
  Branch (85:13): [True: 16, False: 432]
86
            return is_duplicate;
87
        }
88
    }
89
    return 0;
90
}
91
92
static PyObject *
93
merge(PyObject **items1, Py_ssize_t size1,
94
      PyObject **items2, Py_ssize_t size2)
95
{
96
    PyObject *tuple = NULL;
97
    Py_ssize_t pos = 0;
98
99
    for (int i = 0; i < size2; 
i++373
) {
  Branch (99:21): [True: 374, False: 364]
100
        PyObject *arg = items2[i];
101
        int is_duplicate = contains(items1, size1, arg);
102
        if (is_duplicate < 0) {
  Branch (102:13): [True: 1, False: 373]
103
            Py_XDECREF(tuple);
104
            return NULL;
105
        }
106
        if (is_duplicate) {
  Branch (106:13): [True: 15, False: 358]
107
            continue;
108
        }
109
110
        if (tuple == NULL) {
  Branch (110:13): [True: 353, False: 5]
111
            tuple = PyTuple_New(size1 + size2 - i);
112
            if (tuple == NULL) {
  Branch (112:17): [True: 0, False: 353]
113
                return NULL;
114
            }
115
            
for (; 353
pos < size1;
pos++416
) {
  Branch (115:20): [True: 416, False: 353]
116
                PyObject *a = items1[pos];
117
                Py_INCREF(a);
118
                PyTuple_SET_ITEM(tuple, pos, a);
119
            }
120
        }
121
        Py_INCREF(arg);
122
        PyTuple_SET_ITEM(tuple, pos, arg);
123
        pos++;
124
    }
125
126
    if (tuple) {
  Branch (126:9): [True: 353, False: 11]
127
        (void) _PyTuple_Resize(&tuple, pos);
128
    }
129
    return tuple;
130
}
131
132
static PyObject **
133
get_types(PyObject **obj, Py_ssize_t *size)
134
{
135
    if (*obj == Py_None) {
  Branch (135:9): [True: 17, False: 713]
136
        *obj = (PyObject *)&_PyNone_Type;
137
    }
138
    if (_PyUnion_Check(*obj)) {
139
        PyObject *args = ((unionobject *) *obj)->args;
140
        *size = PyTuple_GET_SIZE(args);
141
        return &PyTuple_GET_ITEM(args, 0);
142
    }
143
    else {
144
        *size = 1;
145
        return obj;
146
    }
147
}
148
149
static int
150
is_unionable(PyObject *obj)
151
{
152
    return (obj == Py_None ||
  Branch (152:13): [True: 17, False: 775]
153
        
PyType_Check775
(obj) ||
154
        _PyGenericAlias_Check(obj) ||
155
        _PyUnion_Check(obj));
156
}
157
158
PyObject *
159
_Py_union_type_or(PyObject* self, PyObject* other)
160
{
161
    if (!is_unionable(self) || 
!is_unionable(other)395
) {
  Branch (161:9): [True: 2, False: 395]
  Branch (161:32): [True: 30, False: 365]
162
        Py_RETURN_NOTIMPLEMENTED;
163
    }
164
165
    Py_ssize_t size1, size2;
166
    PyObject **items1 = get_types(&self, &size1);
167
    PyObject **items2 = get_types(&other, &size2);
168
    PyObject *tuple = merge(items1, size1, items2, size2);
169
    if (tuple == NULL) {
  Branch (169:9): [True: 12, False: 353]
170
        if (PyErr_Occurred()) {
  Branch (170:13): [True: 1, False: 11]
171
            return NULL;
172
        }
173
        Py_INCREF(self);
174
        return self;
175
    }
176
177
    PyObject *new_union = make_union(tuple);
178
    Py_DECREF(tuple);
179
    return new_union;
180
}
181
182
static int
183
union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
184
{
185
    PyObject *qualname = NULL;
186
    PyObject *module = NULL;
187
    PyObject *tmp;
188
    PyObject *r = NULL;
189
    int err;
190
191
    if (p == (PyObject *)&_PyNone_Type) {
  Branch (191:9): [True: 2, False: 50]
192
        return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
193
    }
194
195
    if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) {
  Branch (195:9): [True: 0, False: 50]
196
        goto exit;
197
    }
198
199
    if (tmp) {
  Branch (199:9): [True: 11, False: 39]
200
        Py_DECREF(tmp);
201
        if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) {
  Branch (201:13): [True: 0, False: 11]
202
            goto exit;
203
        }
204
        if (tmp) {
  Branch (204:13): [True: 11, False: 0]
205
            // It looks like a GenericAlias
206
            Py_DECREF(tmp);
207
            goto use_repr;
208
        }
209
    }
210
211
    if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
  Branch (211:9): [True: 0, False: 39]
212
        goto exit;
213
    }
214
    if (qualname == NULL) {
  Branch (214:9): [True: 0, False: 39]
215
        goto use_repr;
216
    }
217
    if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) {
  Branch (217:9): [True: 0, False: 39]
218
        goto exit;
219
    }
220
    if (module == NULL || module == Py_None) {
  Branch (220:9): [True: 0, False: 39]
  Branch (220:27): [True: 0, False: 39]
221
        goto use_repr;
222
    }
223
224
    // Looks like a class
225
    if (PyUnicode_Check(module) &&
226
        _PyUnicode_EqualToASCIIString(module, "builtins"))
  Branch (226:9): [True: 39, False: 0]
227
    {
228
        // builtins don't need a module name
229
        r = PyObject_Str(qualname);
230
        goto exit;
231
    }
232
    else {
233
        r = PyUnicode_FromFormat("%S.%S", module, qualname);
234
        goto exit;
235
    }
236
237
use_repr:
238
    r = PyObject_Repr(p);
239
exit:
240
    Py_XDECREF(qualname);
241
    Py_XDECREF(module);
242
    if (r == NULL) {
  Branch (242:9): [True: 0, False: 50]
243
        return -1;
244
    }
245
    err = _PyUnicodeWriter_WriteStr(writer, r);
246
    Py_DECREF(r);
247
    return err;
248
}
249
250
static PyObject *
251
union_repr(PyObject *self)
252
{
253
    unionobject *alias = (unionobject *)self;
254
    Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
255
256
    _PyUnicodeWriter writer;
257
    _PyUnicodeWriter_Init(&writer);
258
     for (Py_ssize_t i = 0; i < len; 
i++52
) {
  Branch (258:29): [True: 52, False: 24]
259
        if (i > 0 && 
_PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 028
) {
  Branch (259:13): [True: 28, False: 24]
  Branch (259:22): [True: 0, False: 28]
260
            goto error;
261
        }
262
        PyObject *p = PyTuple_GET_ITEM(alias->args, i);
263
        if (union_repr_item(&writer, p) < 0) {
  Branch (263:13): [True: 0, False: 52]
264
            goto error;
265
        }
266
    }
267
    return _PyUnicodeWriter_Finish(&writer);
268
error:
269
    _PyUnicodeWriter_Dealloc(&writer);
270
    return NULL;
271
}
272
273
static PyMemberDef union_members[] = {
274
        {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
275
        {0}
276
};
277
278
static PyObject *
279
union_getitem(PyObject *self, PyObject *item)
280
{
281
    unionobject *alias = (unionobject *)self;
282
    // Populate __parameters__ if needed.
283
    if (alias->parameters == NULL) {
  Branch (283:9): [True: 3, False: 2]
284
        alias->parameters = _Py_make_parameters(alias->args);
285
        if (alias->parameters == NULL) {
  Branch (285:13): [True: 0, False: 3]
286
            return NULL;
287
        }
288
    }
289
290
    PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
291
    if (newargs == NULL) {
  Branch (291:9): [True: 0, False: 5]
292
        return NULL;
293
    }
294
295
    PyObject *res;
296
    Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
297
    if (nargs == 0) {
  Branch (297:9): [True: 0, False: 5]
298
        res = make_union(newargs);
299
    }
300
    else {
301
        res = PyTuple_GET_ITEM(newargs, 0);
302
        Py_INCREF(res);
303
        for (Py_ssize_t iarg = 1; iarg < nargs; 
iarg++5
) {
  Branch (303:35): [True: 5, False: 5]
304
            PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
305
            Py_SETREF(res, PyNumber_Or(res, arg));
306
            if (res == NULL) {
  Branch (306:17): [True: 0, False: 5]
307
                break;
308
            }
309
        }
310
    }
311
    Py_DECREF(newargs);
312
    return res;
313
}
314
315
static PyMappingMethods union_as_mapping = {
316
    .mp_subscript = union_getitem,
317
};
318
319
static PyObject *
320
union_parameters(PyObject *self, void *Py_UNUSED(unused))
321
{
322
    unionobject *alias = (unionobject *)self;
323
    if (alias->parameters == NULL) {
  Branch (323:9): [True: 16, False: 10]
324
        alias->parameters = _Py_make_parameters(alias->args);
325
        if (alias->parameters == NULL) {
  Branch (325:13): [True: 0, False: 16]
326
            return NULL;
327
        }
328
    }
329
    Py_INCREF(alias->parameters);
330
    return alias->parameters;
331
}
332
333
static PyGetSetDef union_properties[] = {
334
    {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL},
335
    {0}
336
};
337
338
static PyNumberMethods union_as_number = {
339
        .nb_or = _Py_union_type_or, // Add __or__ function
340
};
341
342
static const char* const cls_attrs[] = {
343
        "__module__",  // Required for compatibility with typing module
344
        NULL,
345
};
346
347
static PyObject *
348
union_getattro(PyObject *self, PyObject *name)
349
{
350
    unionobject *alias = (unionobject *)self;
351
    if (PyUnicode_Check(name)) {
352
        for (const char * const *p = cls_attrs; ; 
p++745
) {
353
            if (*p == NULL) {
  Branch (353:17): [True: 745, False: 748]
354
                break;
355
            }
356
            if (_PyUnicode_EqualToASCIIString(name, *p)) {
  Branch (356:17): [True: 3, False: 745]
357
                return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
358
            }
359
        }
360
    }
361
    return PyObject_GenericGetAttr(self, name);
362
}
363
364
PyObject *
365
_Py_union_args(PyObject *self)
366
{
367
    assert(_PyUnion_Check(self));
368
    return ((unionobject *) self)->args;
369
}
370
371
PyTypeObject _PyUnion_Type = {
372
    PyVarObject_HEAD_INIT(&PyType_Type, 0)
373
    .tp_name = "types.UnionType",
374
    .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
375
              "\n"
376
              "E.g. for int | str"),
377
    .tp_basicsize = sizeof(unionobject),
378
    .tp_dealloc = unionobject_dealloc,
379
    .tp_alloc = PyType_GenericAlloc,
380
    .tp_free = PyObject_GC_Del,
381
    .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
382
    .tp_traverse = union_traverse,
383
    .tp_hash = union_hash,
384
    .tp_getattro = union_getattro,
385
    .tp_members = union_members,
386
    .tp_richcompare = union_richcompare,
387
    .tp_as_mapping = &union_as_mapping,
388
    .tp_as_number = &union_as_number,
389
    .tp_repr = union_repr,
390
    .tp_getset = union_properties,
391
};
392
393
static PyObject *
394
make_union(PyObject *args)
395
{
396
    assert(PyTuple_CheckExact(args));
397
398
    unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
399
    if (result == NULL) {
  Branch (399:9): [True: 0, False: 353]
400
        return NULL;
401
    }
402
403
    Py_INCREF(args);
404
    result->parameters = NULL;
405
    result->args = args;
406
    _PyObject_GC_TRACK(result);
407
    return (PyObject*)result;
408
}