Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
weight.quant_state.code = cast_to_device(weight.quant_state.code, x.device)

bias = cast_to_device(self.bias, x.device)
if x.numel() == x.shape[-1]:
# bitsandbytes routes single-vector inputs through gemv_4bit, which can fail with CPU-stored,
# device-autocasted Params4bit weights on some CUDA/bnb combinations. Use the same dequantized
# matmul path that bnb.matmul_4bit uses for batched inputs.
dequantized_weight = bnb.functional.dequantize_4bit(weight, weight.quant_state).to(x.dtype)
return torch.nn.functional.linear(x, dequantized_weight, bias).to(inp_dtype)
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,30 +316,45 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
# Move the original layer to the CPU.
layer_to_device_via_state_dict(orig_layer, "cpu")

# Inference should fail with an input on the device.
with pytest.raises(RuntimeError):
_ = orig_layer(x)
is_nf4_layer = type(orig_layer).__name__ == "InvokeLinearNF4"
# Inference should fail with an input on the device. Do not probe raw NF4 here: with CPU-stored weights and a
# single-row CUDA input, some bitsandbytes versions hit an unsafe gemv_4bit path instead of raising safely.
if not is_nf4_layer:
with pytest.raises((RuntimeError, ValueError)):
_ = orig_layer(x)

# Wrap the original layer.
custom_layer = copy.deepcopy(orig_layer)
custom_layer = wrap_single_custom_layer(custom_layer)

# Inference should still fail with autocasting disabled.
# Inference should still fail with autocasting disabled. See the raw NF4 note above.
custom_layer.set_device_autocasting_enabled(False)
with pytest.raises(RuntimeError):
_ = custom_layer(x)
if not is_nf4_layer:
with pytest.raises((RuntimeError, ValueError)):
_ = custom_layer(x)

# Run inference with the wrapped layer on the device.
custom_layer.set_device_autocasting_enabled(True)
custom_output = custom_layer(x)
assert custom_output.device.type == device

assert torch.allclose(orig_output, custom_output)
if is_nf4_layer:
assert torch.allclose(orig_output, custom_output, atol=1e-5)
else:
assert torch.allclose(orig_output, custom_output)


PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor]


def _has_dora_patch(patches: list[tuple[BaseLayerPatch, float]]) -> bool:
return any(isinstance(patch, DoRALayer) for patch, _ in patches)


def _is_bnb_quantized_linear(layer: torch.nn.Module) -> bool:
return type(layer).__name__ in {"InvokeLinear8bitLt", "InvokeLinearNF4"}


@pytest.fixture(
params=[
"single_lora",
Expand Down Expand Up @@ -564,6 +579,8 @@ def test_quantized_linear_sidecar_patches(
patches, input = patch_under_test

linear_layer, quantized_linear_layer = quantized_linear_layer_under_test
if _is_bnb_quantized_linear(quantized_linear_layer) and _has_dora_patch(patches):
pytest.skip("DoRA patches require readable base weights and are not compatible with bnb quantized layers.")

# Move everything to the device.
layer_to_device_via_state_dict(linear_layer, device)
Expand Down Expand Up @@ -598,6 +615,8 @@ def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device(
patches, input = patch_under_test

_, quantized_linear_layer = quantized_linear_layer_under_test
if _is_bnb_quantized_linear(quantized_linear_layer) and _has_dora_patch(patches):
pytest.skip("DoRA patches require readable base weights and are not compatible with bnb quantized layers.")

# Move everything to the device.
layer_to_device_via_state_dict(quantized_linear_layer, device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: I
linear_8bit_lt_layer.load_state_dict(state_dict)

# Inference of the original layer should fail.
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, ValueError)):
linear_8bit_lt_layer(x)

# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLin
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
linear_nf4_layer.load_state_dict(state_dict)

# Inference of the original layer should fail.
with pytest.raises(RuntimeError):
linear_nf4_layer(x)
# Do not call the raw bitsandbytes NF4 layer here. With CPU-stored weights and a single-row CUDA input, some
# bitsandbytes versions hit an unsafe gemv_4bit path instead of raising a Python exception. The custom layer below
# is the behavior under test.

# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4)
Expand Down
Loading