Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def dag_backfill(args, dag=None):
dr = DagRun(dag.dag_id, execution_date=args.start_date)
for task in dag.tasks:
print(f"Task {task.task_id} located in DAG {dag.dag_id}")
ti = TaskInstance(task, run_id=None)
ti = TaskInstance.from_task(task, run_id=None)
ti.dag_run = dr
ti.dry_run()
else:
Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/kubernetes_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_pod_yaml(args):
dr = DagRun(dag.dag_id, execution_date=execution_date)
kube_config = KubeConfig()
for task in dag.tasks:
ti = TaskInstance(task, None)
ti = TaskInstance.from_task(task, None)
ti.dag_run = dr
pod = PodGenerator.construct_pod(
dag_id=args.dag_id,
Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _get_ti(
f"run_id or execution_date of {exec_date_or_run_id!r} not found"
)
# TODO: Validate map_index is in range?
ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index)
ti = TaskInstance.from_task(task, run_id=dag_run.run_id, map_index=map_index)
ti.dag_run = dag_run
else:
ti = ti_or_none
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence

for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
ti = TaskInstance.from_task(self, run_id=run_id, map_index=index, state=state)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,7 +1278,7 @@ def run(
)
.one()
)
ti = TaskInstance(self, run_id=dag_run.run_id)
ti = TaskInstance.from_task(self, run_id=dag_run.run_id)
except NoResultFound:
# This is _mostly_ only used in tests
dr = DagRun(
Expand All @@ -1288,7 +1288,7 @@ def run(
execution_date=info.logical_date,
data_interval=info.data_interval,
)
ti = TaskInstance(self, run_id=dr.run_id)
ti = TaskInstance.from_task(self, run_id=dr.run_id)
ti.dag_run = dr
session.add(dr)
session.flush()
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[s

def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]:
for map_index in indexes:
ti = TI(task, run_id=self.run_id, map_index=map_index)
ti = TI.from_task(task, run_id=self.run_id, map_index=map_index)
ti_mutation_hook(ti)
created_counts[ti.operator] += 1
yield ti
Expand Down Expand Up @@ -1185,7 +1185,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
for index in range(total_length):
if index in existing_indexes:
continue
ti = TI(task, run_id=self.run_id, map_index=index, state=None)
ti = TI.from_task(task, run_id=self.run_id, map_index=index, state=None)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
83 changes: 79 additions & 4 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from functools import partial
from pathlib import PurePath
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Generator, Iterable, NamedTuple, Tuple
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -140,6 +140,13 @@


PAST_DEPENDS_MET = "past_depends_met"
# List of fields which should be excluded from serialization.
TASK_INSTANCE_SERIALIZE_EXCLUDES = [
"_log",
"_sa_instance_state",
"dag_run",
"test_mode",
]


class TaskReturnCode(Enum):
Expand Down Expand Up @@ -374,6 +381,7 @@ class TaskInstance(Base, LoggingMixin):
a TI with mapped tasks that expanded to an empty list (state=skipped).
"""

__version__: ClassVar[int] = 1
__tablename__ = "task_instance"
task_id = Column(StringID(), primary_key=True, nullable=False)
dag_id = Column(StringID(), primary_key=True, nullable=False)
Expand Down Expand Up @@ -461,7 +469,6 @@ class TaskInstance(Base, LoggingMixin):
task_instance_note = relationship("TaskInstanceNote", back_populates="task_instance", uselist=False)
note = association_proxy("task_instance_note", "content", creator=_creator_note)
task: Operator # Not always set...

is_trigger_log_context: bool = False
"""Indicate to FileTaskHandler that logging context should be set up for trigger logging.

Expand All @@ -470,13 +477,40 @@ class TaskInstance(Base, LoggingMixin):

def __init__(
self,
task: Operator,
task: Operator | None = None,
execution_date: datetime | None = None,
run_id: str | None = None,
state: str | None = None,
map_index: int = -1,
ti_dict: dict[str, Any] | None = None,
):
"""
Constructs TaskInstance object.

Deprecated, prefer to use "from_task" or "deserialize" static methods.
"""
if task is not None:
self._init_from_task(task, execution_date, run_id, state, map_index)
return
raise AirflowException("Either task or ti_dict must be provided to construct TaskInstance.")

def _init_from_task(
self,
task: Operator,
execution_date: datetime | None,
run_id: str | None,
state: str | None,
map_index: int = -1,
):
super().__init__()
"""
Create TaskInstance from task operator.

:param task: task's Operator object.
:param execution_date: Optional execution time of the task.
:param run_id: Optional DAG run ID for the task.
:param state: Optional state of the task.
:param map_index: Optional map index. Defaults to -1 (non-mapped task).
"""
self.dag_id = task.dag_id
self.task_id = task.task_id
self.map_index = map_index
Expand Down Expand Up @@ -532,6 +566,47 @@ def __init__(
# can be changed when calling 'run'
self.test_mode = False

@staticmethod
def from_task(
task: Operator,
execution_date: datetime | None = None,
run_id: str | None = None,
state: str | None = None,
map_index: int = -1,
):
"""
Create TaskInstance from task operator.

:param task: task's Operator object.
:param execution_date: Optional execution time of the task.
:param run_id: Optional DAG run ID for the task.
:param state: Optional state of the task.
:param map_index: Optional map index. Defaults to -1 (non-mapped task).
"""
return TaskInstance(task, execution_date, run_id, state, map_index)

@staticmethod
def deserialize(ti_dict: dict[str, Any], version: int) -> TaskInstance:
"""Deserialize TaskInstance from dictionary."""
if version > TaskInstance.__version__:
raise TypeError(
f"""Version "{version}" is too big, don't know how to deserialize.
Latest supported version: {TaskInstance.__version}"""
)
ti = TaskInstance()
ti.__dict__ = ti_dict.copy()
ti.init_on_load()
return ti

def serialize(self) -> dict[str, Any]:
ti_dict = self.__dict__.copy()
for field in TASK_INSTANCE_SERIALIZE_EXCLUDES:
ti_dict.pop(field, None)
return ti_dict

def to_dict(self) -> dict[str, Any]:
return self.serialize()

@staticmethod
def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]:
""":meta private:"""
Expand Down
2 changes: 1 addition & 1 deletion airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ExternalDagLink(BaseOperatorLink):
name = "External DAG"

def get_link(self, operator, dttm):
ti = TaskInstance(task=operator, execution_date=dttm)
ti = TaskInstance.from_task(task=operator, execution_date=dttm)
operator.render_template_fields(ti.get_template_context())
query = {"dag_id": operator.external_dag_id, "execution_date": dttm.isoformat()}
return build_airflow_url_with_query(query)
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ class DagAttributeTypes(str, Enum):
XCOM_REF = "xcomref"
DATASET = "dataset"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
TASK_INSTANCE = "task_instance"
8 changes: 7 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
Expand Down Expand Up @@ -448,6 +448,10 @@ def serialize(
return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(cls.serialize(var.__dict__, strict=strict), type_=DAT.SIMPLE_TASK_INSTANCE)
elif isinstance(var, TaskInstance):
# FIXME: We can't use var.serialize() there due to problems in test
# test_recursive_serialize_calls_must_forward_kwargs
return cls._encode(cls.serialize(var=var.to_dict(), strict=strict), type_=DAT.TASK_INSTANCE)
else:
log.debug("Cast type %s to str in serialization.", type(var))
if strict:
Expand Down Expand Up @@ -502,6 +506,8 @@ def deserialize(cls, encoded_var: Any) -> Any:
return Dataset(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif type_ == DAT.TASK_INSTANCE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you are extending legacy code, i suggest using the more generic serialization code from serde

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We rely on serialized_objects in InternalAPI in general
https://github.com/apache/airflow/blob/main/airflow/api_internal/internal_api_call.py#L107
Do you think we could switch to serde now? Is it compatible with what serialize_objects offer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that you are going to rely on a lot of serialization and the new serializer/deserializer is significantly faster and more future proof (versioning) I think everything that does currently not have a schema (that's everything except a DAG*) should switch.

I am willing to help out to ease the migration if required.

  • I am working on DAG serialization/deserialization but untangling how it is done now and to improve the structure is taking time especially with all the edge cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. I think we will migrate to the new serializer/deserializer in this case, but probably outside of this PR.
If you believe it would better to migrate first, then I can revert this change and get back to it when using new way.
WDYT?

Copy link
Contributor

@bolkedebruin bolkedebruin Feb 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I incorrect in thinking that migration is a 2 line change in internal_api_call and you are not relying on any of the other serialized_objects (basically DAG)? If so then I would say do not add technical dept and migrate now. This allows us to call serialized_objects as stale and soon to be deprecated.

Otherwise, keep it and and add it to the todo of AIP-44?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we need to make change there and server-side:
https://github.com/apache/airflow/blob/main/airflow/api_internal/endpoints/rpc_api_endpoint.py#L76

There are more methods with internal_api_call decorator. I did a quick check and I see that (beside primitives) we already need serialization to Dag,DagRun, BaseXCom, CallbackRequest (and probably more soon)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question. (and really sorry I have not looked at it earlier) and also follow up on #29513 as well.

Are we REALLY sure we need to serialize the whole TaskInstance (and other) objects to be passed via internal_api_call ?

For me this is an indication that we either have too narrow of a scope for an @internal_api call (generally speaking the whole internal_api_call should span the whole DB transaction. And since we are trying to pass an ORM object (TaskInstance, DAGRun etc.) it means that that object must have been retrieved before within a transaction. So it means that our internal_api_call should wrap the retrieval as well. Which might simply mean that we need to do some refactoring and add extract new methods (and then decorate them).

That's of course a general statement and approach and there might be cases that this require a bit deeper refactoring.

Which methods are affected @mhenc (besides the #29513 one) ? maybe we can look toghether and figure out approach for all of them ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example in #29513 (comment) I proposed the solution that would avoid TaskInstance serialization altogether. I am reasonable convinced, that similar approach can be done for all ORM objects of ours and that we do not need to serialize any of them (in which case the whole PR might not be needed).

Copy link
Contributor Author

@mhenc mhenc Feb 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this works only for client -> Internal API side.

What about Internal API -> client, e.g. worker. We need to have TaskInstance object in worker to run the task, e.g.
https://github.com/apache/airflow/blob/main/airflow/cli/commands/task_command.py#L187

unless of course we are able to refactor it completely

Copy link
Member

@potiuk potiuk Feb 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's precisely what I am refactoring now (with my POC)

return TaskInstance.deserialize(ti_dict=cls.deserialize(var), version=1)
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")

Expand Down
2 changes: 1 addition & 1 deletion airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@ def rendered_templates(self, session):
# make sense in this situation, but "works" prior to AIP-39. This
# "fakes" a temporary DagRun-TaskInstance association (not saved to
# database) for presentation only.
ti = TaskInstance(raw_task, map_index=map_index)
ti = TaskInstance.from_task(raw_task, map_index=map_index)
ti.dag_run = DagRun(dag_id=dag_id, execution_date=dttm)
else:
ti = dag_run.get_task_instance(task_id=task_id, map_index=map_index, session=session)
Expand Down
2 changes: 1 addition & 1 deletion kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def create_context(task) -> Context:
execution_date=execution_date,
run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
)
task_instance = TaskInstance(task=task)
task_instance = TaskInstance.from_task(task=task)
task_instance.dag_run = dag_run
task_instance.dag_id = dag.dag_id
task_instance.xcom_push = mock.Mock() # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion tests/api/common/test_delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def setup_dag_models(self, for_sub_dag=False):
with create_session() as session:
session.add(DM(dag_id=self.key, fileloc=self.dag_file_path, is_subdag=for_sub_dag))
dr = DR(dag_id=self.key, run_type=DagRunType.MANUAL, run_id="test", execution_date=test_date)
ti = TI(task=task, state=State.SUCCESS)
ti = TI.from_task(task=task, state=State.SUCCESS)
ti.dag_run = dr
session.add_all((dr, ti))
# flush to ensure task instance if written before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,23 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags={}):

index = 0
for i in range(dags[dag_id]["success"]):
ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS)
ti = TaskInstance.from_task(
mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS
)
setattr(ti, "start_date", DEFAULT_DATETIME_1)
session.add(ti)
index += 1
for i in range(dags[dag_id]["failed"]):
ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.FAILED)
ti = TaskInstance.from_task(
mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.FAILED
)
setattr(ti, "start_date", DEFAULT_DATETIME_1)
session.add(ti)
index += 1
for i in range(dags[dag_id]["running"]):
ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.RUNNING)
ti = TaskInstance.from_task(
mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.RUNNING
)
setattr(ti, "start_date", DEFAULT_DATETIME_1)
session.add(ti)
index += 1
Expand Down
8 changes: 4 additions & 4 deletions tests/api_connexion/endpoints/test_task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def create_task_instances(
state=dag_run_state,
)
session.add(dr)
ti = TaskInstance(task=tasks[i], **self.ti_init)
ti = TaskInstance.from_task(task=tasks[i], **self.ti_init)
ti.dag_run = dr
ti.note = "placeholder-note"

Expand Down Expand Up @@ -385,7 +385,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session):
tis = self.create_task_instances(session)
old_ti = tis[0]
for idx in (1, 2):
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
ti = TaskInstance.from_task(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]:
setattr(ti, attr, getattr(old_ti, attr))
Expand Down Expand Up @@ -1697,7 +1697,7 @@ def test_should_update_mapped_task_instance_state(self, session):
NEW_STATE = "failed"
map_index = 1
tis = self.create_task_instances(session)
ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index)
ti = TaskInstance.from_task(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index)
ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
session.add(ti)
session.commit()
Expand Down Expand Up @@ -1871,7 +1871,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session):
tis = self.create_task_instances(session)
old_ti = tis[0]
for idx in (1, 2):
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
ti = TaskInstance.from_task(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]:
setattr(ti, attr, getattr(old_ti, attr))
Expand Down
8 changes: 4 additions & 4 deletions tests/api_connexion/endpoints/test_xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _create_xcom_entry(self, dag_id, run_id, execution_date, task_id, xcom_key,
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti = TaskInstance.from_task(EmptyOperator(task_id=task_id), run_id=run_id)
ti.dag_id = dag_id
session.add(ti)
backend.set(
Expand Down Expand Up @@ -365,7 +365,7 @@ def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id):
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti = TaskInstance.from_task(EmptyOperator(task_id=task_id), run_id=run_id)
ti.dag_id = dag_id
session.add(ti)

Expand Down Expand Up @@ -401,7 +401,7 @@ def _create_invalid_xcom_entries(self, execution_date):
run_type=DagRunType.MANUAL,
)
session.add(dagrun1)
ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id")
ti = TaskInstance.from_task(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id")
ti.dag_id = "invalid_dag"
session.add(ti)
for i in [1, 2]:
Expand Down Expand Up @@ -486,7 +486,7 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids):
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
ti = TaskInstance(EmptyOperator(task_id=self.task_id), run_id=self.run_id)
ti = TaskInstance.from_task(EmptyOperator(task_id=self.task_id), run_id=self.run_id)
ti.dag_id = self.dag_id
session.add(ti)

Expand Down
4 changes: 2 additions & 2 deletions tests/api_connexion/schemas/test_task_instance_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def set_attrs(self, session, dag_maker):
session.rollback()

def test_task_instance_schema_without_sla_and_rendered(self, session):
ti = TI(task=self.task, **self.default_ti_init)
ti = TI.from_task(task=self.task, **self.default_ti_init)
for key, value in self.default_ti_extras.items():
setattr(ti, key, value)
serialized_ti = task_instance_schema.dump((ti, None, None))
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_task_instance_schema_with_sla_and_rendered(self, session):
)
session.add(sla_miss)
session.flush()
ti = TI(task=self.task, **self.default_ti_init)
ti = TI.from_task(task=self.task, **self.default_ti_init)
for key, value in self.default_ti_extras.items():
setattr(ti, key, value)
self.task.template_fields = ["partitions"]
Expand Down
Loading