Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions crosshair/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,33 +490,50 @@ def get_constructor_signature(cls: Type) -> Optional[inspect.Signature]:


def proxy_for_class(typ: Type, varname: str) -> object:
Comment thread
rambip marked this conversation as resolved.
data_members = _TYPE_HINTS.get(typ, None)
# Unwrap parameterized generics (e.g. Container[int] → Container) so that
# get_type_hints and inspect-based helpers receive a plain class.
cls = origin_of(typ)
if not isinstance(cls, type):
cls = typ
data_members = _TYPE_HINTS.get(cls, None)
if data_members is None:
data_members = get_type_hints(typ)
_TYPE_HINTS[typ] = data_members

if sys.version_info >= (3, 8) and type(typ) is typing._TypedDictMeta: # type: ignore
try:
data_members = get_type_hints(cls)
except (AttributeError, NameError):
# Forward references that can't be resolved outside the defining module
# (e.g. torch.utils.data.DataLoader) cause NameError/AttributeError here.
data_members = {}
_TYPE_HINTS[cls] = data_members

if sys.version_info >= (3, 8) and type(cls) is typing._TypedDictMeta: # type: ignore
# Handling for TypedDict
optional_keys = getattr(typ, "__optional_keys__", ())
optional_keys = getattr(cls, "__optional_keys__", ())
keys = (
k
for k in data_members.keys()
if k not in optional_keys or context_statespace().smt_fork()
)
return {k: proxy_for_type(data_members[k], varname + "." + k) for k in keys}

constructor_sig = get_constructor_signature(typ)
constructor_sig = get_constructor_signature(cls)
if constructor_sig is None:
raise CrosshairUnsupported(
f"unable to create concrete instance of {typ} due to bad constructor"
)
# TODO: use dynamic_typing.get_bindings_from_type_arguments(typ) to instantiate
# type variables in `constructor_sig`
bindings = dynamic_typing.get_bindings_from_type_arguments(typ)
Comment thread
rambip marked this conversation as resolved.
if bindings:
new_params = []
for p in constructor_sig.parameters.values():
if p.annotation != inspect.Parameter.empty:
resolved = dynamic_typing.realize(p.annotation, bindings)
p = p.replace(annotation=resolved)
new_params.append(p)
constructor_sig = constructor_sig.replace(parameters=new_params)
args = gen_args(constructor_sig)
typename = name_of_type(typ)
try:
with ResumedTracing():
obj = WithEnforcement(typ)(*args.args, **args.kwargs)
obj = WithEnforcement(cls)(*args.args, **args.kwargs)
except (PreconditionFailed, PostconditionFailed):
# preconditions can be invalidated when the __init__ method has preconditions.
# postconditions can be invalidated when the class has invariants.
Expand Down
43 changes: 43 additions & 0 deletions crosshair/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,49 @@ def f(x: B) -> int:
check_states(f, CONFIRMED)


def test_proxy_for_parameterized_generic() -> None:
T = TypeVar("T")

class Container(Generic[T]):
def __init__(self, value: T) -> None:
self.value = value

with standalone_statespace:
with NoTracing():
obj = proxy_for_class(Container[int], "x")
assert isinstance(obj.value, SymbolicInt) # type: ignore[attr-defined]


def test_proxy_for_multi_typevar_generic() -> None:
A = TypeVar("A")
B = TypeVar("B")

class Pair(Generic[A, B]):
def __init__(self, first: A, second: B) -> None:
self.first = first
self.second = second

with standalone_statespace:
with NoTracing():
obj = proxy_for_class(Pair[int, str], "x")
assert isinstance(obj.first, SymbolicInt) # type: ignore[attr-defined]
assert isinstance(obj.second, LazyIntSymbolicStr) # type: ignore[attr-defined]


def test_proxy_for_class_with_unresolvable_forward_ref() -> None:
class Broken:
value: "NonExistentType" # type: ignore

def __init__(self) -> None:
pass

# Should not raise; falls back to {} for data_members
with standalone_statespace:
with NoTracing():
obj = proxy_for_class(Broken, "x")
assert isinstance(obj, Broken)


def test_any() -> None:
def f(x: Any) -> bool:
"""post: True"""
Expand Down
4 changes: 2 additions & 2 deletions crosshair/dynamic_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def get_bindings_from_type_arguments(pytype: Type) -> Mapping[object, type]:

def realize(pytype: Type, bindings: Mapping[object, type]) -> object:
if typing_inspect.is_typevar(pytype):
return bindings[pytype]
return bindings.get(pytype, pytype)
if not hasattr(pytype, "__args__"):
return pytype
newargs: List = []
Expand All @@ -246,7 +246,7 @@ def realize(pytype: Type, bindings: Mapping[object, type]) -> object:

def realize(pytype: Type, bindings: Mapping[object, type]) -> object:
if typing_inspect.is_typevar(pytype):
return bindings[pytype]
return bindings.get(pytype, pytype)
if not hasattr(pytype, "__args__"):
return pytype
newargs: List = []
Expand Down
1 change: 1 addition & 0 deletions doc/source/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ In order of initial commit. Many thanks!
* `Kevin Turcios <https://github.com/KRRT7>`_
* `Tony Dang <https://github.com/Dang-Hoang-Tung>`_
* `Michael Schvarcz <https://github.com/michael-schvarcz>`_
* `Antonin Peronnet <https://github.com/rambip>`_ (with help from Claude)
Loading