Skip to content

Commit 0ad6c87

Browse files
authored
feat: add enforce URI query params with a specific for MySQL (#23723)
1 parent e9b4022 commit 0ad6c87

File tree

4 files changed

+59
-7
lines changed

4 files changed

+59
-7
lines changed

superset/db_engine_specs/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
357357
top_keywords: Set[str] = {"TOP"}
358358
# A set of disallowed connection query parameters
359359
disallow_uri_query_params: Set[str] = set()
360+
# A Dict of query parameters that will always be used on every connection
361+
enforce_uri_query_params: Dict[str, Any] = {}
360362

361363
force_column_alias_quotes = False
362364
arraysize = 0
@@ -1089,11 +1091,12 @@ def adjust_engine_params( # pylint: disable=unused-argument
10891091
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
10901092
given query is running in order to enforce permissions (see #23385 and #23401).
10911093
1092-
Currently, changing the catalog is not supported. The method acceps a catalog so
1093-
that when catalog support is added to Superse the interface remains the same. This
1094-
is important because DB engine specs can be installed from 3rd party packages.
1094+
Currently, changing the catalog is not supported. The method accepts a catalog so
1095+
that when catalog support is added to Superset the interface remains the same.
1096+
This is important because DB engine specs can be installed from 3rd party
1097+
packages.
10951098
"""
1096-
return uri, connect_args
1099+
return uri, {**connect_args, **cls.enforce_uri_query_params}
10971100

10981101
@classmethod
10991102
def patch(cls) -> None:

superset/db_engine_specs/mysql.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
176176
),
177177
}
178178
disallow_uri_query_params = {"local_infile"}
179+
enforce_uri_query_params = {"local_infile": 0}
179180

180181
@classmethod
181182
def convert_dttm(
@@ -198,10 +199,13 @@ def adjust_engine_params(
198199
catalog: Optional[str] = None,
199200
schema: Optional[str] = None,
200201
) -> Tuple[URL, Dict[str, Any]]:
202+
uri, new_connect_args = super(
203+
MySQLEngineSpec, MySQLEngineSpec
204+
).adjust_engine_params(uri, connect_args, catalog, schema)
201205
if schema:
202206
uri = uri.set(database=parse.quote(schema, safe=""))
203207

204-
return uri, connect_args
208+
return uri, new_connect_args
205209

206210
@classmethod
207211
def get_schema_from_engine_params(

tests/integration_tests/model_tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,21 @@ def test_impersonate_user_presto(self, mocked_create_engine):
188188
"password": "original_user_password",
189189
}
190190

191+
@unittest.skipUnless(
192+
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
193+
)
194+
@mock.patch("superset.models.core.create_engine")
195+
def test_adjust_engine_params_mysql(self, mocked_create_engine):
196+
model = Database(
197+
database_name="test_database",
198+
sqlalchemy_uri="mysql://user:password@localhost",
199+
)
200+
model._get_sqla_engine()
201+
call_args = mocked_create_engine.call_args
202+
203+
assert str(call_args[0][0]) == "mysql://user:password@localhost"
204+
assert call_args[1]["connect_args"]["local_infile"] == 0
205+
191206
@mock.patch("superset.models.core.create_engine")
192207
def test_impersonate_user_trino(self, mocked_create_engine):
193208
principal_user = security_manager.find_user(username="gamma")

tests/unit_tests/db_engine_specs/test_mysql.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
from datetime import datetime
19-
from typing import Any, Dict, Optional, Type
19+
from typing import Any, Dict, Optional, Tuple, Type
2020
from unittest.mock import Mock, patch
2121

2222
import pytest
@@ -33,7 +33,7 @@
3333
TINYINT,
3434
TINYTEXT,
3535
)
36-
from sqlalchemy.engine.url import make_url
36+
from sqlalchemy.engine.url import make_url, URL
3737

3838
from superset.utils.core import GenericDataType
3939
from tests.unit_tests.db_engine_specs.utils import (
@@ -119,6 +119,36 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
119119
MySQLEngineSpec.validate_database_uri(url)
120120

121121

122+
@pytest.mark.parametrize(
123+
"sqlalchemy_uri,connect_args,returns",
124+
[
125+
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
126+
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
127+
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
128+
(
129+
"mysql://user:password@host/db1",
130+
{"param1": "some_value"},
131+
{"local_infile": 0, "param1": "some_value"},
132+
),
133+
(
134+
"mysql://user:password@host/db1",
135+
{"local_infile": 1, "param1": "some_value"},
136+
{"local_infile": 0, "param1": "some_value"},
137+
),
138+
],
139+
)
140+
def test_adjust_engine_params(
141+
sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
142+
) -> None:
143+
from superset.db_engine_specs.mysql import MySQLEngineSpec
144+
145+
url = make_url(sqlalchemy_uri)
146+
returned_url, returned_connect_args = MySQLEngineSpec.adjust_engine_params(
147+
url, connect_args
148+
)
149+
assert returned_connect_args == returns
150+
151+
122152
@patch("sqlalchemy.engine.Engine.connect")
123153
def test_get_cancel_query_id(engine_mock: Mock) -> None:
124154
from superset.db_engine_specs.mysql import MySQLEngineSpec

0 commit comments

Comments
 (0)