Skip to content

Commit bea2288

Browse files
committed
cli/conn/preview: Migrate multiprocess->multithreaded
Requires recent patches to cyanodbc releasing GIL.
1 parent 6815728 commit bea2288

File tree

6 files changed

+144
-297
lines changed

6 files changed

+144
-297
lines changed

odbcli/__main__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from multiprocessing import set_start_method
21
from .cli import main
3-
#main()
42

53
if __name__ == "__main__":
6-
set_start_method('spawn')
74
main()

odbcli/cli.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from prompt_toolkit.utils import get_cwidth
1313
from .app import sqlApp, ExitEX
1414
from .layout import sqlAppLayout
15-
from .conn import connStatus
16-
from .executor import cmsg, commandStatus
15+
from .conn import connStatus, executionStatus
1716

1817

1918
def main():
@@ -31,42 +30,40 @@ def main():
3130
# If it's a preview query we need an indication
3231
# of where to run the query
3332
if(app_res[0] == "preview"):
34-
sqlConn = my_app.selected_object.conn
33+
sql_conn = my_app.selected_object.conn
3534
else:
36-
sqlConn = my_app.active_conn
37-
if sqlConn is not None:
35+
sql_conn = my_app.active_conn
36+
if sql_conn is not None:
3837
#TODO also check that it is connected
3938
try:
4039
secho("Executing query...Ctrl-c to cancel", err = False)
4140
start = time()
42-
res = sqlConn.async_execute(app_res[1])
41+
crsr = sql_conn.async_execute(app_res[1])
4342
execution = time() - start
44-
sqlConn.status = connStatus.IDLE
4543
secho("Query execution...done", err = False)
4644
if(app_res[0] == "preview"):
4745
continue
4846
if my_app.timing_enabled:
4947
print("Time: %0.03fs" % execution)
50-
if res.status == commandStatus.OKWRESULTS:
51-
ht = my_app.application.output.get_size()[0]
52-
formatted = sqlConn.formatted_fetch(ht - 3 - my_app.pager_reserve_lines, my_app.table_format)
53-
sqlConn.status = connStatus.FETCHING
54-
echo_via_pager(formatted)
55-
elif res.status == commandStatus.OK:
56-
secho("No rows returned\n", err = False)
48+
49+
if sql_conn.execution_status == executionStatus.FAIL:
50+
err = sql_conn.execution_err
51+
secho("Query error: %s\n" % err, err = True, fg = "red")
5752
else:
58-
secho("Query error: %s\n" % res.payload, err = True, fg = "red")
59-
except BrokenPipeError:
60-
my_app.logger.debug('BrokenPipeError caught. Recovering...', file = stderr)
53+
if crsr.description:
54+
cols = [col.name for col in crsr.description]
55+
else:
56+
cols = []
57+
if len(cols):
58+
ht = my_app.application.output.get_size()[0]
59+
formatted = sql_conn.formatted_fetch(ht - 3 - my_app.pager_reserve_lines, cols, my_app.table_format)
60+
sql_conn.status = connStatus.FETCHING
61+
echo_via_pager(formatted)
62+
else:
63+
secho("No rows returned\n", err = False)
6164
except KeyboardInterrupt:
62-
secho("Cancelling query...", err = True, fg = 'red')
63-
sqlConn.executor.terminate()
64-
sqlConn.executor.join()
65-
secho("Query cancelled.", err = True, fg='red')
66-
#TODO: catch ConnectError
67-
sqlConn.connect(start_executor = True)
68-
sqlConn.status = connStatus.IDLE
69-
# TODO check status of return
70-
sqlConn.async_fetchdone()
71-
# sqlConn.parent_chan.send(cmsg("fetchdone", None, None))
72-
# sqlConn.parent_chan.recv()
65+
secho("Cancelling query...", err = True, fg = "red")
66+
sql_conn.cancel()
67+
secho("Query cancelled.", err = True, fg = "red")
68+
sql_conn.status = connStatus.IDLE
69+
sql_conn.close_cursor()

odbcli/conn.py

Lines changed: 93 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from cyanodbc import connect, Connection, SQLGetInfo, Cursor, DatabaseError, ConnectError
33
from typing import Optional
44
from cli_helpers.tabular_output import TabularOutputFormatter
5-
from multiprocessing import Process, Pipe
65
from logging import getLogger
76
from re import sub
8-
from threading import Lock
9-
from .executor import executor_process, cmsg, commandStatus
7+
from threading import Lock, Event, Thread
8+
from enum import IntEnum
109

1110
formatter = TabularOutputFormatter()
1211

@@ -17,6 +16,11 @@ class connStatus(Enum):
1716
FETCHING = 3
1817
ERROR = 4
1918

19+
class executionStatus(IntEnum):
20+
OK = 0
21+
FAIL = 1
22+
OKWRESULTS = 2
23+
2024
class sqlConnection:
2125
def __init__(
2226
self,
@@ -32,8 +36,6 @@ def __init__(
3236
self.username = username
3337
self.password = password
3438
self.status = connStatus.DISCONNECTED
35-
self.executor: Process = None
36-
self.parent_chan, self.child_chan = Pipe()
3739
self.logger = getLogger(__name__)
3840
self._quotechar = None
3941
self._search_escapechar = None
@@ -48,6 +50,26 @@ def __init__(
4850
# multiple auto-completion result queries before each has had a chance
4951
# to return.
5052
self._lock = Lock()
53+
self._fetch_res: list = None
54+
self._execution_status: executionStatus = executionStatus.OK
55+
self._execution_err: str = None
56+
57+
@property
58+
def execution_status(self) -> executionStatus:
59+
""" Hold the lock here since it gets assigned in execute
60+
which can be called in a different thread """
61+
with self._lock:
62+
res = self._execution_status
63+
return res
64+
65+
@property
66+
def execution_err(self) -> str:
67+
""" Last execution error: Cleared prior to every execution.
68+
Hold the lock here since it gets assigned in execute
69+
which can be called in a different thread """
70+
with self._lock:
71+
res = self._execution_err
72+
return res
5173

5274
@property
5375
def quotechar(self) -> str:
@@ -85,8 +107,7 @@ def connect(
85107
self,
86108
username: str = "",
87109
password: str = "",
88-
force: bool = False,
89-
start_executor: bool = False) -> None:
110+
force: bool = False) -> None:
90111
uid = username or self.username
91112
pwd = password or self.password
92113
conn_str = "DSN=" + self.dsn + ";"
@@ -103,97 +124,65 @@ def connect(
103124
except ConnectError as e:
104125
self.logger.error("Error while connecting: %s", str(e))
105126
raise ConnectError(e)
106-
if start_executor:
107-
self.executor = Process(
108-
target = executor_process,
109-
args=(self.child_chan, self.logger.getEffectiveLevel(),))
110-
self.executor.start()
111-
self.logger.info("Started executor process: %d", self.executor.pid)
112-
self.parent_chan.send(cmsg("connect", conn_str, None))
113-
resp = self.parent_chan.recv()
114-
# How do you handle failure here?
115-
if not resp.status == commandStatus.OK:
116-
self.logger.error("Error atempting to connect in executor process")
117-
self.executor.terminate()
118-
self.executor.join()
119-
raise ConnectError("Connection failure in executor")
120-
121-
def async_lastresponse(self) -> cmsg:
122-
if self.executor and self.executor.is_alive():
123-
self.logger.debug("Asking for last message, executor pid %d",
124-
self.executor.pid)
125-
self.parent_chan.send(cmsg("lastresponse", None, None))
126-
resp = self.parent_chan.recv()
127-
# Above should never fail
128-
return resp
129-
130-
def async_execute(self, query) -> cmsg:
131-
if self.executor and self.executor.is_alive():
132-
self.logger.debug("Sending query %s to pid %d",
133-
query, self.executor.pid)
134-
# TODO: message should carry
135-
# current catalog. One might
136-
# think that the main process
137-
# connection always "follows"
138-
# database changes since all
139-
# main queries get executed
140-
# against executor thread
141-
# and main process conn only
142-
# gets used for sidebar/auto
143-
# completion. But, for
144-
# example the MYSQL driver
145-
# if starting without a
146-
# declared database will just
147-
# switch to the first db
148-
# when running find_columns
149-
self.parent_chan.send(
150-
cmsg("execute", query, None))
151-
# Will block but can be interrupted
152-
res = self.parent_chan.recv()
153-
self.logger.debug("Execution done")
154-
self.query = query
155-
# Check if catalog has changed in which case
156-
# execute query locally
157-
self.parent_chan.send(cmsg("currentcatalog", None, None))
158-
rescat = self.parent_chan.recv()
159-
if rescat.status == commandStatus.FAIL:
160-
# TODO raise exception here since
161-
# connection catalogs are possibly out of sync
162-
# and we don't have a way of knowing
163-
res = cmsg("execute", "", commandStatus.FAIL)
164-
elif not rescat.payload == self.current_catalog():
165-
# query changed the catalog
166-
# so let's change the database locally
167-
self.logger.debug("Execution changed catalog")
168-
self.execute("USE " + rescat.payload)
169-
else:
170-
res = cmsg("execute", "", commandStatus.FAIL)
171-
return res
172127

173-
def async_fetch(self, size) -> cmsg:
174-
if self.executor and self.executor.is_alive():
175-
self.logger.debug("Fetching size %d from pid %d",
176-
size, self.executor.pid)
177-
self.parent_chan.send(cmsg("fetch", size, None))
178-
res = self.parent_chan.recv()
179-
self.logger.debug("Fetching done")
180-
else:
181-
res = cmsg("fetch", "", commandStatus.FAIL)
182-
return res
183-
184-
def async_fetchdone(self) -> cmsg:
185-
if self.executor and self.executor.is_alive():
186-
self.parent_chan.send(cmsg("fetchdone", None, None))
187-
res = self.parent_chan.recv()
128+
def fetchmany(self, size, event: Event = None) -> list:
129+
if self.cursor:
130+
self._fetch_res = self.cursor.fetchmany(size)
188131
else:
189-
res = cmsg("fetchdone", "", commandStatus.FAIL)
190-
return res
191-
192-
def execute(self, query, parameters = None) -> Cursor:
132+
self._fetch_res = []
133+
if event is not None:
134+
event.set()
135+
return self._fetch_res
136+
137+
def async_fetchmany(self, size) -> list:
138+
""" async_ is a misnomer here. It does execute fetch in a new thread
139+
however it will also wait for execution to complete. At this time
140+
this helps us with registering KeyboardInterrupt during cyanodbc.
141+
fetchmany only; it may evolve to have more true async-like behavior.
142+
"""
143+
exec_event = Event()
144+
t = Thread(
145+
target = self.fetchmany,
146+
kwargs = {"size": size, "event": exec_event},
147+
daemon = True)
148+
t.start()
149+
# Will block but can be interrupted
150+
exec_event.wait()
151+
return self._fetch_res
152+
153+
def execute(self, query, parameters = None, event: Event = None) -> Cursor:
193154
with self._lock:
155+
self.close_cursor()
194156
self.cursor = self.conn.cursor()
195-
self.cursor.execute(query, parameters)
196-
self.query = query
157+
try:
158+
self._execution_err = None
159+
self.status = connStatus.EXECUTING
160+
self.cursor.execute(query, parameters)
161+
self.status = connStatus.IDLE
162+
self._execution_status = executionStatus.OK
163+
self.query = query
164+
except DatabaseError as e:
165+
self._execution_status = executionStatus.FAIL
166+
self._execution_err = str(e)
167+
self.logger.warning("Execution error: %s", str(e))
168+
if event is not None:
169+
event.set()
170+
return self.cursor
171+
172+
def async_execute(self, query) -> Cursor:
173+
""" async_ is a misnomer here. It does execute fetch in a new thread
174+
however it will also wait for execution to complete. At this time
175+
this helps us with registering KeyboardInterrupt during cyanodbc.
176+
execute only; it may evolve to have more true async-like behavior.
177+
"""
178+
exec_event = Event()
179+
t = Thread(
180+
target = self.execute,
181+
kwargs = {"query": query, "parameters": None, "event": exec_event},
182+
daemon = True)
183+
t.start()
184+
# Will block but can be interrupted
185+
exec_event.wait()
197186
return self.cursor
198187

199188
def list_catalogs(self) -> list:
@@ -296,9 +285,6 @@ def get_info(self, code: int) -> str:
296285
return self.conn.get_info(code)
297286

298287
def close(self) -> None:
299-
if self.executor and self.executor.is_alive():
300-
self.executor.terminate()
301-
self.executor.join()
302288
# TODO: When disconnecting
303289
# We likely don't want to allow any exception to
304290
# propagate. Catch DatabaseError?
@@ -311,24 +297,27 @@ def close_cursor(self) -> None:
311297
self.cursor = None
312298
self.query = None
313299

300+
def cancel(self) -> None:
301+
if self.cursor:
302+
self.cursor.cancel()
303+
self.query = None
304+
314305
def preview_query(self, table, filter_query = "", limit = -1) -> str:
315306
qry = "SELECT * FROM " + table + " " + filter_query
316307
if limit > 0:
317308
qry = qry + " LIMIT " + str(limit)
318309
return qry
319310

320-
def formatted_fetch(self, size, format_name = "psql"):
311+
def formatted_fetch(self, size, cols, format_name = "psql"):
321312
while True:
322-
res = self.async_fetch(size)
323-
if (res.status == commandStatus.FAIL) or (not res.type == "fetch"):
324-
return "Encountered a problem while fetching"
325-
elif len(res.payload[1]) == 0:
313+
res = self.async_fetchmany(size)
314+
if len(res) < 1:
326315
break
327316
else:
328317
yield "\n".join(
329318
formatter.format_output(
330-
res.payload[1],
331-
res.payload[0],
319+
res,
320+
cols,
332321
format_name = format_name))
333322

334323
connWrappers = {}

0 commit comments

Comments
 (0)