3737 get_cluster_table ,
3838 get_res_during_tuning ,
3939 is_valid_task ,
40+ is_valid_uuid ,
4041 list_to_string ,
4142 serialize ,
4243)
@@ -97,7 +98,8 @@ def ping():
9798 msg = "Ping fail! Make sure Neural Solution runner is running!"
9899 break
99100 except Exception as e :
100- msg = "Ping fail! {}" .format (e )
101+ print (e )
102+ msg = "Ping fail!"
101103 break
102104 sock .close ()
103105 return {"status" : "Healthy" , "msg" : msg } if count == 2 else {"status" : "Failed" , "msg" : msg }
@@ -167,26 +169,31 @@ async def submit_task(task: Task):
167169 cursor = conn .cursor ()
168170 task_id = str (uuid .uuid4 ()).replace ("-" , "" )
169171 sql = (
170- r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)"
171- + r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')" .format (
172- task_id ,
173- task .script_url ,
174- task .optimized ,
175- list_to_string (task .arguments ),
176- task .approach ,
177- list_to_string (task .requirements ),
178- task .workers ,
179- )
172+ "INSERT INTO task "
173+ "(id, script_url, optimized, arguments, approach, requirements, workers, status) "
174+ "VALUES (?, ?, ?, ?, ?, ?, ?, 'pending')"
180175 )
181- cursor .execute (sql )
176+
177+ task_params = (
178+ task_id ,
179+ task .script_url ,
180+ task .optimized ,
181+ list_to_string (task .arguments ),
182+ task .approach ,
183+ list_to_string (task .requirements ),
184+ task .workers ,
185+ )
186+
187+ conn .execute (sql , task_params )
182188 conn .commit ()
183189 try :
184190 task_submitter .submit_task (task_id )
185191 except ConnectionRefusedError :
186192 msg = "Task Submitted fail! Make sure Neural Solution runner is running!"
187193 status = "failed"
188194 except Exception as e :
189- msg = "Task Submitted fail! {}" .format (e )
195+ msg = "Task Submitted fail!"
196+ print (e )
190197 status = "failed"
191198 conn .close ()
192199 else :
@@ -205,6 +212,8 @@ def get_task_by_id(task_id: str):
205212 Returns:
206213 json: task status, result, quantized model path
207214 """
215+ if not is_valid_uuid (task_id ):
216+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
208217 res = None
209218 db_path = get_db_path (config .workspace )
210219 if os .path .isfile (db_path ):
@@ -246,6 +255,8 @@ def get_task_status_by_id(request: Request, task_id: str):
246255 Returns:
247256 json: task status and information
248257 """
258+ if not is_valid_uuid (task_id ):
259+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
249260 status = "unknown"
250261 tuning_info = {}
251262 optimization_result = {}
@@ -290,7 +301,13 @@ async def read_logs(task_id: str):
290301 Yields:
291302 str: log lines
292303 """
293- log_path = "{}/task_{}.txt" .format (get_task_log_workspace (config .workspace ), task_id )
304+ if not is_valid_uuid (task_id ):
305+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
306+ log_path = os .path .normpath (os .path .join (get_task_log_workspace (config .workspace ), "task_{}.txt" .format (task_id )))
307+
308+ if not log_path .startswith (os .path .normpath (config .workspace )):
309+ return {"error" : "Logfile not found." }
310+
294311 if not os .path .exists (log_path ):
295312 return {"error" : "Logfile not found." }
296313
@@ -388,12 +405,17 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str):
388405 Raises:
389406 HTTPException: exception
390407 """
408+ if not is_valid_uuid (task_id ):
409+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
391410 if not check_log_exists (task_id = task_id , task_log_path = get_task_log_workspace (config .workspace )):
392411 raise HTTPException (status_code = 404 , detail = "Task not found" )
393412 await websocket .accept ()
394413
395414 # send the log that has been written
396- log_path = "{}/task_{}.txt" .format (get_task_log_workspace (config .workspace ), task_id )
415+ log_path = os .path .normpath (os .path .join (get_task_log_workspace (config .workspace ), "task_{}.txt" .format (task_id )))
416+
417+ if not log_path .startswith (os .path .normpath (config .workspace )):
418+ return {"error" : "Logfile not found." }
397419 last_position = 0
398420 previous_log = []
399421 if os .path .exists (log_path ):
@@ -429,6 +451,8 @@ async def download_file(task_id: str):
429451 Returns:
430452 FileResponse: quantized model of zip file format
431453 """
454+ if not is_valid_uuid (task_id ):
455+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
432456 db_path = get_db_path (config .workspace )
433457 if os .path .isfile (db_path ):
434458 conn = sqlite3 .connect (db_path )
@@ -444,6 +468,9 @@ async def download_file(task_id: str):
444468 path = res [2 ]
445469 zip_filename = "quantized_model.zip"
446470 zip_filepath = os .path .abspath (os .path .join (get_task_workspace (config .workspace ), task_id , zip_filename ))
471+
472+ if not zip_filepath .startswith (os .path .normpath (os .path .abspath (get_task_workspace (config .workspace )))):
473+ raise HTTPException (status_code = 422 , detail = "Invalid File" )
447474 # create zipfile and add file
448475 with zipfile .ZipFile (zip_filepath , "w" , zipfile .ZIP_DEFLATED ) as zip_file :
449476 for root , dirs , files in os .walk (path ):
0 commit comments