1111import re
1212import sys
1313import traceback
14+ import types
15+ import typing
1416
1517import isort
1618import black
1719
20+ import circuitpython_typing
21+ import circuitpython_typing .socket
22+
1823
1924IMPORTS_IGNORE = frozenset (
2025 {
21- "int" ,
22- "float" ,
26+ "array" ,
2327 "bool" ,
24- "str" ,
28+ "buffer" ,
29+ "bytearray" ,
2530 "bytes" ,
26- "tuple" ,
27- "list" ,
28- "set" ,
2931 "dict" ,
30- "bytearray" ,
31- "slice" ,
3232 "file" ,
33- "buffer" ,
33+ "float" ,
34+ "int" ,
35+ "list" ,
3436 "range" ,
35- "array" ,
37+ "set" ,
38+ "slice" ,
39+ "str" ,
3640 "struct_time" ,
41+ "tuple" ,
3742 }
3843)
39- IMPORTS_TYPING = frozenset (
40- {
41- "Any" ,
42- "Dict" ,
43- "Optional" ,
44- "Union" ,
45- "Tuple" ,
46- "List" ,
47- "Sequence" ,
48- "NamedTuple" ,
49- "Iterable" ,
50- "Iterator" ,
51- "Callable" ,
52- "AnyStr" ,
53- "overload" ,
54- "Type" ,
55- }
56- )
57- IMPORTS_TYPES = frozenset ({"TracebackType" })
58- CPY_TYPING = frozenset (
59- {"ReadableBuffer" , "WriteableBuffer" , "AudioSample" , "FrameBuffer" , "Alarm" }
60- )
44+
45+ # Include all definitions in these type modules, minus some name conflicts.
46+ AVAILABLE_TYPE_MODULE_IMPORTS = {
47+ "types" : frozenset (types .__all__ ),
48+ # Conflicts: countio.Counter, canio.Match
49+ "typing" : frozenset (typing .__all__ ) - set (["Counter" , "Match" ]),
50+ "circuitpython_typing" : frozenset (circuitpython_typing .__all__ ),
51+ "circuitpython_typing.socket" : frozenset (circuitpython_typing .socket .__all__ ),
52+ }
6153
6254
6355def is_typed (node , allow_any = False ):
@@ -116,9 +108,7 @@ def find_stub_issues(tree):
116108
117109def extract_imports (tree ):
118110 modules = set ()
119- typing = set ()
120- types = set ()
121- cpy_typing = set ()
111+ used_type_module_imports = {module : set () for module in AVAILABLE_TYPE_MODULE_IMPORTS .keys ()}
122112
123113 def collect_annotations (anno_tree ):
124114 if anno_tree is None :
@@ -127,12 +117,9 @@ def collect_annotations(anno_tree):
127117 if isinstance (node , ast .Name ):
128118 if node .id in IMPORTS_IGNORE :
129119 continue
130- elif node .id in IMPORTS_TYPING :
131- typing .add (node .id )
132- elif node .id in IMPORTS_TYPES :
133- types .add (node .id )
134- elif node .id in CPY_TYPING :
135- cpy_typing .add (node .id )
120+ for module , imports in AVAILABLE_TYPE_MODULE_IMPORTS .items ():
121+ if node .id in imports :
122+ used_type_module_imports [module ].add (node .id )
136123 elif isinstance (node , ast .Attribute ):
137124 if isinstance (node .value , ast .Name ):
138125 modules .add (node .value .id )
@@ -145,15 +132,12 @@ def collect_annotations(anno_tree):
145132 elif isinstance (node , ast .FunctionDef ):
146133 collect_annotations (node .returns )
147134 for deco in node .decorator_list :
148- if isinstance (deco , ast .Name ) and (deco .id in IMPORTS_TYPING ):
149- typing .add (deco .id )
150-
151- return {
152- "modules" : sorted (modules ),
153- "typing" : sorted (typing ),
154- "types" : sorted (types ),
155- "cpy_typing" : sorted (cpy_typing ),
156- }
135+ if isinstance (deco , ast .Name ) and (
136+ deco .id in AVAILABLE_TYPE_MODULE_IMPORTS ["typing" ]
137+ ):
138+ used_type_module_imports ["typing" ].add (deco .id )
139+
140+ return (modules , used_type_module_imports )
157141
158142
159143def find_references (tree ):
@@ -237,15 +221,11 @@ def convert_folder(top_level, stub_directory):
237221 ok += 1
238222
239223 # Add import statements
240- imports = extract_imports (tree )
224+ imports , type_imports = extract_imports (tree )
241225 import_lines = ["from __future__ import annotations" ]
242- if imports ["types" ]:
243- import_lines .append ("from types import " + ", " .join (imports ["types" ]))
244- if imports ["typing" ]:
245- import_lines .append ("from typing import " + ", " .join (imports ["typing" ]))
246- if imports ["cpy_typing" ]:
247- import_lines .append ("from circuitpython_typing import " + ", " .join (imports ["cpy_typing" ]))
248- import_lines .extend (f"import { m } " for m in imports ["modules" ])
226+ for type_module , used_types in type_imports .items ():
227+ import_lines .append (f"from { type_module } import { ', ' .join (sorted (used_types ))} " )
228+ import_lines .extend (f"import { m } " for m in sorted (imports ))
249229 import_body = "\n " .join (import_lines )
250230 m = re .match (r'(\s*""".*?""")' , stub_contents , flags = re .DOTALL )
251231 if m :
0 commit comments