Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.api_connexion.types import APIResponse
from airflow.models import Trigger, Variable, XCom
from airflow.models.dagwarning import DagWarning
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.serialized_objects import BaseSerialization

log = logging.getLogger(__name__)
Expand All @@ -46,6 +47,8 @@ def _initialize_map() -> dict[str, Callable]:
DagModel.get_paused_dag_ids,
DagFileProcessorManager.clear_nonexistent_import_errors,
DagWarning.purge_inactive_dag_warnings,
TaskInstance.check_and_change_state_before_execution,
TaskInstance.retrieve_from_db,
XCom.get_value,
XCom.get_one,
XCom.get_many,
Expand Down
10 changes: 8 additions & 2 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def sigusr2_debug_handler(signum, frame):
# This is not supported on Windows systems
signal.signal(signal.SIGUSR2, sigusr2_debug_handler)

if not self.task_instance.check_and_change_state_before_execution(
self.task_instance = TaskInstance.check_and_change_state_before_execution(
self.task_instance.dag_id,
self.task_instance.run_id,
self.task_instance.task_id,
self.task_instance.map_index,
self.task_instance.task,
mark_success=self.mark_success,
ignore_all_deps=self.ignore_all_deps,
ignore_depends_on_past=self.ignore_depends_on_past,
Expand All @@ -148,7 +153,8 @@ def sigusr2_debug_handler(signum, frame):
job_id=self.id,
pool=self.pool,
external_executor_id=self.external_executor_id,
):
)
if not self.task_instance.state == State.RUNNING:
self.log.info("Task is not able to be run")
return None

Expand Down
169 changes: 121 additions & 48 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm import make_transient, reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import ColumnOperators, case

from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.datasets import Dataset
Expand Down Expand Up @@ -787,6 +788,51 @@ def error(self, session: Session = NEW_SESSION) -> None:
session.merge(self)
session.commit()

@classmethod
@internal_api_call
@provide_session
def retrieve_from_db(
cls,
dag_id: str,
run_id: str,
task_id: str,
map_index: int,
session: Session = NEW_SESSION,
lock_for_update: bool = False,
) -> TaskInstance | None:
"""
Retrieve the task instance from the database based on the primary key

:param dag_id: The Dag ID
:param run_id: The Dag run ID
:param task_id: The Task ID
:param map_index: The map index
:param session: SQLAlchemy ORM Session
:param lock_for_update: if True, indicates that the database should
lock the TaskInstance (issuing a FOR UPDATE clause) until the
session is committed.
:return: The TaskInstance object retrieved from the database.
"""
logger = cls.logger()
logger.debug(
f"Retrieving TaskInstance from DB with primary key: ({dag_id}, {task_id}, {run_id}, {map_index})"
)

qry = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag_id,
TaskInstance.task_id == task_id,
TaskInstance.run_id == run_id,
TaskInstance.map_index == map_index,
)

if lock_for_update:
for attempt in run_with_db_retries(logger=logger):
with attempt:
ti: TaskInstance | None = qry.with_for_update().one_or_none()
else:
ti = qry.one_or_none()
return ti

@provide_session
def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
"""
Expand Down Expand Up @@ -1120,7 +1166,6 @@ def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session
dep_context = dep_context or DepContext()
for dep in dep_context.deps | self.task.deps:
for dep_status in dep.get_dep_statuses(self, session, dep_context):

self.log.debug(
"%s dependency '%s' PASSED: %s, %s",
self,
Expand Down Expand Up @@ -1207,9 +1252,15 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:

return dr

@staticmethod
@internal_api_call
@provide_session
def check_and_change_state_before_execution(
self,
dag_id: str,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Primary key for task instance is composite with dag_id, run_id, task_id, map_index. So the call is a little bit more verbose

run_id: str,
task_id: str,
map_index: int,
task: Operator,
verbose: bool = True,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
Expand All @@ -1222,13 +1273,18 @@ def check_and_change_state_before_execution(
pool: str | None = None,
external_executor_id: str | None = None,
session: Session = NEW_SESSION,
) -> bool:
"""
Checks dependencies and then sets state to RUNNING if they are met. Returns
True if and only if state is set to RUNNING, which implies that task should be
executed, in preparation for _run_raw_task

:param verbose: whether to turn on more verbose logging
) -> TaskInstance:
"""
Retrieve the TI based on its primary keys. Checks dependencies and then sets state to RUNNING if they
are met. Returns an updated version of the retrieved TI. If state is set to RUNNING, it implies
that task should be executed, in preparation for _run_raw_task.

:param dag_id: The Dag ID
:param run_id: The Dag run ID
:param task_id: The Task ID
:param map_index: The map index
:pram task: The task object
:param verbose: Whether to turn on more verbose logging
:param ignore_all_deps: Ignore all of the non-critical dependencies, just runs
:param ignore_depends_on_past: Ignore depends_on_past DAG attribute
:param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped
Expand All @@ -1237,25 +1293,32 @@ def check_and_change_state_before_execution(
:param mark_success: Don't run the task, mark its state as success
:param test_mode: Doesn't record success or failure in the DB
:param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
:param pool: specifies the pool to use to run the task instance
:param pool: Specifies the pool to use to run the task instance
:param external_executor_id: The identifier of the celery executor
:param session: SQLAlchemy ORM Session
:return: whether the state was changed to running or not
"""
task = self.task
self.refresh_from_task(task, pool_override=pool)
self.test_mode = test_mode
self.refresh_from_db(session=session, lock_for_update=True)
self.job_id = job_id
self.hostname = get_hostname()
self.pid = None
ti = TaskInstance.retrieve_from_db(
dag_id, run_id, task_id, map_index, session=session, lock_for_update=True
)
if ti is None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometime this function is called but there is not yet a db record for the task instance. cf test_check_and_change_state_before_execution_dep_not_met which highlight this case.

ti = TaskInstance(
task=task,
run_id=run_id,
map_index=map_index,
)
ti.refresh_from_task(task, pool_override=pool)
ti.test_mode = test_mode
ti.job_id = job_id
ti.hostname = get_hostname()
ti.pid = None

if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
if not ignore_all_deps and not ignore_ti_state and ti.state == State.SUCCESS:
Stats.incr(
"previously_succeeded",
1,
1,
tags={"dag_id": self.dag_id, "run_id": self.run_id, "task_id": self.task_id},
tags={"dag_id": ti.dag_id, "run_id": ti.run_id, "task_id": ti.task_id},
)

if not mark_success:
Expand All @@ -1270,23 +1333,24 @@ def check_and_change_state_before_execution(
ignore_task_deps=ignore_task_deps,
description="non-requeueable deps",
)
if not self.are_dependencies_met(
if not ti.are_dependencies_met(
dep_context=non_requeueable_dep_context, session=session, verbose=True
):
session.commit()
return False
make_transient(ti)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Task instance is expired if the session was not passed down from the parent caller but provided by @provide_session. We need to make_transient to be able to do:

ti = TI.check_and_change_state_before_execution()

We have only 1 case where there is no session in the parent, LocalTaskJob._execute.

return ti

# For reporting purposes, we report based on 1-indexed,
# not 0-indexed lists (i.e. Attempt 1 instead of
# Attempt 0 for the first attempt).
# Set the task start date. In case it was re-scheduled use the initial
# start date that is recorded in task_reschedule table
# If the task continues after being deferred (next_method is set), use the original start_date
self.start_date = self.start_date if self.next_method else timezone.utcnow()
if self.state == State.UP_FOR_RESCHEDULE:
task_reschedule: TR = TR.query_for_task_instance(self, session=session).first()
ti.start_date = ti.start_date if ti.next_method else timezone.utcnow()
if ti.state == State.UP_FOR_RESCHEDULE:
task_reschedule: TR = TR.query_for_task_instance(ti, session=session).first()
if task_reschedule:
self.start_date = task_reschedule.start_date
ti.start_date = task_reschedule.start_date

# Secondly we find non-runnable but requeueable tis. We reset its state.
# This is because we might have hit concurrency limits,
Expand All @@ -1300,44 +1364,47 @@ def check_and_change_state_before_execution(
ignore_ti_state=ignore_ti_state,
description="requeueable deps",
)
if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True):
self.state = State.NONE
self.log.warning(
if not ti.are_dependencies_met(dep_context=dep_context, session=session, verbose=True):
ti.state = State.NONE
ti.log.warning(
"Rescheduling due to concurrency limits reached "
"at task runtime. Attempt %s of "
"%s. State set to NONE.",
self.try_number,
self.max_tries + 1,
ti.try_number,
ti.max_tries + 1,
)
self.queued_dttm = timezone.utcnow()
session.merge(self)
ti.queued_dttm = timezone.utcnow()
session.merge(ti)
session.commit()
return False
make_transient(ti)
return ti

if self.next_kwargs is not None:
self.log.info("Resuming after deferral")
if ti.next_kwargs is not None:
ti.log.info("Resuming after deferral")
else:
self.log.info("Starting attempt %s of %s", self.try_number, self.max_tries + 1)
self._try_number += 1
ti.log.info("Starting attempt %s of %s", ti.try_number, ti.max_tries + 1)
ti._try_number += 1

if not test_mode:
session.add(Log(State.RUNNING, self))
self.state = State.RUNNING
self.external_executor_id = external_executor_id
self.end_date = None
session.add(Log(State.RUNNING, ti))
ti.state = State.RUNNING
ti.external_executor_id = external_executor_id
ti.end_date = None
if not test_mode:
session.merge(self).task = task
session.merge(ti).task = task
session.commit()

# Closing all pooled connections to prevent
# "max number of connections reached"
settings.engine.dispose() # type: ignore
if verbose:
if mark_success:
self.log.info("Marking success for %s on %s", self.task, self.execution_date)
ti.log.info("Marking success for %s on %s", ti.task, ti.execution_date)
else:
self.log.info("Executing %s on %s", self.task, self.execution_date)
return True
ti.log.info("Executing %s on %s", ti.task, ti.execution_date)

make_transient(ti)
return ti

def _date_or_empty(self, attr: str) -> str:
result: datetime | None = getattr(self, attr, None)
Expand Down Expand Up @@ -1715,7 +1782,12 @@ def run(
session: Session = NEW_SESSION,
) -> None:
"""Run TaskInstance"""
res = self.check_and_change_state_before_execution(
ti_before_execution = TaskInstance.check_and_change_state_before_execution(
self.dag_id,
self.run_id,
self.task_id,
self.map_index,
self.task,
verbose=verbose,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
Expand All @@ -1728,7 +1800,8 @@ def run(
pool=pool,
session=session,
)
if not res:
self.state = ti_before_execution.state
if not self.state == State.RUNNING:
return

self._run_raw_task(
Expand Down
5 changes: 0 additions & 5 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def test_heartbeat_failed_fast(self):
dag_id = "test_heartbeat_failed_fast"
task_id = "test_heartbeat_failed_fast_op"
with create_session() as session:

dag_id = "test_heartbeat_failed_fast"
task_id = "test_heartbeat_failed_fast_op"
dag = self.dagbag.get_dag(dag_id)
Expand Down Expand Up @@ -341,7 +340,6 @@ def test_mark_success_no_kill(self, caplog, get_test_dag, session):
)

def test_localtaskjob_double_trigger(self):

dag = self.dagbag.dags.get("test_localtaskjob_double_trigger")
task = dag.get_task("test_localtaskjob_double_trigger_task")

Expand Down Expand Up @@ -379,7 +377,6 @@ def test_localtaskjob_double_trigger(self):
@patch.object(StandardTaskRunner, "return_code")
@mock.patch("airflow.jobs.scheduler_job.Stats.incr", autospec=True)
def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code, create_dummy_dag):

_, task = create_dummy_dag("test_localtaskjob_code")

ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
Expand All @@ -400,7 +397,6 @@ def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code,

@patch.object(StandardTaskRunner, "return_code")
def test_localtaskjob_maintain_heart_rate(self, mock_return_code, caplog, create_dummy_dag):

_, task = create_dummy_dag("test_localtaskjob_double_trigger")

ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
Expand Down Expand Up @@ -685,7 +681,6 @@ def test_fast_follow(
get_test_dag,
):
with conf_vars(conf):

dag = get_test_dag(
"test_dagrun_fast_follow",
)
Expand Down
Loading