Skip to content
Merged
40 changes: 30 additions & 10 deletions crosshair/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,33 +490,53 @@ def get_constructor_signature(cls: Type) -> Optional[inspect.Signature]:


def proxy_for_class(typ: Type, varname: str) -> object:
Comment thread
rambip marked this conversation as resolved.
data_members = _TYPE_HINTS.get(typ, None)
# Unwrap parameterized generics (e.g. Container[int] → Container) so that
# get_type_hints and inspect-based helpers receive a plain class.
cls = origin_of(typ)
if not isinstance(cls, type):
cls = typ
data_members = _TYPE_HINTS.get(cls, None)
if data_members is None:
data_members = get_type_hints(typ)
_TYPE_HINTS[typ] = data_members

if sys.version_info >= (3, 8) and type(typ) is typing._TypedDictMeta: # type: ignore
try:
data_members = get_type_hints(cls)
except Exception:
Comment thread
rambip marked this conversation as resolved.
Outdated
# Forward references that can't be resolved outside the defining module
# (e.g. torch.utils.data.DataLoader) cause NameError/AttributeError here.
data_members = {}
_TYPE_HINTS[cls] = data_members

if sys.version_info >= (3, 8) and type(cls) is typing._TypedDictMeta: # type: ignore
# Handling for TypedDict
optional_keys = getattr(typ, "__optional_keys__", ())
optional_keys = getattr(cls, "__optional_keys__", ())
keys = (
k
for k in data_members.keys()
if k not in optional_keys or context_statespace().smt_fork()
)
return {k: proxy_for_type(data_members[k], varname + "." + k) for k in keys}

constructor_sig = get_constructor_signature(typ)
constructor_sig = get_constructor_signature(cls)
if constructor_sig is None:
raise CrosshairUnsupported(
f"unable to create concrete instance of {typ} due to bad constructor"
)
# TODO: use dynamic_typing.get_bindings_from_type_arguments(typ) to instantiate
# type variables in `constructor_sig`
bindings = dynamic_typing.get_bindings_from_type_arguments(typ)
Comment thread
rambip marked this conversation as resolved.
if bindings:
new_params = []
for p in constructor_sig.parameters.values():
if p.annotation != inspect.Parameter.empty:
try:
resolved = dynamic_typing.realize(p.annotation, bindings)
p = p.replace(annotation=resolved)
except KeyError:
Comment thread
rambip marked this conversation as resolved.
Outdated
pass # unbound TypeVar — leave annotation as-is
new_params.append(p)
constructor_sig = constructor_sig.replace(parameters=new_params)
args = gen_args(constructor_sig)
typename = name_of_type(typ)
try:
with ResumedTracing():
obj = WithEnforcement(typ)(*args.args, **args.kwargs)
obj = WithEnforcement(cls)(*args.args, **args.kwargs)
except (PreconditionFailed, PostconditionFailed):
# preconditions can be invalidated when the __init__ method has preconditions.
# postconditions can be invalidated when the class has invariants.
Expand Down
Loading