@@ -1407,6 +1407,44 @@ def test_variables(self):
14071407        os .remove ('variables1.json' )
14081408        os .remove ('variables2.json' )
14091409
1410+ class  CSRFTests (unittest .TestCase ):
1411+     def  setUp (self ):
1412+         configuration .load_test_config ()
1413+         configuration .conf .set ("webserver" , "authenticate" , "False" )
1414+         configuration .conf .set ("webserver" , "expose_config" , "True" )
1415+         app  =  application .create_app ()
1416+         app .config ['TESTING' ] =  True 
1417+         self .app  =  app .test_client ()
1418+ 
1419+         self .dagbag  =  models .DagBag (
1420+             dag_folder = DEV_NULL , include_examples = True )
1421+         self .dag_bash  =  self .dagbag .dags ['example_bash_operator' ]
1422+         self .runme_0  =  self .dag_bash .get_task ('runme_0' )
1423+ 
1424+     def  get_csrf (self , response ):
1425+         tree  =  html .fromstring (response .data )
1426+         form  =  tree .find ('.//form' )
1427+ 
1428+         return  form .find ('.//input[@name="_csrf_token"]' ).value 
1429+ 
1430+     def  test_csrf_rejection (self ):
1431+         endpoints  =  ([
1432+             "/admin/queryview/" ,
1433+             "/admin/airflow/paused?dag_id=example_python_operator&is_paused=false" ,
1434+         ])
1435+         for  endpoint  in  endpoints :
1436+             response  =  self .app .post (endpoint )
1437+             self .assertIn ('CSRF token is missing' , response .data .decode ('utf-8' ))
1438+ 
1439+     def  test_csrf_acceptance (self ):
1440+         response  =  self .app .get ("/admin/queryview/" )
1441+         csrf  =  self .get_csrf (response )
1442+         response  =  self .app .post ("/admin/queryview/" , data = dict (csrf_token = csrf ))
1443+         self .assertEqual (200 , response .status_code )
1444+ 
1445+     def  tearDown (self ):
1446+         configuration .conf .set ("webserver" , "expose_config" , "False" )
1447+         self .dag_bash .clear (start_date = DEFAULT_DATE , end_date = datetime .now ())
14101448
14111449class  WebUiTests (unittest .TestCase ):
14121450    def  setUp (self ):
@@ -1415,6 +1453,7 @@ def setUp(self):
14151453        configuration .conf .set ("webserver" , "expose_config" , "True" )
14161454        app  =  application .create_app ()
14171455        app .config ['TESTING' ] =  True 
1456+         app .config ['WTF_CSRF_METHODS' ] =  []
14181457        self .app  =  app .test_client ()
14191458
14201459        self .dagbag  =  models .DagBag (include_examples = True )
@@ -1445,10 +1484,10 @@ def test_index(self):
14451484    def  test_query (self ):
14461485        response  =  self .app .get ('/admin/queryview/' )
14471486        self .assertIn ("Ad Hoc Query" , response .data .decode ('utf-8' ))
1448-         response  =  self .app .get (
1449-             "/admin/queryview/?"  
1450-             " conn_id=airflow_db&" 
1451-             " sql=SELECT+COUNT%281%29+as+TEST+FROM+task_instance"
1487+         response  =  self .app .post (
1488+             "/admin/queryview/"  ,  data = dict ( 
1489+             conn_id = " airflow_db" , 
1490+             sql = " SELECT+COUNT%281%29+as+TEST+FROM+task_instance") )
14521491        self .assertIn ("TEST" , response .data .decode ('utf-8' ))
14531492
14541493    def  test_health (self ):
@@ -1563,9 +1602,10 @@ def test_dag_views(self):
15631602        response  =  self .app .get (
15641603            "/admin/airflow/refresh?dag_id=example_bash_operator" )
15651604        response  =  self .app .get ("/admin/airflow/refresh_all" )
1566-         response  =  self .app .get (
1605+         response  =  self .app .post (
15671606            "/admin/airflow/paused?" 
15681607            "dag_id=example_python_operator&is_paused=false" )
1608+         self .assertIn ("OK" , response .data .decode ('utf-8' ))
15691609        response  =  self .app .get ("/admin/xcom" , follow_redirects = True )
15701610        self .assertIn ("Xcoms" , response .data .decode ('utf-8' ))
15711611
0 commit comments