1717from ..zk_proof_generator import generate_proofs , resolve_proof_paths
1818from ..mpi_lora_onnx_exporter import export_lora_onnx_json_mpi
1919
20+
2021def read_file_as_bytes (path : str ) -> bytes :
2122 with open (path , "rb" ) as f :
2223 return f .read ()
2324
25+
2426def strip_prefix (raw_name : str ) -> str :
2527 """
2628 Remove 'base_model.model.', 'base_model.', 'model.' from the submodule name.
@@ -30,9 +32,10 @@ def strip_prefix(raw_name: str) -> str:
3032 name2 = raw_name
3133 for pfx in ["base_model.model." , "base_model." , "model." ]:
3234 if name2 .startswith (pfx ):
33- name2 = name2 [len (pfx ):]
35+ name2 = name2 [len (pfx ) :]
3436 return name2 .strip ()
3537
38+
3639class LoRAServer :
3740 def __init__ (self , base_model_name : str , lora_model_id : str , out_dir : str ):
3841 self .out_dir = out_dir
@@ -91,18 +94,20 @@ def finalize_proofs_and_collect(self):
9194 x_data = last_in ,
9295 submodule = mod ,
9396 output_dir = self .out_dir ,
94- verbose = True
97+ verbose = True ,
9598 )
9699 self .session_data .clear ()
97100
98101 # generate proofs synchronously
99- print ("[A] Running generate_proofs(...) via asyncio.run(...) in the same thread." )
102+ print (
103+ "[A] Running generate_proofs(...) via asyncio.run(...) in the same thread."
104+ )
100105 proof_res = asyncio .run (
101106 generate_proofs (
102107 onnx_dir = self .out_dir ,
103108 json_dir = self .out_dir ,
104109 output_dir = self .out_dir ,
105- verbose = True
110+ verbose = True ,
106111 )
107112 )
108113
@@ -113,6 +118,7 @@ def finalize_proofs_and_collect(self):
113118
114119 return
115120
121+
116122class LoRAServerSocket (threading .Thread ):
117123 def __init__ (self , host , port , lora_server : LoRAServer , stop_event ):
118124 super ().__init__ ()
@@ -123,13 +129,16 @@ def __init__(self, host, port, lora_server: LoRAServer, stop_event):
123129
124130 def run (self ):
125131 import socket
132+
126133 print (f"[A-Server] listening on { self .host } :{ self .port } " )
127134 srv = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
128135 srv .bind ((self .host , self .port ))
129136 srv .listen (5 )
130137 srv .settimeout (1200.0 )
131138
132- print (f"[A-Server] Running on { self .host } :{ self .port } , local artifacts in '{ self .lora_server .out_dir } '" )
139+ print (
140+ f"[A-Server] Running on { self .host } :{ self .port } , local artifacts in '{ self .lora_server .out_dir } '"
141+ )
133142 try :
134143 while not self .stop_event .is_set ():
135144 try :
@@ -147,28 +156,28 @@ def handle_conn(self, conn, addr):
147156 if not data :
148157 return
149158 req = pickle .loads (data )
150- rtype = req .get ("request_type" ,"lora_forward" )
159+ rtype = req .get ("request_type" , "lora_forward" )
151160
152161 if rtype == "init_request" :
153162 submods = self .lora_server .list_lora_injection_points ()
154- resp = {"response_type" :"init_response" ,"injection_points" : submods }
163+ resp = {"response_type" : "init_response" , "injection_points" : submods }
155164
156165 elif rtype == "lora_forward" :
157166 sname = req ["submodule_name" ]
158167 arr = req ["input_array" ]
159168 tin = torch .tensor (arr , dtype = torch .float32 )
160169 out = self .lora_server .apply_lora (sname , tin )
161170 resp = {
162- "response_type" :"lora_forward_response" ,
163- "output_array" : out .cpu ().numpy ()
171+ "response_type" : "lora_forward_response" ,
172+ "output_array" : out .cpu ().numpy (),
164173 }
165174
166175 elif rtype == "end_inference" :
167176 # generate proofs locally
168177 self .lora_server .finalize_proofs_and_collect ()
169178 resp = {
170179 "response_type" : "end_inference_ack" ,
171- "message" : "A finished proof generation locally. B can close."
180+ "message" : "A finished proof generation locally. B can close." ,
172181 }
173182
174183 else :
@@ -191,4 +200,4 @@ def recv_all(self, conn, chunk_size=4096):
191200 if not chunk :
192201 break
193202 buffer += chunk
194- return buffer
203+ return buffer
0 commit comments