|
66 | 66 | "metadata": {},
|
67 | 67 | "outputs": [],
|
68 | 68 | "source": [
|
69 |
| - "from neuraxle.base import BaseSaver, BaseStep, ExecutionContext, Identity\n", |
70 |
| - "from queue import Queue\n", |
| 69 | + "from multiprocessing import Queue\n", |
| 70 | + "from neuraxle.base import BaseSaver, BaseTransformer, ExecutionContext, Identity\n", |
| 71 | + "from neuraxle.base import ExecutionContext as CX\n", |
| 72 | + "from neuraxle.distributed.streaming import _ProducerConsumerMixin\n", |
71 | 73 | "\n",
|
72 |
| - "class ObservableQueueStepSaver(BaseSaver):\n", |
73 |
| - " def save_step(self, step: 'BaseStep', context: 'ExecutionContext') -> 'BaseStep':\n", |
74 |
| - " step.queue = None\n", |
75 |
| - " step.observers = []\n", |
| 74 | + "class _ProducerConsumerStepSaver(BaseSaver):\n", |
| 75 | + " \"\"\"\n", |
| 76 | + " Saver for :class:`_ProducerConsumerMixin`.\n", |
| 77 | + " This saver class makes sure that the non-picklable queue\n", |
| 78 | + " is deleted upon saving for multiprocessing steps.\n", |
| 79 | + " \"\"\"\n", |
| 80 | + "\n", |
| 81 | + " def save_step(self, step: BaseTransformer, context: 'CX') -> BaseTransformer:\n", |
| 82 | + " step: _ProducerConsumerMixin = step # typing.\n", |
| 83 | + " step._allow_exit_without_queue_flush()\n", |
| 84 | + " step.input_queue = None\n", |
| 85 | + " step.consumers = []\n", |
76 | 86 | " return step\n",
|
77 | 87 | "\n",
|
78 |
| - " def can_load(self, step: 'BaseStep', context: 'ExecutionContext'):\n", |
| 88 | + " def can_load(self, step: BaseTransformer, context: 'CX') -> bool:\n", |
79 | 89 | " return True\n",
|
80 | 90 | "\n",
|
81 |
| - " def load_step(self, step: 'BaseStep', context: 'ExecutionContext') -> 'BaseStep':\n", |
82 |
| - " step.queue = Queue()\n", |
| 91 | + " def load_step(self, step: 'BaseTransformer', context: 'CX') -> 'BaseTransformer':\n", |
| 92 | + " step: _ProducerConsumerMixin = step # typing.\n", |
| 93 | + " step.input_queue = None\n", |
83 | 94 | " return step"
|
84 | 95 | ]
|
85 | 96 | },
|
|
98 | 109 | "source": [
|
99 | 110 | "class IdentityWithQueue(Identity):\n",
|
100 | 111 | " def __init__(self):\n",
|
101 |
| - " super().__init__(savers=[ObservableQueueStepSaver()])\n", |
| 112 | + " super().__init__(savers=[_ProducerConsumerStepSaver()])\n", |
102 | 113 | "\n",
|
103 | 114 | " def setup(self, context=None):\n",
|
104 | 115 | " if not self.is_initialized:\n",
|
|
0 commit comments