Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bundle_name to ParseImportError #45480

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow/api/common/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
# This handles the case when the dag_id is changed in the file
session.execute(
delete(ParseImportError)
.where(ParseImportError.filename == dag.fileloc)
.where(
ParseImportError.filename == dag.fileloc,
ParseImportError.bundle_name == dag.get_bundle_name(session),
)
.execution_options(synchronize_session="fetch")
)

Expand Down
27 changes: 21 additions & 6 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from sqlalchemy import func, select
from sqlalchemy import func, select, tuple_

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
Expand Down Expand Up @@ -61,7 +61,9 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) ->
readable_dag_ids = security.get_readable_dags()
file_dag_ids = {
dag_id[0]
for dag_id in session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all()
for dag_id in session.query(DagModel.dag_id)
.filter(DagModel.fileloc == error.filename, DagModel.bundle_name == error.bundle_name)
.all()
}

# Can the user read any DAGs in the file?
Expand Down Expand Up @@ -98,9 +100,17 @@ def get_import_errors(
if not can_read_all_dags:
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
readable_dag_ids = security.get_readable_dags()
dagfiles_stmt = select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids))
query = query.where(ParseImportError.filename.in_(dagfiles_stmt))
count_query = count_query.where(ParseImportError.filename.in_(dagfiles_stmt))
dagfiles_stmt = (
select(DagModel.fileloc, DagModel.bundle_name)
.distinct()
.where(DagModel.dag_id.in_(readable_dag_ids))
)
query = query.where(
tuple_(ParseImportError.filename, ParseImportError.bundle_name).in_(dagfiles_stmt)
)
count_query = count_query.where(
tuple_(ParseImportError.filename, ParseImportError.bundle_name).in_(dagfiles_stmt)
)

total_entries = session.scalars(count_query).one()
import_errors = session.scalars(query.offset(offset).limit(limit)).all()
Expand All @@ -109,7 +119,12 @@ def get_import_errors(
for import_error in import_errors:
# Check if user has read access to all the DAGs defined in the file
file_dag_ids = (
session.query(DagModel.dag_id).filter(DagModel.fileloc == import_error.filename).all()
session.query(DagModel.dag_id)
.filter(
DagModel.fileloc == import_error.filename,
DagModel.bundle_name == import_error.bundle_name,
)
.all()
)
requests: Sequence[IsAuthorizedDagRequest] = [
{
Expand Down
1 change: 1 addition & 0 deletions airflow/api_connexion/schemas/error_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Meta:
import_error_id = auto_field("id", dump_only=True)
timestamp = auto_field(format="iso", dump_only=True)
filename = auto_field(dump_only=True)
bundle_name = auto_field(dump_only=True)
stack_trace = auto_field("stacktrace", dump_only=True)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_import_errors(
"id",
"timestamp",
"filename",
"bundle_name",
"stacktrace",
],
ParseImportError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner:
job=Job(),
processor=DagFileProcessorManager(
processor_timeout=processor_timeout_seconds,
dag_directory=args.subdir,
max_runs=args.num_runs,
),
)
Expand Down
50 changes: 39 additions & 11 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,14 @@
log = logging.getLogger(__name__)


def _create_orm_dags(dags: Iterable[MaybeSerializedDAG], *, session: Session) -> Iterator[DagModel]:
def _create_orm_dags(
bundle_name: str, dags: Iterable[MaybeSerializedDAG], *, session: Session
) -> Iterator[DagModel]:
for dag in dags:
orm_dag = DagModel(dag_id=dag.dag_id)
if dag.is_paused_upon_creation is not None:
orm_dag.is_paused = dag.is_paused_upon_creation
orm_dag.bundle_name = bundle_name
log.info("Creating ORM DAG for %s", dag.dag_id)
session.add(orm_dag)
yield orm_dag
Expand Down Expand Up @@ -238,38 +241,58 @@ def _update_dag_warnings(
session.merge(warning_to_add)


def _update_import_errors(files_parsed: set[str], import_errors: dict[str, str], session: Session):
def _update_import_errors(
files_parsed: set[str], bundle_name: str, import_errors: dict[str, str], session: Session
):
from airflow.listeners.listener import get_listener_manager

# We can remove anything from files parsed in this batch that doesn't have an error. We need to remove old
# errors (i.e. from files that are removed) separately

session.execute(delete(ParseImportError).where(ParseImportError.filename.in_(list(files_parsed))))
session.execute(
delete(ParseImportError).where(
ParseImportError.filename.in_(list(files_parsed)), ParseImportError.bundle_name == bundle_name
)
)

existing_import_error_files = set(session.scalars(select(ParseImportError.filename)))
existing_import_error_files = set(
session.execute(select(ParseImportError.filename, ParseImportError.bundle_name))
)

# Add the errors of the processed files
for filename, stacktrace in import_errors.items():
if filename in existing_import_error_files:
session.query(ParseImportError).where(ParseImportError.filename == filename).update(
{"filename": filename, "timestamp": utcnow(), "stacktrace": stacktrace},
if (filename, bundle_name) in existing_import_error_files:
session.query(ParseImportError).where(
ParseImportError.filename == filename, ParseImportError.bundle_name == bundle_name
).update(
{
"filename": filename,
"bundle_name": bundle_name,
"timestamp": utcnow(),
"stacktrace": stacktrace,
},
)
# sending notification when an existing dag import error occurs
get_listener_manager().hook.on_existing_dag_import_error(filename=filename, stacktrace=stacktrace)
else:
session.add(
ParseImportError(
filename=filename,
bundle_name=bundle_name,
timestamp=utcnow(),
stacktrace=stacktrace,
)
)
# sending notification when a new dag import error occurs
get_listener_manager().hook.on_new_dag_import_error(filename=filename, stacktrace=stacktrace)
session.query(DagModel).filter(DagModel.fileloc == filename).update({"has_import_errors": True})
session.query(DagModel).filter(
DagModel.fileloc == filename, DagModel.bundle_name == bundle_name
).update({"has_import_errors": True})


def update_dag_parsing_results_in_db(
bundle_name: str,
bundle_version: str | None,
dags: Collection[MaybeSerializedDAG],
import_errors: dict[str, str],
warnings: set[DagWarning],
Expand Down Expand Up @@ -307,8 +330,7 @@ def update_dag_parsing_results_in_db(
)
log.debug("Calling the DAG.bulk_sync_to_db method")
try:
DAG.bulk_write_to_db(dags, session=session)
# Write Serialized DAGs to DB, capturing errors
DAG.bulk_write_to_db(bundle_name, bundle_version, dags, session=session)
# Write Serialized DAGs to DB, capturing errors
for dag in dags:
serialize_errors.extend(_serialize_dag_capturing_errors(dag, session))
Expand All @@ -327,6 +349,7 @@ def update_dag_parsing_results_in_db(
good_dag_filelocs = {dag.fileloc for dag in dags if dag.fileloc not in import_errors}
_update_import_errors(
files_parsed=good_dag_filelocs,
bundle_name=bundle_name,
import_errors=import_errors,
session=session,
)
Expand All @@ -346,6 +369,8 @@ class DagModelOperation(NamedTuple):
"""Collect DAG objects and perform database operations for them."""

dags: dict[str, MaybeSerializedDAG]
bundle_name: str
bundle_version: str | None

def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
"""Find existing DagModel objects from DAG objects."""
Expand All @@ -365,7 +390,8 @@ def add_dags(self, *, session: Session) -> dict[str, DagModel]:
orm_dags.update(
(model.dag_id, model)
for model in _create_orm_dags(
(dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags),
bundle_name=self.bundle_name,
dags=(dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags),
session=session,
)
)
Expand Down Expand Up @@ -430,6 +456,8 @@ def update_dags(
dm.timetable_summary = dag.timetable.summary
dm.timetable_description = dag.timetable.description
dm.asset_expression = dag.timetable.asset_condition.as_expression()
dm.bundle_name = self.bundle_name
dm.latest_bundle_version = self.bundle_version

last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id)
if last_automated_run is None:
Expand Down
Loading