Skip to content

Commit 1677d80

Browse files
bkossakowskaBeata Kossakowska
andauthored
Add deferrable mode to DataprocInstantiateWorkflowTemplateOperator (#28618)
Co-authored-by: Beata Kossakowska <bkossakowska@google.com>
1 parent c5d548b commit 1677d80

File tree

8 files changed

+303
-15
lines changed

8 files changed

+303
-15
lines changed

‎airflow/providers/google/cloud/hooks/dataproc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
2929
from google.api_core.operation import Operation
3030
from google.api_core.operation_async import AsyncOperation
31+
from google.api_core.operations_v1.operations_client import OperationsClient
3132
from google.api_core.retry import Retry
3233
from google.cloud.dataproc_v1 import (
3334
Batch,
@@ -1047,6 +1048,10 @@ def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncCli
10471048
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
10481049
)
10491050

1051+
def get_operations_client(self, region: str) -> OperationsClient:
1052+
"""Returns OperationsClient"""
1053+
return self.get_template_client(region=region).transport.operations_client
1054+
10501055
@GoogleBaseHook.fallback_to_default_project_id
10511056
async def create_cluster(
10521057
self,
@@ -1459,6 +1464,9 @@ async def instantiate_inline_workflow_template(
14591464
)
14601465
return operation
14611466

1467+
async def get_operation(self, region, operation_name):
1468+
return await self.get_operations_client(region).get_operation(name=operation_name)
1469+
14621470
@GoogleBaseHook.fallback_to_default_project_id
14631471
async def get_job(
14641472
self,

‎airflow/providers/google/cloud/operators/dataproc.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
DataprocClusterTrigger,
5656
DataprocDeleteClusterTrigger,
5757
DataprocSubmitTrigger,
58+
DataprocWorkflowTrigger,
5859
)
5960
from airflow.utils import timezone
6061

@@ -1688,7 +1689,7 @@ class DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
16881689
16891690
.. seealso::
16901691
Please refer to:
1691-
https://cloud.google.com/dataproc/docs/reference/rest/v1beta2/projects.regions.workflowTemplates/instantiate
1692+
https://cloud.google.com/dataproc/docs/reference/rest/v1/projects.regions.workflowTemplates/instantiate
16921693
16931694
:param template_id: The id of the template. (templated)
16941695
:param project_id: The ID of the google cloud project in which
@@ -1717,6 +1718,8 @@ class DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
17171718
If set as a sequence, the identities from the list must grant
17181719
Service Account Token Creator IAM role to the directly preceding identity, with first
17191720
account from the list granting this role to the originating account (templated).
1721+
:param deferrable: Run operator in the deferrable mode.
1722+
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
17201723
"""
17211724

17221725
template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters")
@@ -1737,10 +1740,13 @@ def __init__(
17371740
metadata: Sequence[tuple[str, str]] = (),
17381741
gcp_conn_id: str = "google_cloud_default",
17391742
impersonation_chain: str | Sequence[str] | None = None,
1743+
deferrable: bool = False,
1744+
polling_interval_seconds: int = 10,
17401745
**kwargs,
17411746
) -> None:
17421747
super().__init__(**kwargs)
1743-
1748+
if deferrable and polling_interval_seconds <= 0:
1749+
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
17441750
self.template_id = template_id
17451751
self.parameters = parameters
17461752
self.version = version
@@ -1752,6 +1758,8 @@ def __init__(
17521758
self.request_id = request_id
17531759
self.gcp_conn_id = gcp_conn_id
17541760
self.impersonation_chain = impersonation_chain
1761+
self.deferrable = deferrable
1762+
self.polling_interval_seconds = polling_interval_seconds
17551763

17561764
def execute(self, context: Context):
17571765
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
@@ -1772,8 +1780,34 @@ def execute(self, context: Context):
17721780
context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id
17731781
)
17741782
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
1775-
operation.result()
1776-
self.log.info("Workflow %s completed successfully", self.workflow_id)
1783+
if not self.deferrable:
1784+
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
1785+
self.log.info("Workflow %s completed successfully", self.workflow_id)
1786+
else:
1787+
self.defer(
1788+
trigger=DataprocWorkflowTrigger(
1789+
template_name=self.template_id,
1790+
name=operation.operation.name,
1791+
project_id=self.project_id,
1792+
region=self.region,
1793+
gcp_conn_id=self.gcp_conn_id,
1794+
impersonation_chain=self.impersonation_chain,
1795+
polling_interval_seconds=self.polling_interval_seconds,
1796+
),
1797+
method_name="execute_complete",
1798+
)
1799+
1800+
def execute_complete(self, context, event=None) -> None:
1801+
"""
1802+
Callback for when the trigger fires - returns immediately.
1803+
Relies on trigger to throw an exception, otherwise it assumes execution was
1804+
successful.
1805+
"""
1806+
if event["status"] == "failed" or event["status"] == "error":
1807+
self.log.exception("Unexpected error in the operation.")
1808+
raise AirflowException(event["message"])
1809+
1810+
self.log.info("Workflow %s completed successfully", event["operation_name"])
17771811

17781812

17791813
class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator):

‎airflow/providers/google/cloud/triggers/dataproc.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,93 @@ def _get_hook(self) -> DataprocAsyncHook:
310310
gcp_conn_id=self.gcp_conn_id,
311311
impersonation_chain=self.impersonation_chain,
312312
)
313+
314+
315+
class DataprocWorkflowTrigger(BaseTrigger):
316+
"""
317+
Trigger that periodically polls information from Dataproc API to verify status.
318+
Implementation leverages asynchronous transport.
319+
"""
320+
321+
def __init__(
322+
self,
323+
template_name: str,
324+
name: str,
325+
region: str,
326+
project_id: str | None = None,
327+
gcp_conn_id: str = "google_cloud_default",
328+
impersonation_chain: str | Sequence[str] | None = None,
329+
delegate_to: str | None = None,
330+
polling_interval_seconds: int = 10,
331+
):
332+
super().__init__()
333+
self.gcp_conn_id = gcp_conn_id
334+
self.template_name = template_name
335+
self.name = name
336+
self.impersonation_chain = impersonation_chain
337+
self.project_id = project_id
338+
self.region = region
339+
self.polling_interval_seconds = polling_interval_seconds
340+
self.delegate_to = delegate_to
341+
if delegate_to:
342+
warnings.warn(
343+
"'delegate_to' parameter is deprecated, please use 'impersonation_chain'", DeprecationWarning
344+
)
345+
346+
def serialize(self):
347+
return (
348+
"airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger",
349+
{
350+
"template_name": self.template_name,
351+
"name": self.name,
352+
"project_id": self.project_id,
353+
"region": self.region,
354+
"gcp_conn_id": self.gcp_conn_id,
355+
"delegate_to": self.delegate_to,
356+
"impersonation_chain": self.impersonation_chain,
357+
"polling_interval_seconds": self.polling_interval_seconds,
358+
},
359+
)
360+
361+
async def run(self) -> AsyncIterator["TriggerEvent"]:
362+
hook = self._get_hook()
363+
while True:
364+
try:
365+
operation = await hook.get_operation(region=self.region, operation_name=self.name)
366+
if operation.done:
367+
if operation.error.message:
368+
yield TriggerEvent(
369+
{
370+
"operation_name": operation.name,
371+
"operation_done": operation.done,
372+
"status": "error",
373+
"message": operation.error.message,
374+
}
375+
)
376+
return
377+
yield TriggerEvent(
378+
{
379+
"operation_name": operation.name,
380+
"operation_done": operation.done,
381+
"status": "success",
382+
"message": "Operation is successfully ended.",
383+
}
384+
)
385+
return
386+
else:
387+
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
388+
await asyncio.sleep(self.polling_interval_seconds)
389+
except Exception as e:
390+
self.log.exception("Exception occurred while checking operation status.")
391+
yield TriggerEvent(
392+
{
393+
"status": "failed",
394+
"message": str(e),
395+
}
396+
)
397+
398+
def _get_hook(self) -> DataprocAsyncHook: # type: ignore[override]
399+
return DataprocAsyncHook(
400+
gcp_conn_id=self.gcp_conn_id,
401+
impersonation_chain=self.impersonation_chain,
402+
)

‎docs/apache-airflow-providers-google/operators/cloud/dataproc.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,14 @@ Once a workflow is created users can trigger it using
262262
:start-after: [START how_to_cloud_dataproc_trigger_workflow_template]
263263
:end-before: [END how_to_cloud_dataproc_trigger_workflow_template]
264264

265+
Also for all this action you can use operator in the deferrable mode:
266+
267+
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_workflow.py
268+
:language: python
269+
:dedent: 4
270+
:start-after: [START how_to_cloud_dataproc_trigger_workflow_template_async]
271+
:end-before: [END how_to_cloud_dataproc_trigger_workflow_template_async]
272+
265273
The inline operator is an alternative. It creates a workflow, run it, and delete it afterwards:
266274
:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator`:
267275

‎tests/providers/google/cloud/hooks/test_dataproc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,17 @@ async def test_instantiate_workflow_template(self, mock_client):
729729
metadata=(),
730730
)
731731

732+
@pytest.mark.asyncio
733+
@async_mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_operation"))
734+
async def test_get_operation(self, mock_client):
735+
mock_client.return_value = None
736+
hook = DataprocAsyncHook(
737+
gcp_conn_id="google_cloud_default", delegate_to=None, impersonation_chain=None
738+
)
739+
await hook.get_operation(region=GCP_LOCATION, operation_name="operation_name")
740+
mock_client.assert_called_once()
741+
mock_client.assert_called_with(region=GCP_LOCATION, operation_name="operation_name")
742+
732743
@mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_template_client"))
733744
def test_instantiate_workflow_template_missing_region(self, mock_client):
734745
with pytest.raises(TypeError):

‎tests/providers/google/cloud/operators/test_dataproc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
DataprocClusterTrigger,
6060
DataprocDeleteClusterTrigger,
6161
DataprocSubmitTrigger,
62+
DataprocWorkflowTrigger,
6263
)
6364
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
6465
from airflow.serialization.serialized_objects import SerializedDAG
@@ -441,6 +442,7 @@ def test_deprecation_warning(self):
441442
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
442443
def test_execute(self, mock_hook, to_dict_mock):
443444
self.extra_links_manager_mock.attach_mock(mock_hook, "hook")
445+
mock_hook.return_value.create_cluster.result.return_value = None
444446
create_cluster_args = {
445447
"region": GCP_REGION,
446448
"project_id": GCP_PROJECT,
@@ -1363,6 +1365,36 @@ def test_execute(self, mock_hook):
13631365
metadata=METADATA,
13641366
)
13651367

1368+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
1369+
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
1370+
def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
1371+
operator = DataprocInstantiateWorkflowTemplateOperator(
1372+
task_id=TASK_ID,
1373+
template_id=TEMPLATE_ID,
1374+
region=GCP_REGION,
1375+
project_id=GCP_PROJECT,
1376+
version=2,
1377+
parameters={},
1378+
request_id=REQUEST_ID,
1379+
retry=RETRY,
1380+
timeout=TIMEOUT,
1381+
metadata=METADATA,
1382+
gcp_conn_id=GCP_CONN_ID,
1383+
impersonation_chain=IMPERSONATION_CHAIN,
1384+
deferrable=True,
1385+
)
1386+
1387+
with pytest.raises(TaskDeferred) as exc:
1388+
operator.execute(mock.MagicMock())
1389+
1390+
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
1391+
1392+
mock_hook.return_value.instantiate_workflow_template.assert_called_once()
1393+
1394+
mock_hook.return_value.wait_for_operation.assert_not_called()
1395+
assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
1396+
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
1397+
13661398

13671399
@pytest.mark.need_serialized_dag
13681400
@mock.patch(DATAPROC_PATH.format("DataprocHook"))

0 commit comments

Comments
 (0)