Line data Source code
1 : // types.UnionType -- used to represent e.g. Union[int, str], int | str
2 : #include "Python.h"
3 : #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK
4 : #include "pycore_unionobject.h"
5 : #include "structmember.h"
6 :
7 :
8 : static PyObject *make_union(PyObject *);
9 :
10 :
11 : typedef struct {
12 : PyObject_HEAD
13 : PyObject *args;
14 : PyObject *parameters;
15 : } unionobject;
16 :
17 : static void
18 9383 : unionobject_dealloc(PyObject *self)
19 : {
20 9383 : unionobject *alias = (unionobject *)self;
21 :
22 9383 : _PyObject_GC_UNTRACK(self);
23 :
24 9383 : Py_XDECREF(alias->args);
25 9383 : Py_XDECREF(alias->parameters);
26 9383 : Py_TYPE(self)->tp_free(self);
27 9383 : }
28 :
29 : static int
30 242786 : union_traverse(PyObject *self, visitproc visit, void *arg)
31 : {
32 242786 : unionobject *alias = (unionobject *)self;
33 242786 : Py_VISIT(alias->args);
34 242786 : Py_VISIT(alias->parameters);
35 242786 : return 0;
36 : }
37 :
38 : static Py_hash_t
39 19 : union_hash(PyObject *self)
40 : {
41 19 : unionobject *alias = (unionobject *)self;
42 19 : PyObject *args = PyFrozenSet_New(alias->args);
43 19 : if (args == NULL) {
44 0 : return (Py_hash_t)-1;
45 : }
46 19 : Py_hash_t hash = PyObject_Hash(args);
47 19 : Py_DECREF(args);
48 19 : return hash;
49 : }
50 :
51 : static PyObject *
52 173 : union_richcompare(PyObject *a, PyObject *b, int op)
53 : {
54 173 : if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
55 144 : Py_RETURN_NOTIMPLEMENTED;
56 : }
57 :
58 29 : PyObject *a_set = PySet_New(((unionobject*)a)->args);
59 29 : if (a_set == NULL) {
60 0 : return NULL;
61 : }
62 29 : PyObject *b_set = PySet_New(((unionobject*)b)->args);
63 29 : if (b_set == NULL) {
64 0 : Py_DECREF(a_set);
65 0 : return NULL;
66 : }
67 29 : PyObject *result = PyObject_RichCompare(a_set, b_set, op);
68 29 : Py_DECREF(b_set);
69 29 : Py_DECREF(a_set);
70 29 : return result;
71 : }
72 :
73 : static int
74 9958 : is_same(PyObject *left, PyObject *right)
75 : {
76 9958 : int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
77 9958 : return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
78 : }
79 :
80 : static int
81 9404 : contains(PyObject **items, Py_ssize_t size, PyObject *obj)
82 : {
83 19346 : for (int i = 0; i < size; i++) {
84 9958 : int is_duplicate = is_same(items[i], obj);
85 9958 : if (is_duplicate) { // -1 or 1
86 16 : return is_duplicate;
87 : }
88 : }
89 9388 : return 0;
90 : }
91 :
92 : static PyObject *
93 9395 : merge(PyObject **items1, Py_ssize_t size1,
94 : PyObject **items2, Py_ssize_t size2)
95 : {
96 9395 : PyObject *tuple = NULL;
97 9395 : Py_ssize_t pos = 0;
98 :
99 18798 : for (int i = 0; i < size2; i++) {
100 9404 : PyObject *arg = items2[i];
101 9404 : int is_duplicate = contains(items1, size1, arg);
102 9404 : if (is_duplicate < 0) {
103 1 : Py_XDECREF(tuple);
104 1 : return NULL;
105 : }
106 9403 : if (is_duplicate) {
107 15 : continue;
108 : }
109 :
110 9388 : if (tuple == NULL) {
111 9383 : tuple = PyTuple_New(size1 + size2 - i);
112 9383 : if (tuple == NULL) {
113 0 : return NULL;
114 : }
115 19309 : for (; pos < size1; pos++) {
116 9926 : PyObject *a = items1[pos];
117 9926 : Py_INCREF(a);
118 9926 : PyTuple_SET_ITEM(tuple, pos, a);
119 : }
120 : }
121 9388 : Py_INCREF(arg);
122 9388 : PyTuple_SET_ITEM(tuple, pos, arg);
123 9388 : pos++;
124 : }
125 :
126 9394 : if (tuple) {
127 9383 : (void) _PyTuple_Resize(&tuple, pos);
128 : }
129 9394 : return tuple;
130 : }
131 :
132 : static PyObject **
133 18790 : get_types(PyObject **obj, Py_ssize_t *size)
134 : {
135 18790 : if (*obj == Py_None) {
136 4313 : *obj = (PyObject *)&_PyNone_Type;
137 : }
138 18790 : if (_PyUnion_Check(*obj)) {
139 533 : PyObject *args = ((unionobject *) *obj)->args;
140 533 : *size = PyTuple_GET_SIZE(args);
141 533 : return &PyTuple_GET_ITEM(args, 0);
142 : }
143 : else {
144 18257 : *size = 1;
145 18257 : return obj;
146 : }
147 : }
148 :
149 : static int
150 20676 : is_unionable(PyObject *obj)
151 : {
152 16363 : return (obj == Py_None ||
153 24137 : PyType_Check(obj) ||
154 44813 : _PyGenericAlias_Check(obj) ||
155 1480 : _PyUnion_Check(obj));
156 : }
157 :
158 : PyObject *
159 10339 : _Py_union_type_or(PyObject* self, PyObject* other)
160 : {
161 10339 : if (!is_unionable(self) || !is_unionable(other)) {
162 944 : Py_RETURN_NOTIMPLEMENTED;
163 : }
164 :
165 : Py_ssize_t size1, size2;
166 9395 : PyObject **items1 = get_types(&self, &size1);
167 9395 : PyObject **items2 = get_types(&other, &size2);
168 9395 : PyObject *tuple = merge(items1, size1, items2, size2);
169 9395 : if (tuple == NULL) {
170 12 : if (PyErr_Occurred()) {
171 1 : return NULL;
172 : }
173 11 : Py_INCREF(self);
174 11 : return self;
175 : }
176 :
177 9383 : PyObject *new_union = make_union(tuple);
178 9383 : Py_DECREF(tuple);
179 9383 : return new_union;
180 : }
181 :
182 : static int
183 52 : union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
184 : {
185 52 : PyObject *qualname = NULL;
186 52 : PyObject *module = NULL;
187 : PyObject *tmp;
188 52 : PyObject *r = NULL;
189 : int err;
190 :
191 52 : if (p == (PyObject *)&_PyNone_Type) {
192 2 : return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
193 : }
194 :
195 50 : if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) {
196 0 : goto exit;
197 : }
198 :
199 50 : if (tmp) {
200 11 : Py_DECREF(tmp);
201 11 : if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) {
202 0 : goto exit;
203 : }
204 11 : if (tmp) {
205 : // It looks like a GenericAlias
206 11 : Py_DECREF(tmp);
207 11 : goto use_repr;
208 : }
209 : }
210 :
211 39 : if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
212 0 : goto exit;
213 : }
214 39 : if (qualname == NULL) {
215 0 : goto use_repr;
216 : }
217 39 : if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) {
218 0 : goto exit;
219 : }
220 39 : if (module == NULL || module == Py_None) {
221 0 : goto use_repr;
222 : }
223 :
224 : // Looks like a class
225 78 : if (PyUnicode_Check(module) &&
226 39 : _PyUnicode_EqualToASCIIString(module, "builtins"))
227 : {
228 : // builtins don't need a module name
229 39 : r = PyObject_Str(qualname);
230 39 : goto exit;
231 : }
232 : else {
233 0 : r = PyUnicode_FromFormat("%S.%S", module, qualname);
234 0 : goto exit;
235 : }
236 :
237 11 : use_repr:
238 11 : r = PyObject_Repr(p);
239 50 : exit:
240 50 : Py_XDECREF(qualname);
241 50 : Py_XDECREF(module);
242 50 : if (r == NULL) {
243 0 : return -1;
244 : }
245 50 : err = _PyUnicodeWriter_WriteStr(writer, r);
246 50 : Py_DECREF(r);
247 50 : return err;
248 : }
249 :
250 : static PyObject *
251 24 : union_repr(PyObject *self)
252 : {
253 24 : unionobject *alias = (unionobject *)self;
254 24 : Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
255 :
256 : _PyUnicodeWriter writer;
257 24 : _PyUnicodeWriter_Init(&writer);
258 76 : for (Py_ssize_t i = 0; i < len; i++) {
259 52 : if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
260 0 : goto error;
261 : }
262 52 : PyObject *p = PyTuple_GET_ITEM(alias->args, i);
263 52 : if (union_repr_item(&writer, p) < 0) {
264 0 : goto error;
265 : }
266 : }
267 24 : return _PyUnicodeWriter_Finish(&writer);
268 0 : error:
269 0 : _PyUnicodeWriter_Dealloc(&writer);
270 0 : return NULL;
271 : }
272 :
273 : static PyMemberDef union_members[] = {
274 : {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
275 : {0}
276 : };
277 :
278 : static PyObject *
279 5 : union_getitem(PyObject *self, PyObject *item)
280 : {
281 5 : unionobject *alias = (unionobject *)self;
282 : // Populate __parameters__ if needed.
283 5 : if (alias->parameters == NULL) {
284 3 : alias->parameters = _Py_make_parameters(alias->args);
285 3 : if (alias->parameters == NULL) {
286 0 : return NULL;
287 : }
288 : }
289 :
290 5 : PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
291 5 : if (newargs == NULL) {
292 0 : return NULL;
293 : }
294 :
295 : PyObject *res;
296 5 : Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
297 5 : if (nargs == 0) {
298 0 : res = make_union(newargs);
299 : }
300 : else {
301 5 : res = PyTuple_GET_ITEM(newargs, 0);
302 5 : Py_INCREF(res);
303 10 : for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
304 5 : PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
305 5 : Py_SETREF(res, PyNumber_Or(res, arg));
306 5 : if (res == NULL) {
307 0 : break;
308 : }
309 : }
310 : }
311 5 : Py_DECREF(newargs);
312 5 : return res;
313 : }
314 :
315 : static PyMappingMethods union_as_mapping = {
316 : .mp_subscript = union_getitem,
317 : };
318 :
319 : static PyObject *
320 26 : union_parameters(PyObject *self, void *Py_UNUSED(unused))
321 : {
322 26 : unionobject *alias = (unionobject *)self;
323 26 : if (alias->parameters == NULL) {
324 16 : alias->parameters = _Py_make_parameters(alias->args);
325 16 : if (alias->parameters == NULL) {
326 0 : return NULL;
327 : }
328 : }
329 26 : Py_INCREF(alias->parameters);
330 26 : return alias->parameters;
331 : }
332 :
333 : static PyGetSetDef union_properties[] = {
334 : {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL},
335 : {0}
336 : };
337 :
338 : static PyNumberMethods union_as_number = {
339 : .nb_or = _Py_union_type_or, // Add __or__ function
340 : };
341 :
342 : static const char* const cls_attrs[] = {
343 : "__module__", // Required for compatibility with typing module
344 : NULL,
345 : };
346 :
347 : static PyObject *
348 748 : union_getattro(PyObject *self, PyObject *name)
349 : {
350 748 : unionobject *alias = (unionobject *)self;
351 748 : if (PyUnicode_Check(name)) {
352 1493 : for (const char * const *p = cls_attrs; ; p++) {
353 1493 : if (*p == NULL) {
354 745 : break;
355 : }
356 748 : if (_PyUnicode_EqualToASCIIString(name, *p)) {
357 3 : return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
358 : }
359 : }
360 : }
361 745 : return PyObject_GenericGetAttr(self, name);
362 : }
363 :
364 : PyObject *
365 38 : _Py_union_args(PyObject *self)
366 : {
367 38 : assert(_PyUnion_Check(self));
368 38 : return ((unionobject *) self)->args;
369 : }
370 :
371 : PyTypeObject _PyUnion_Type = {
372 : PyVarObject_HEAD_INIT(&PyType_Type, 0)
373 : .tp_name = "types.UnionType",
374 : .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
375 : "\n"
376 : "E.g. for int | str"),
377 : .tp_basicsize = sizeof(unionobject),
378 : .tp_dealloc = unionobject_dealloc,
379 : .tp_alloc = PyType_GenericAlloc,
380 : .tp_free = PyObject_GC_Del,
381 : .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
382 : .tp_traverse = union_traverse,
383 : .tp_hash = union_hash,
384 : .tp_getattro = union_getattro,
385 : .tp_members = union_members,
386 : .tp_richcompare = union_richcompare,
387 : .tp_as_mapping = &union_as_mapping,
388 : .tp_as_number = &union_as_number,
389 : .tp_repr = union_repr,
390 : .tp_getset = union_properties,
391 : };
392 :
393 : static PyObject *
394 9383 : make_union(PyObject *args)
395 : {
396 9383 : assert(PyTuple_CheckExact(args));
397 :
398 9383 : unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
399 9383 : if (result == NULL) {
400 0 : return NULL;
401 : }
402 :
403 9383 : Py_INCREF(args);
404 9383 : result->parameters = NULL;
405 9383 : result->args = args;
406 9383 : _PyObject_GC_TRACK(result);
407 9383 : return (PyObject*)result;
408 : }
|