@@ -30,18 +30,21 @@ def fetch_fw(path, name, sha256):
3030from tinygrad .device import Device
3131from tinygrad .engine .jit import TinyJit
3232
33- from openpilot .common .file_chunker import read_file_chunked
34- from openpilot .system .hardware .hw import Paths
35-
3633
3734NV12Frame = namedtuple ("NV12Frame" , ['width' , 'height' , 'stride' , 'y_height' , 'uv_height' , 'size' ])
35+ WARP_INPUTS = ['img_q' , 'big_img_q' , 'tfm' , 'big_tfm' ]
36+ POLICY_INPUTS = ['feat_q' , 'desire_q' , 'desire' , 'traffic_convention' ]
3837
3938UV_SCALE_MATRIX = np .array ([[0.5 , 0 , 0 ], [0 , 0.5 , 0 ], [0 , 0 , 1 ]], dtype = np .float32 )
4039UV_SCALE_MATRIX_INV = np .linalg .inv (UV_SCALE_MATRIX )
4140
4241WARP_DEV = os .getenv ('WARP_DEV' )
4342
4443
44+ def make_random_images (keys , shape , device = None ):
45+ return {k : Tensor .randint (shape , low = 0 , high = 256 , dtype = 'uint8' , device = device ).realize () for k in keys }
46+
47+
4548def warp_perspective_tinygrad (src_flat , M_inv , dst_shape , src_shape , stride_pad , border_fill_val = None ):
4649 w_dst , h_dst = dst_shape
4750 h_src , w_src = src_shape
@@ -148,55 +151,49 @@ def sample_desire(buf, frame_skip):
148151 return buf .reshape (- 1 , frame_skip , * buf .shape [1 :]).max (1 ).flatten (0 , 1 ).unsqueeze (0 )
149152
150153
151- def make_run_policy (vision_runner , policy_runner , nv12 : NV12Frame , model_w , model_h ,
152- vision_features_slice , frame_skip , prepare_only = False ):
154+ def make_warp (nv12 , model_w , model_h , frame_skip ):
153155 frame_prepare = make_frame_prepare (nv12 , model_w , model_h )
154156 sample_skip_fn = partial (sample_skip , frame_skip = frame_skip )
155- sample_desire_fn = partial (sample_desire , frame_skip = frame_skip )
156157
157- def run_policy (img_q , big_img_q , feat_q , desire_q , desire , traffic_convention , tfm , big_tfm , frame , big_frame ):
158+ def warp_enqueue (img_q , big_img_q , tfm , big_tfm , frame , big_frame ):
158159 tfm = tfm .to (WARP_DEV )
159160 big_tfm = big_tfm .to (WARP_DEV )
160- desire = desire .to (Device .DEFAULT )
161- traffic_convention = traffic_convention .to (Device .DEFAULT )
162- Tensor .realize (tfm , big_tfm , desire , traffic_convention )
161+ Tensor .realize (tfm , big_tfm )
163162
164163 warped_frame = frame_prepare (frame , tfm ).unsqueeze (0 ).to (Device .DEFAULT )
165164 warped_big_frame = frame_prepare (big_frame , big_tfm ).unsqueeze (0 ).to (Device .DEFAULT )
166165 img = shift_and_sample (img_q , warped_frame , sample_skip_fn )
167166 big_img = shift_and_sample (big_img_q , warped_big_frame , sample_skip_fn )
167+ return img , big_img
168+ return warp_enqueue
168169
169- if prepare_only :
170- return img , big_img
171170
171+ def make_run_policy (vision_runner , policy_runner , vision_features_slice , frame_skip ):
172+ sample_desire_fn = partial (sample_desire , frame_skip = frame_skip )
173+ sample_skip_fn = partial (sample_skip , frame_skip = frame_skip )
174+
175+ def run_policy (img , big_img , feat_q , desire_q , desire , traffic_convention ):
176+ desire = desire .to (Device .DEFAULT )
177+ traffic_convention = traffic_convention .to (Device .DEFAULT )
178+ Tensor .realize (desire , traffic_convention )
179+ desire_buf = shift_and_sample (desire_q , desire .reshape (1 , 1 , - 1 ), sample_desire_fn )
172180 vision_out = next (iter (vision_runner ({'img' : img , 'big_img' : big_img }).values ())).cast ('float32' )
173181
174182 new_feat = vision_out [:, vision_features_slice ].reshape (1 , - 1 ).unsqueeze (0 )
175183 feat_buf = shift_and_sample (feat_q , new_feat , sample_skip_fn )
176- desire_buf = shift_and_sample (desire_q , desire .reshape (1 , 1 , - 1 ), sample_desire_fn )
177184
178185 inputs = {'features_buffer' : feat_buf , 'desire_pulse' : desire_buf , 'traffic_convention' : traffic_convention }
179186 policy_out = next (iter (policy_runner (inputs ).values ())).cast ('float32' )
180-
181187 return vision_out , policy_out
182188 return run_policy
183189
184190
185- def compile_modeld (nv12 : NV12Frame , model_w , model_h , prepare_only , frame_skip ,
186- vision_runner , policy_runner , vision_metadata , policy_metadata ):
187- print (f"Compiling combined policy JIT for { nv12 .width } x{ nv12 .height } (prepare_only={ prepare_only } )..." )
188-
189- vision_features_slice = vision_metadata ['output_slices' ]['hidden_state' ]
191+ def compile_jit (jit , make_random_inputs , input_keys , frame_skip , vision_metadata , policy_metadata ):
190192 vision_input_shapes = vision_metadata ['input_shapes' ]
191193 policy_input_shapes = policy_metadata ['input_shapes' ]
192194
193- _run = make_run_policy (vision_runner , policy_runner , nv12 , model_w , model_h ,
194- vision_features_slice , frame_skip , prepare_only )
195- run_policy_jit = TinyJit (_run , prune = True )
196-
197195 SEED = 42
198-
199- def random_inputs_run_fn (fn , seed , test_val = None , test_buffers = None , expect_match = True ):
196+ def random_inputs_run (fn , seed , test_val = None , test_buffers = None , expect_match = True ):
200197 input_queues , npy = make_input_queues (vision_input_shapes , policy_input_shapes , frame_skip , Device .DEFAULT )
201198 np .random .seed (seed )
202199 Tensor .manual_seed (seed )
@@ -205,13 +202,12 @@ def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_matc
205202 n_runs = 1 if testing else 3
206203
207204 for i in range (n_runs ):
208- frame = Tensor .randint (nv12 .size , low = 0 , high = 256 , dtype = 'uint8' , device = WARP_DEV ).realize ()
209- big_frame = Tensor .randint (nv12 .size , low = 0 , high = 256 , dtype = 'uint8' , device = WARP_DEV ).realize ()
210205 for v in npy .values ():
211206 v [:] = np .random .randn (* v .shape ).astype (v .dtype )
212207 Device .default .synchronize ()
208+ random_inputs = make_random_inputs ()
213209 st = time .perf_counter ()
214- outs = fn (** input_queues , frame = frame , big_frame = big_frame )
210+ outs = fn (** { k : input_queues [ k ] for k in input_keys }, ** random_inputs )
215211 mt = time .perf_counter ()
216212 Device .default .synchronize ()
217213 et = time .perf_counter ()
@@ -227,16 +223,15 @@ def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_matc
227223 if test_buffers is not None :
228224 match = all (np .array_equal (a , b ) for a , b in zip (buffers , test_buffers , strict = True ))
229225 assert match == expect_match , f"buffers { 'differ from' if expect_match else 'match' } baseline (seed={ seed } )"
230- return fn , val , buffers
226+ return val , buffers
231227
232228 print ('capture + replay' )
233- run_policy_jit , test_val , test_buffers = random_inputs_run_fn (run_policy_jit , SEED )
234-
229+ test_val , test_buffers = random_inputs_run (jit , SEED )
235230 print ('pickle round trip' )
236- run_policy_jit = pickle .loads (pickle .dumps (run_policy_jit ))
237- random_inputs_run_fn ( run_policy_jit , SEED , test_val , test_buffers , expect_match = True )
238- random_inputs_run_fn ( run_policy_jit , SEED + 1 , test_val , test_buffers , expect_match = False )
239- return run_policy_jit
231+ jit = pickle .loads (pickle .dumps (jit ))
232+ random_inputs_run ( jit , SEED , test_val , test_buffers , expect_match = True )
233+ random_inputs_run ( jit , SEED + 1 , test_val , test_buffers , expect_match = False )
234+ return jit
240235
241236
242237def _parse_size (s ):
@@ -245,6 +240,8 @@ def _parse_size(s):
245240
246241
247242def read_file_chunked_to_shm (path ):
243+ from openpilot .common .file_chunker import read_file_chunked
244+ from openpilot .system .hardware .hw import Paths
248245 shm_path = os .path .join (Paths .shm_path (), os .path .basename (path ))
249246 atexit .register (lambda : os .path .exists (shm_path ) and os .remove (shm_path ))
250247 with open (shm_path , 'wb' ) as f :
@@ -255,6 +252,7 @@ def read_file_chunked_to_shm(path):
255252if __name__ == "__main__" :
256253 from tinygrad .nn .onnx import OnnxRunner
257254 from openpilot .system .camerad .cameras .nv12_info import get_nv12_info
255+ from openpilot .selfdrive .modeld .get_model_metadata import make_metadata_dict
258256 p = argparse .ArgumentParser ()
259257 p .add_argument ('--model-size' , type = _parse_size , required = True , help = 'model input WxH' )
260258 p .add_argument ('--camera-resolutions' , type = _parse_size , nargs = '+' , required = True ,
@@ -266,23 +264,26 @@ def read_file_chunked_to_shm(path):
266264 args = p .parse_args ()
267265
268266 out = defaultdict (dict )
269- # init runners once so weights are shared
270- from get_model_metadata import make_metadata_dict
271267 vision_path , policy_path = read_file_chunked_to_shm (args .vision_onnx ), read_file_chunked_to_shm (args .policy_onnx )
268+ model_w , model_h = args .model_size
269+
272270 vision_runner = OnnxRunner (vision_path )
273271 policy_runner = OnnxRunner (policy_path )
274- out ['metadata' ]['vision' ] = make_metadata_dict (vision_path )
275- out ['metadata' ]['policy' ] = make_metadata_dict (policy_path )
272+ vision_metadata , policy_metadata = make_metadata_dict (vision_path ), make_metadata_dict (policy_path )
273+
274+ run_policy_jit = TinyJit (make_run_policy (vision_runner , policy_runner , vision_metadata ['output_slices' ]['hidden_state' ], args .frame_skip ), prune = True )
275+
276+ out ['metadata' ]['vision' ], out ['metadata' ]['policy' ] = vision_metadata , policy_metadata
277+
278+ make_random_model_inputs = partial (make_random_images , keys = ['img' , 'big_img' ], shape = vision_metadata ['input_shapes' ]['img' ])
279+ out ['run_policy' ] = compile_jit (run_policy_jit , make_random_model_inputs , POLICY_INPUTS , args .frame_skip , vision_metadata , policy_metadata )
276280
277281 for cam_w , cam_h in args .camera_resolutions :
278282 nv12 = NV12Frame (cam_w , cam_h , * get_nv12_info (cam_w , cam_h ))
279- model_w , model_h = args .model_size
280- out [(cam_w ,cam_h )] = {
281- name : compile_modeld (nv12 , model_w , model_h , prepare_only , args .frame_skip ,
282- vision_runner , policy_runner , out ['metadata' ]['vision' ], out ['metadata' ]['policy' ])
283- for name , prepare_only in [('warp_enqueue' , True ), ('run_policy' , False )]
284- }
283+ make_random_warp_inputs = partial (make_random_images , keys = ['frame' , 'big_frame' ], shape = nv12 .size , device = WARP_DEV )
284+ warp_enqueue = TinyJit (make_warp (nv12 , model_w , model_h , args .frame_skip ), prune = True )
285+ out [(cam_w ,cam_h )] = compile_jit (warp_enqueue , make_random_warp_inputs , WARP_INPUTS , args .frame_skip , vision_metadata , policy_metadata )
285286
286287 with open (args .output , "wb" ) as f :
287288 pickle .dump (out , f )
288- print (f"Saved combined JIT to { args .output } ({ os .path .getsize (args .output ) / 1e6 :.2f} MB)" )
289+ print (f"Saved JITs to { args .output } ({ os .path .getsize (args .output ) / 1e6 :.2f} MB)" )
0 commit comments