Skip to content

Commit 747b960

Browse files
committed
Refactor code for consistency and readability
- Updated string formatting from single quotes to double quotes in several files for uniformity. - Added newlines for improved readability in multiple functions and classes across various modules. - Enhanced error messages and print statements for better clarity during execution.
1 parent bcc59ea commit 747b960

6 files changed

Lines changed: 104 additions & 66 deletions

File tree

src/zklora/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '0.1.2'
1+
__version__ = "0.1.2"
22

33
from .zk_proof_generator import batch_verify_proofs
44
from .lora_contributor_mpi import LoRAServer, LoRAServerSocket
@@ -7,11 +7,11 @@
77

88

99
__all__ = [
10-
'batch_verify_proofs',
11-
'LoRAServer',
12-
'LoRAServerSocket',
13-
'BaseModelClient',
14-
'commit_activations',
15-
'verify_commitment',
16-
'__version__',
17-
]
10+
"batch_verify_proofs",
11+
"LoRAServer",
12+
"LoRAServerSocket",
13+
"BaseModelClient",
14+
"commit_activations",
15+
"verify_commitment",
16+
"__version__",
17+
]

src/zklora/activations_commit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,28 @@
22
import json
33
import numpy as np
44

5+
56
def get_merkle_root(activations_path: str) -> str:
67
"""
78
Calculate the Merkle root hash of model activations stored in a JSON file.
8-
9+
910
Args:
1011
activations_path: Path to JSON file containing model activations under "input_data" key
11-
12+
1213
Returns:
1314
str: Hexadecimal string of the Merkle root hash, prefixed with "0x"
1415
"""
1516
# Load the intermediate activations from JSON file
16-
with open(activations_path, 'r') as f:
17+
with open(activations_path, "r") as f:
1718
activations = json.load(f)
1819

1920
# Convert nested data to numpy array and flatten
2021
flattened_np = np.array(activations["input_data"]).reshape(-1)
21-
22+
2223
# Get and return the Merkle root hash
2324
return merkle.insert_values(flattened_np.tolist())
2425

26+
2527
if __name__ == "__main__":
2628
activations_path = "intermediate_activations/base_model_model_lm_head.json"
2729
merkle_root = get_merkle_root(activations_path)

src/zklora/base_model_user_mpi/__init__.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.nn as nn
99
from transformers import AutoModelForCausalLM, AutoTokenizer
1010

11+
1112
class BaseModelToLoRAComm:
1213
def __init__(self, host_a="127.0.0.1", port_a=30000):
1314
self.host_a = host_a
@@ -20,16 +21,18 @@ def init_request(self):
2021

2122
def lora_forward(self, sub_name, arr):
2223
req = {
23-
"request_type":"lora_forward",
24+
"request_type": "lora_forward",
2425
"submodule_name": sub_name,
25-
"input_array": arr
26+
"input_array": arr,
2627
}
2728
resp = self.send_and_recv(req)
2829
return resp.get("output_array", None)
2930

3031
def end_inference(self):
3132
req = {"request_type": "end_inference"}
32-
resp = self.send_and_recv(req)#, timeout=600.0) # might be slower if proof gen is big
33+
resp = self.send_and_recv(
34+
req
35+
) # , timeout=600.0) # might be slower if proof gen is big
3336
return resp
3437

3538
def send_and_recv(self, data_dict):
@@ -52,13 +55,18 @@ def send_and_recv(self, data_dict):
5255
s.close()
5356

5457
if not buffer:
55-
raise RuntimeError("[B] No data from A (EOF). Possibly A took too long or closed early.")
58+
raise RuntimeError(
59+
"[B] No data from A (EOF). Possibly A took too long or closed early."
60+
)
5661

5762
resp = pickle.loads(buffer)
5863
return resp
5964

65+
6066
class RemoteLoRAWrappedModule(nn.Module):
61-
def __init__(self, sub_name, local_sub, comm: BaseModelToLoRAComm, combine_mode="replace"):
67+
def __init__(
68+
self, sub_name, local_sub, comm: BaseModelToLoRAComm, combine_mode="replace"
69+
):
6270
super().__init__()
6371
self.sub_name = sub_name
6472
self.local_sub = local_sub
@@ -77,6 +85,7 @@ def forward(self, x: torch.Tensor):
7785
return base_out + out_t
7886
return out_t
7987

88+
8089
class BaseModelClient:
8190
def __init__(
8291
self,
@@ -127,9 +136,13 @@ def init_and_patch(self):
127136
*parents, child = path_parts
128137
m = self._navigate(self.model, parents)
129138
orig_sub = getattr(m, child)
130-
wrapped = RemoteLoRAWrappedModule(full_name, orig_sub, comm, self.combine_mode)
139+
wrapped = RemoteLoRAWrappedModule(
140+
full_name, orig_sub, comm, self.combine_mode
141+
)
131142
setattr(m, child, wrapped)
132-
print(f"[B] Patched submodule '{full_name}' from {comm.host_a}:{comm.port_a}.")
143+
print(
144+
f"[B] Patched submodule '{full_name}' from {comm.host_a}:{comm.port_a}."
145+
)
133146
except Exception as e:
134147
print(f"[B] Could not patch '{full_name}': {e}")
135148

@@ -144,4 +157,6 @@ def end_inference(self):
144157
"""Notify all contributors that inference is finished."""
145158
for comm in self.comms:
146159
resp = comm.end_inference()
147-
print("[B] end_inference => got ack from", comm.host_a, comm.port_a, ":", resp)
160+
print(
161+
"[B] end_inference => got ack from", comm.host_a, comm.port_a, ":", resp
162+
)

src/zklora/lora_contributor_mpi/__init__.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
from ..zk_proof_generator import generate_proofs, resolve_proof_paths
1818
from ..mpi_lora_onnx_exporter import export_lora_onnx_json_mpi
1919

20+
2021
def read_file_as_bytes(path: str) -> bytes:
2122
with open(path, "rb") as f:
2223
return f.read()
2324

25+
2426
def strip_prefix(raw_name: str) -> str:
2527
"""
2628
Remove 'base_model.model.', 'base_model.', 'model.' from the submodule name.
@@ -30,9 +32,10 @@ def strip_prefix(raw_name: str) -> str:
3032
name2 = raw_name
3133
for pfx in ["base_model.model.", "base_model.", "model."]:
3234
if name2.startswith(pfx):
33-
name2 = name2[len(pfx):]
35+
name2 = name2[len(pfx) :]
3436
return name2.strip()
3537

38+
3639
class LoRAServer:
3740
def __init__(self, base_model_name: str, lora_model_id: str, out_dir: str):
3841
self.out_dir = out_dir
@@ -91,18 +94,20 @@ def finalize_proofs_and_collect(self):
9194
x_data=last_in,
9295
submodule=mod,
9396
output_dir=self.out_dir,
94-
verbose=True
97+
verbose=True,
9598
)
9699
self.session_data.clear()
97100

98101
# generate proofs synchronously
99-
print("[A] Running generate_proofs(...) via asyncio.run(...) in the same thread.")
102+
print(
103+
"[A] Running generate_proofs(...) via asyncio.run(...) in the same thread."
104+
)
100105
proof_res = asyncio.run(
101106
generate_proofs(
102107
onnx_dir=self.out_dir,
103108
json_dir=self.out_dir,
104109
output_dir=self.out_dir,
105-
verbose=True
110+
verbose=True,
106111
)
107112
)
108113

@@ -113,6 +118,7 @@ def finalize_proofs_and_collect(self):
113118

114119
return
115120

121+
116122
class LoRAServerSocket(threading.Thread):
117123
def __init__(self, host, port, lora_server: LoRAServer, stop_event):
118124
super().__init__()
@@ -123,13 +129,16 @@ def __init__(self, host, port, lora_server: LoRAServer, stop_event):
123129

124130
def run(self):
125131
import socket
132+
126133
print(f"[A-Server] listening on {self.host}:{self.port}")
127134
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
128135
srv.bind((self.host, self.port))
129136
srv.listen(5)
130137
srv.settimeout(1200.0)
131138

132-
print(f"[A-Server] Running on {self.host}:{self.port}, local artifacts in '{self.lora_server.out_dir}'")
139+
print(
140+
f"[A-Server] Running on {self.host}:{self.port}, local artifacts in '{self.lora_server.out_dir}'"
141+
)
133142
try:
134143
while not self.stop_event.is_set():
135144
try:
@@ -147,28 +156,28 @@ def handle_conn(self, conn, addr):
147156
if not data:
148157
return
149158
req = pickle.loads(data)
150-
rtype = req.get("request_type","lora_forward")
159+
rtype = req.get("request_type", "lora_forward")
151160

152161
if rtype == "init_request":
153162
submods = self.lora_server.list_lora_injection_points()
154-
resp = {"response_type":"init_response","injection_points": submods}
163+
resp = {"response_type": "init_response", "injection_points": submods}
155164

156165
elif rtype == "lora_forward":
157166
sname = req["submodule_name"]
158167
arr = req["input_array"]
159168
tin = torch.tensor(arr, dtype=torch.float32)
160169
out = self.lora_server.apply_lora(sname, tin)
161170
resp = {
162-
"response_type":"lora_forward_response",
163-
"output_array": out.cpu().numpy()
171+
"response_type": "lora_forward_response",
172+
"output_array": out.cpu().numpy(),
164173
}
165174

166175
elif rtype == "end_inference":
167176
# generate proofs locally
168177
self.lora_server.finalize_proofs_and_collect()
169178
resp = {
170179
"response_type": "end_inference_ack",
171-
"message": "A finished proof generation locally. B can close."
180+
"message": "A finished proof generation locally. B can close.",
172181
}
173182

174183
else:
@@ -191,4 +200,4 @@ def recv_all(self, conn, chunk_size=4096):
191200
if not chunk:
192201
break
193202
buffer += chunk
194-
return buffer
203+
return buffer

src/zklora/mpi_lora_onnx_exporter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# zklora/mpi_lora_onnx_exporter.py
22
"""
3-
New code specifically for 'split inference' (MPI) scenario,
3+
New code specifically for 'split inference' (MPI) scenario,
44
similar to lora_onnx_exporter but with different approach or naming to avoid collisions.
55
"""
66

@@ -9,14 +9,13 @@
99
import torch
1010
import numpy as np
1111
import torch.nn as nn
12-
from peft import PeftModel
1312

1413

1514
def normalize_lora_matrices_mpi(
1615
A: torch.Tensor, B: torch.Tensor, x_data: np.ndarray
1716
) -> tuple[torch.Tensor, torch.Tensor, int, int, int]:
1817
"""
19-
Same shape logic as the older function, but with a new name
18+
Same shape logic as the older function, but with a new name
2019
to avoid collisions with the old version.
2120
x_data => (batch, seq_len, hidden_dim).
2221
"""
@@ -43,7 +42,7 @@ def normalize_lora_matrices_mpi(
4342

4443
class LoraShapeTransformerMPI(nn.Module):
4544
"""
46-
Variation of LoraShapeTransformer used specifically for
45+
Variation of LoraShapeTransformer used specifically for
4746
the split-inference approach, with a new class name to avoid collisions.
4847
"""
4948

@@ -72,7 +71,7 @@ def export_lora_onnx_json_mpi(
7271
verbose: bool = False,
7372
):
7473
"""
75-
The 'split inference' version of the ONNX+JSON exporter.
74+
The 'split inference' version of the ONNX+JSON exporter.
7675
Similar logic but a different name to avoid collisions with the old function.
7776
"""
7877
import torch.onnx
@@ -85,13 +84,17 @@ def export_lora_onnx_json_mpi(
8584
# If the submodule doesn't have lora_A/lora_B, skip
8685
if not (hasattr(submodule, "lora_A") and hasattr(submodule, "lora_B")):
8786
if verbose:
88-
print(f"[export_lora_onnx_json_mpi] No lora_A/B in submodule '{sub_name}', skipping.")
87+
print(
88+
f"[export_lora_onnx_json_mpi] No lora_A/B in submodule '{sub_name}', skipping."
89+
)
8990
return
9091

9192
a_keys = list(submodule.lora_A.keys()) if hasattr(submodule.lora_A, "keys") else []
9293
if not a_keys:
9394
if verbose:
94-
print(f"[export_lora_onnx_json_mpi] No adapter keys in submodule.lora_A for '{sub_name}'.")
95+
print(
96+
f"[export_lora_onnx_json_mpi] No adapter keys in submodule.lora_A for '{sub_name}'."
97+
)
9598
return
9699

97100
A_mod = submodule.lora_A[a_keys[0]]
@@ -102,14 +105,19 @@ def export_lora_onnx_json_mpi(
102105

103106
try:
104107
from .mpi_lora_onnx_exporter import normalize_lora_matrices_mpi
105-
A_fixed, B_fixed, in_dim, rank, out_dim = normalize_lora_matrices_mpi(A, B, x_data)
108+
109+
A_fixed, B_fixed, in_dim, rank, out_dim = normalize_lora_matrices_mpi(
110+
A, B, x_data
111+
)
106112
except ValueError as e:
107113
if verbose:
108114
print(f"Shape fix error for '{sub_name}': {e}")
109115
return
110116

111117
# Build the shape-transformer
112-
lora_transformer = LoraShapeTransformerMPI(A_fixed, B_fixed, batch_size, seq_len, hidden_dim).eval()
118+
lora_transformer = LoraShapeTransformerMPI(
119+
A_fixed, B_fixed, batch_size, seq_len, hidden_dim
120+
).eval()
113121

114122
safe_name = sub_name.replace(".", "_").replace("/", "_")
115123
os.makedirs(output_dir, exist_ok=True)
@@ -135,7 +143,6 @@ def export_lora_onnx_json_mpi(
135143
print(f"Export error for '{sub_name}': {e}")
136144

137145
# Save JSON
138-
import json
139146
json_path = os.path.join(output_dir, f"{safe_name}.json")
140147
with open(json_path, "w") as f:
141148
row_data = x_1d.numpy().tolist()

0 commit comments

Comments
 (0)