-
Couldn't load subscription status.
- Fork 15.9k
AIP-44 Migrate TaskInstance.check_and_change_state_before_execution to Internal API #29513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,13 +59,14 @@ | |
| ) | ||
| from sqlalchemy.ext.associationproxy import association_proxy | ||
| from sqlalchemy.ext.mutable import MutableDict | ||
| from sqlalchemy.orm import reconstructor, relationship | ||
| from sqlalchemy.orm import make_transient, reconstructor, relationship | ||
| from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value | ||
| from sqlalchemy.orm.session import Session | ||
| from sqlalchemy.sql.elements import BooleanClauseList | ||
| from sqlalchemy.sql.expression import ColumnOperators, case | ||
|
|
||
| from airflow import settings | ||
| from airflow.api_internal.internal_api_call import internal_api_call | ||
| from airflow.compat.functools import cache | ||
| from airflow.configuration import conf | ||
| from airflow.datasets import Dataset | ||
|
|
@@ -787,6 +788,51 @@ def error(self, session: Session = NEW_SESSION) -> None: | |
| session.merge(self) | ||
| session.commit() | ||
|
|
||
| @classmethod | ||
| @internal_api_call | ||
| @provide_session | ||
| def retrieve_from_db( | ||
| cls, | ||
| dag_id: str, | ||
| run_id: str, | ||
| task_id: str, | ||
| map_index: int, | ||
| session: Session = NEW_SESSION, | ||
| lock_for_update: bool = False, | ||
| ) -> TaskInstance | None: | ||
| """ | ||
| Retrieve the task instance from the database based on the primary key | ||
|
|
||
| :param dag_id: The Dag ID | ||
| :param run_id: The Dag run ID | ||
| :param task_id: The Task ID | ||
| :param map_index: The map index | ||
| :param session: SQLAlchemy ORM Session | ||
| :param lock_for_update: if True, indicates that the database should | ||
| lock the TaskInstance (issuing a FOR UPDATE clause) until the | ||
| session is committed. | ||
| :return: The TaskInstance object retrieved from the database. | ||
| """ | ||
| logger = cls.logger() | ||
| logger.debug( | ||
| f"Retrieving TaskInstance from DB with primary key: ({dag_id}, {task_id}, {run_id}, {map_index})" | ||
| ) | ||
|
|
||
| qry = session.query(TaskInstance).filter( | ||
| TaskInstance.dag_id == dag_id, | ||
| TaskInstance.task_id == task_id, | ||
| TaskInstance.run_id == run_id, | ||
| TaskInstance.map_index == map_index, | ||
| ) | ||
|
|
||
| if lock_for_update: | ||
| for attempt in run_with_db_retries(logger=logger): | ||
| with attempt: | ||
| ti: TaskInstance | None = qry.with_for_update().one_or_none() | ||
| else: | ||
| ti = qry.one_or_none() | ||
| return ti | ||
|
|
||
| @provide_session | ||
| def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None: | ||
| """ | ||
|
|
@@ -1120,7 +1166,6 @@ def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session | |
| dep_context = dep_context or DepContext() | ||
| for dep in dep_context.deps | self.task.deps: | ||
| for dep_status in dep.get_dep_statuses(self, session, dep_context): | ||
|
|
||
| self.log.debug( | ||
| "%s dependency '%s' PASSED: %s, %s", | ||
| self, | ||
|
|
@@ -1207,9 +1252,15 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: | |
|
|
||
| return dr | ||
|
|
||
| @staticmethod | ||
| @internal_api_call | ||
| @provide_session | ||
| def check_and_change_state_before_execution( | ||
| self, | ||
| dag_id: str, | ||
| run_id: str, | ||
| task_id: str, | ||
| map_index: int, | ||
| task: Operator, | ||
| verbose: bool = True, | ||
| ignore_all_deps: bool = False, | ||
| ignore_depends_on_past: bool = False, | ||
|
|
@@ -1222,13 +1273,18 @@ def check_and_change_state_before_execution( | |
| pool: str | None = None, | ||
| external_executor_id: str | None = None, | ||
| session: Session = NEW_SESSION, | ||
| ) -> bool: | ||
| """ | ||
| Checks dependencies and then sets state to RUNNING if they are met. Returns | ||
| True if and only if state is set to RUNNING, which implies that task should be | ||
| executed, in preparation for _run_raw_task | ||
|
|
||
| :param verbose: whether to turn on more verbose logging | ||
| ) -> TaskInstance: | ||
| """ | ||
| Retrieve the TI based on its primary keys. Checks dependencies and then sets state to RUNNING if they | ||
| are met. Returns an updated version of the retrieved TI. If state is set to RUNNING, it implies | ||
| that task should be executed, in preparation for _run_raw_task. | ||
|
|
||
| :param dag_id: The Dag ID | ||
| :param run_id: The Dag run ID | ||
| :param task_id: The Task ID | ||
| :param map_index: The map index | ||
| :pram task: The task object | ||
| :param verbose: Whether to turn on more verbose logging | ||
| :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs | ||
| :param ignore_depends_on_past: Ignore depends_on_past DAG attribute | ||
| :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped | ||
|
|
@@ -1237,25 +1293,32 @@ def check_and_change_state_before_execution( | |
| :param mark_success: Don't run the task, mark its state as success | ||
| :param test_mode: Doesn't record success or failure in the DB | ||
| :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID | ||
| :param pool: specifies the pool to use to run the task instance | ||
| :param pool: Specifies the pool to use to run the task instance | ||
| :param external_executor_id: The identifier of the celery executor | ||
| :param session: SQLAlchemy ORM Session | ||
| :return: whether the state was changed to running or not | ||
| """ | ||
| task = self.task | ||
| self.refresh_from_task(task, pool_override=pool) | ||
| self.test_mode = test_mode | ||
| self.refresh_from_db(session=session, lock_for_update=True) | ||
| self.job_id = job_id | ||
| self.hostname = get_hostname() | ||
| self.pid = None | ||
| ti = TaskInstance.retrieve_from_db( | ||
| dag_id, run_id, task_id, map_index, session=session, lock_for_update=True | ||
| ) | ||
| if ti is None: | ||
|
||
| ti = TaskInstance( | ||
| task=task, | ||
| run_id=run_id, | ||
| map_index=map_index, | ||
| ) | ||
| ti.refresh_from_task(task, pool_override=pool) | ||
| ti.test_mode = test_mode | ||
| ti.job_id = job_id | ||
| ti.hostname = get_hostname() | ||
| ti.pid = None | ||
|
|
||
| if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS: | ||
| if not ignore_all_deps and not ignore_ti_state and ti.state == State.SUCCESS: | ||
| Stats.incr( | ||
| "previously_succeeded", | ||
| 1, | ||
| 1, | ||
| tags={"dag_id": self.dag_id, "run_id": self.run_id, "task_id": self.task_id}, | ||
| tags={"dag_id": ti.dag_id, "run_id": ti.run_id, "task_id": ti.task_id}, | ||
| ) | ||
|
|
||
| if not mark_success: | ||
|
|
@@ -1270,23 +1333,24 @@ def check_and_change_state_before_execution( | |
| ignore_task_deps=ignore_task_deps, | ||
| description="non-requeueable deps", | ||
| ) | ||
| if not self.are_dependencies_met( | ||
| if not ti.are_dependencies_met( | ||
| dep_context=non_requeueable_dep_context, session=session, verbose=True | ||
| ): | ||
| session.commit() | ||
| return False | ||
| make_transient(ti) | ||
|
||
| return ti | ||
|
|
||
| # For reporting purposes, we report based on 1-indexed, | ||
| # not 0-indexed lists (i.e. Attempt 1 instead of | ||
| # Attempt 0 for the first attempt). | ||
| # Set the task start date. In case it was re-scheduled use the initial | ||
| # start date that is recorded in task_reschedule table | ||
| # If the task continues after being deferred (next_method is set), use the original start_date | ||
| self.start_date = self.start_date if self.next_method else timezone.utcnow() | ||
| if self.state == State.UP_FOR_RESCHEDULE: | ||
| task_reschedule: TR = TR.query_for_task_instance(self, session=session).first() | ||
| ti.start_date = ti.start_date if ti.next_method else timezone.utcnow() | ||
| if ti.state == State.UP_FOR_RESCHEDULE: | ||
| task_reschedule: TR = TR.query_for_task_instance(ti, session=session).first() | ||
| if task_reschedule: | ||
| self.start_date = task_reschedule.start_date | ||
| ti.start_date = task_reschedule.start_date | ||
|
|
||
| # Secondly we find non-runnable but requeueable tis. We reset its state. | ||
| # This is because we might have hit concurrency limits, | ||
|
|
@@ -1300,44 +1364,47 @@ def check_and_change_state_before_execution( | |
| ignore_ti_state=ignore_ti_state, | ||
| description="requeueable deps", | ||
| ) | ||
| if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): | ||
| self.state = State.NONE | ||
| self.log.warning( | ||
| if not ti.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): | ||
| ti.state = State.NONE | ||
| ti.log.warning( | ||
| "Rescheduling due to concurrency limits reached " | ||
| "at task runtime. Attempt %s of " | ||
| "%s. State set to NONE.", | ||
| self.try_number, | ||
| self.max_tries + 1, | ||
| ti.try_number, | ||
| ti.max_tries + 1, | ||
| ) | ||
| self.queued_dttm = timezone.utcnow() | ||
| session.merge(self) | ||
| ti.queued_dttm = timezone.utcnow() | ||
| session.merge(ti) | ||
| session.commit() | ||
| return False | ||
| make_transient(ti) | ||
| return ti | ||
|
|
||
| if self.next_kwargs is not None: | ||
| self.log.info("Resuming after deferral") | ||
| if ti.next_kwargs is not None: | ||
| ti.log.info("Resuming after deferral") | ||
| else: | ||
| self.log.info("Starting attempt %s of %s", self.try_number, self.max_tries + 1) | ||
| self._try_number += 1 | ||
| ti.log.info("Starting attempt %s of %s", ti.try_number, ti.max_tries + 1) | ||
| ti._try_number += 1 | ||
|
|
||
| if not test_mode: | ||
| session.add(Log(State.RUNNING, self)) | ||
| self.state = State.RUNNING | ||
| self.external_executor_id = external_executor_id | ||
| self.end_date = None | ||
| session.add(Log(State.RUNNING, ti)) | ||
| ti.state = State.RUNNING | ||
| ti.external_executor_id = external_executor_id | ||
| ti.end_date = None | ||
| if not test_mode: | ||
| session.merge(self).task = task | ||
| session.merge(ti).task = task | ||
| session.commit() | ||
|
|
||
| # Closing all pooled connections to prevent | ||
| # "max number of connections reached" | ||
| settings.engine.dispose() # type: ignore | ||
| if verbose: | ||
| if mark_success: | ||
| self.log.info("Marking success for %s on %s", self.task, self.execution_date) | ||
| ti.log.info("Marking success for %s on %s", ti.task, ti.execution_date) | ||
| else: | ||
| self.log.info("Executing %s on %s", self.task, self.execution_date) | ||
| return True | ||
| ti.log.info("Executing %s on %s", ti.task, ti.execution_date) | ||
|
|
||
| make_transient(ti) | ||
| return ti | ||
|
|
||
| def _date_or_empty(self, attr: str) -> str: | ||
| result: datetime | None = getattr(self, attr, None) | ||
|
|
@@ -1715,7 +1782,12 @@ def run( | |
| session: Session = NEW_SESSION, | ||
| ) -> None: | ||
| """Run TaskInstance""" | ||
| res = self.check_and_change_state_before_execution( | ||
| ti_before_execution = TaskInstance.check_and_change_state_before_execution( | ||
| self.dag_id, | ||
| self.run_id, | ||
| self.task_id, | ||
| self.map_index, | ||
| self.task, | ||
| verbose=verbose, | ||
| ignore_all_deps=ignore_all_deps, | ||
| ignore_depends_on_past=ignore_depends_on_past, | ||
|
|
@@ -1728,7 +1800,8 @@ def run( | |
| pool=pool, | ||
| session=session, | ||
| ) | ||
| if not res: | ||
| self.state = ti_before_execution.state | ||
| if not self.state == State.RUNNING: | ||
| return | ||
|
|
||
| self._run_raw_task( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Primary key for task instance is composite with
dag_id,run_id,task_id,map_index. So the call is a little bit more verbose