Skip to content

Commit 0fabfb0

Browse files
committed
Batch remaining functions for archive creation
1 parent 8950cc9 commit 0fabfb0

File tree

1 file changed

+90
-56
lines changed

1 file changed

+90
-56
lines changed

src/aiida/tools/archive/create.py

Lines changed: 90 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
import tempfile
1717
from datetime import datetime
1818
from pathlib import Path
19-
from typing import Callable, Iterable, Optional, Sequence, Union
19+
from typing import Callable, Iterable, Optional, Union
2020

2121
from tabulate import tabulate
2222

2323
from aiida import orm
24-
from aiida.common.exceptions import LicensingException
2524
from aiida.common.lang import type_check
2625
from aiida.common.links import GraphTraversalRules
2726
from aiida.common.log import AIIDA_LOGGER
@@ -42,17 +41,18 @@
4241
EXPORT_LOGGER = AIIDA_LOGGER.getChild('export')
4342
QbType = Callable[[], orm.QueryBuilder]
4443

44+
4545
def _batch_query_ids(
4646
querybuilder: QbType,
4747
source_entity_type: EntityTypes,
4848
source_ids: set[int],
4949
target_entity_type: EntityTypes,
5050
relationship: str,
5151
batch_size_limit: int = 10000,
52-
query_batch_size: int = 1000
52+
query_batch_size: int = 1000,
5353
) -> set[int]:
5454
"""Batch query to avoid PostgreSQL parameter limits.
55-
55+
5656
:param querybuilder: QueryBuilder factory function
5757
:param source_entity_type: Entity type to filter on
5858
:param source_ids: Set of IDs to filter
@@ -81,7 +81,7 @@ def _batch_query_ids(
8181
source_id_list = list(source_ids)
8282

8383
for i in range(0, len(source_id_list), batch_size_limit):
84-
batch_ids = source_id_list[i:i + batch_size_limit]
84+
batch_ids = source_id_list[i : i + batch_size_limit]
8585

8686
qb = querybuilder()
8787
qb.append(source_orm, filters={'id': {'in': batch_ids}}, tag='source')
@@ -356,12 +356,14 @@ def querybuilder():
356356
with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress:
357357
for etype, ids in entity_ids.items():
358358
if etype == EntityTypes.NODE and strip_checkpoints:
359+
359360
def transform(row):
360361
data = row['entity']
361362
if data.get('node_type', '').startswith('process.'):
362363
data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None)
363364
return data
364365
else:
366+
365367
def transform(row):
366368
return row['entity']
367369

@@ -370,14 +372,14 @@ def transform(row):
370372
# Batch the IDs to avoid parameter limits
371373
ids_list = list(ids)
372374
for i in range(0, len(ids_list), 10000): # Batch of 10k to stay under 32k limit
373-
batch_ids = ids_list[i:i + 10000]
375+
batch_ids = ids_list[i : i + 10000]
374376
for nrows, rows in batch_iter(
375377
querybuilder()
376378
.append(
377379
entity_type_to_orm[etype],
378380
filters={'id': {'in': batch_ids}}, # Now only 10k parameters max
379381
tag='entity',
380-
project=['**']
382+
project=['**'],
381383
)
382384
.iterdict(batch_size=batch_size),
383385
batch_size,
@@ -569,7 +571,7 @@ def progress_str(name):
569571
if entity_ids[EntityTypes.GROUP]:
570572
group_id_list = list(entity_ids[EntityTypes.GROUP])
571573
for i in range(0, len(group_id_list), 10000): # Batch to avoid parameter limits
572-
batch_ids = group_id_list[i:i + 10000]
574+
batch_ids = group_id_list[i : i + 10000]
573575
qbuilder = querybuilder()
574576
qbuilder.append(orm.Group, filters={'id': {'in': batch_ids}}, project='id', tag='group')
575577
qbuilder.append(orm.Node, with_group='group', project='id')
@@ -600,7 +602,7 @@ def progress_str(name):
600602
EntityTypes.COMPUTER,
601603
'with_node',
602604
batch_size_limit=10000,
603-
query_batch_size=batch_size
605+
query_batch_size=batch_size,
604606
)
605607
)
606608

@@ -616,7 +618,7 @@ def progress_str(name):
616618
EntityTypes.AUTHINFO,
617619
'with_computer',
618620
batch_size_limit=10000,
619-
query_batch_size=batch_size
621+
query_batch_size=batch_size,
620622
)
621623
)
622624

@@ -632,7 +634,7 @@ def progress_str(name):
632634
EntityTypes.LOG,
633635
'with_node',
634636
batch_size_limit=10000,
635-
query_batch_size=batch_size
637+
query_batch_size=batch_size,
636638
)
637639
)
638640

@@ -648,7 +650,7 @@ def progress_str(name):
648650
EntityTypes.COMMENT,
649651
'with_node',
650652
batch_size_limit=10000,
651-
query_batch_size=batch_size
653+
query_batch_size=batch_size,
652654
)
653655
)
654656

@@ -664,7 +666,7 @@ def progress_str(name):
664666
EntityTypes.USER,
665667
'with_node',
666668
batch_size_limit=10000,
667-
query_batch_size=batch_size
669+
query_batch_size=batch_size,
668670
)
669671
)
670672
if entity_ids[EntityTypes.GROUP]:
@@ -676,7 +678,7 @@ def progress_str(name):
676678
EntityTypes.USER,
677679
'with_group',
678680
batch_size_limit=10000,
679-
query_batch_size=batch_size
681+
query_batch_size=batch_size,
680682
)
681683
)
682684
if entity_ids[EntityTypes.COMMENT]:
@@ -688,7 +690,7 @@ def progress_str(name):
688690
EntityTypes.USER,
689691
'with_comment',
690692
batch_size_limit=10000,
691-
query_batch_size=batch_size
693+
query_batch_size=batch_size,
692694
)
693695
)
694696
if entity_ids[EntityTypes.AUTHINFO]:
@@ -700,7 +702,7 @@ def progress_str(name):
700702
EntityTypes.USER,
701703
'with_authinfo',
702704
batch_size_limit=10000,
703-
query_batch_size=batch_size
705+
query_batch_size=batch_size,
704706
)
705707
)
706708

@@ -713,55 +715,81 @@ def _stream_repo_files(
713715
key_format: str, writer: ArchiveWriterAbstract, node_ids: set[int], backend: StorageBackend, batch_size: int
714716
) -> None:
715717
"""Collect all repository object keys from the nodes, then stream the files to the archive."""
716-
keys = set(
717-
orm.Node.get_collection(backend).iter_repo_keys(filters={'id': {'in': list(node_ids)}}, batch_size=batch_size)
718-
)
718+
719+
# Batch the node IDs to avoid parameter limits when getting repo keys
720+
node_ids_list = list(node_ids)
721+
batch_size_limit = 10000 # Stay well under 65535 parameter limit
722+
all_keys = set()
723+
724+
for i in range(0, len(node_ids_list), batch_size_limit):
725+
batch_ids = node_ids_list[i : i + batch_size_limit]
726+
batch_keys = set(
727+
orm.Node.get_collection(backend).iter_repo_keys(filters={'id': {'in': batch_ids}}, batch_size=batch_size)
728+
)
729+
all_keys.update(batch_keys)
719730

720731
repository = backend.get_repository()
721732
if not repository.key_format == key_format:
722733
# Here we would have to go back and replace all the keys in the `BackendNode.repository_metadata`s
723734
raise NotImplementedError(
724735
f'Backend repository key format incompatible: {repository.key_format!r} != {key_format!r}'
725736
)
726-
with get_progress_reporter()(desc='Archiving files: ', total=len(keys)) as progress:
727-
for key, stream in repository.iter_object_streams(keys): # type: ignore[arg-type]
737+
with get_progress_reporter()(desc='Archiving files: ', total=len(all_keys)) as progress:
738+
for key, stream in repository.iter_object_streams(all_keys): # type: ignore[arg-type]
728739
# to-do should we use assume the key here is correct, or always re-compute and check?
729740
writer.put_object(stream, key=key)
730741
progress.update()
731742

732743

733744
def _check_unsealed_nodes(querybuilder: QbType, node_ids: set[int], batch_size: int) -> None:
734745
"""Check no process nodes are unsealed, i.e. all processes have completed."""
735-
qbuilder = (
736-
querybuilder()
737-
.append(
738-
orm.ProcessNode,
739-
filters={
740-
'id': {'in': list(node_ids)},
741-
'attributes.sealed': {
742-
'!in': [True] # better operator?
746+
if not node_ids:
747+
return
748+
749+
# Batch the node IDs to avoid parameter limits
750+
node_ids_list = list(node_ids)
751+
batch_size_limit = 10000 # Stay well under 65535 parameter limit
752+
all_unsealed_pks = []
753+
754+
for i in range(0, len(node_ids_list), batch_size_limit):
755+
batch_ids = node_ids_list[i : i + batch_size_limit]
756+
757+
qbuilder = (
758+
querybuilder()
759+
.append(
760+
orm.ProcessNode,
761+
filters={
762+
'id': {'in': batch_ids},
763+
'attributes.sealed': {
764+
'!in': [True] # better operator?
765+
},
743766
},
744-
},
745-
project='id',
767+
project='id',
768+
)
769+
.distinct()
746770
)
747-
.distinct()
748-
)
749-
unsealed_node_pks = qbuilder.all(batch_size=batch_size, flat=True)
750-
if unsealed_node_pks:
771+
batch_unsealed_pks = qbuilder.all(batch_size=batch_size, flat=True)
772+
all_unsealed_pks.extend(batch_unsealed_pks)
773+
774+
if all_unsealed_pks:
751775
raise ExportValidationError(
752776
'All ProcessNodes must be sealed before they can be exported. '
753-
f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed."
777+
f"Node(s) with PK(s): {', '.join(str(pk) for pk in all_unsealed_pks)} is/are not sealed."
754778
)
755779

756780

757781
def _check_node_licenses(
758782
querybuilder: QbType,
759783
node_ids: set[int],
760-
allowed_licenses: Union[None, Sequence[str], Callable],
761-
forbidden_licenses: Union[None, Sequence[str], Callable],
784+
allowed_licenses: Optional[Union[list, Callable]],
785+
forbidden_licenses: Optional[Union[list, Callable]],
762786
batch_size: int,
763787
) -> None:
764788
"""Check the nodes to be archived for disallowed licences."""
789+
from typing import Sequence
790+
791+
from aiida.common.exceptions import LicensingException
792+
765793
if allowed_licenses is None and forbidden_licenses is None:
766794
return None
767795

@@ -807,24 +835,31 @@ def check_forbidden(lic):
807835
else:
808836
raise TypeError('forbidden_licenses not a list or function')
809837

810-
# create query
811-
qbuilder = querybuilder().append(
812-
orm.Node,
813-
project=['id', 'attributes.source.license'],
814-
filters={'id': {'in': list(node_ids)}},
815-
)
838+
# Batch the node IDs to avoid parameter limits
839+
node_ids_list = list(node_ids)
840+
batch_size_limit = 10000 # Stay well under 65535 parameter limit
816841

817-
for node_id, name in qbuilder.iterall(batch_size=batch_size):
818-
if name is None:
819-
continue
820-
if not check_allowed(name):
821-
raise LicensingException(
822-
f"Node {node_id} is licensed under '{name}' license, which is not in the list of allowed licenses"
823-
)
824-
if check_forbidden(name):
825-
raise LicensingException(
826-
f"Node {node_id} is licensed under '{name}' license, which is in the list of forbidden licenses"
827-
)
842+
for i in range(0, len(node_ids_list), batch_size_limit):
843+
batch_ids = node_ids_list[i : i + batch_size_limit]
844+
845+
# create query for this batch
846+
qbuilder = querybuilder().append(
847+
orm.Node,
848+
project=['id', 'attributes.source.license'],
849+
filters={'id': {'in': batch_ids}},
850+
)
851+
852+
for node_id, name in qbuilder.iterall(batch_size=batch_size):
853+
if name is None:
854+
continue
855+
if not check_allowed(name):
856+
raise LicensingException(
857+
f"Node {node_id} is licensed under '{name}' license, which is not in the list of allowed licenses"
858+
)
859+
if check_forbidden(name):
860+
raise LicensingException(
861+
f"Node {node_id} is licensed under '{name}' license, which is in the list of forbidden licenses"
862+
)
828863

829864

830865
def get_init_summary(
@@ -856,4 +891,3 @@ def get_init_summary(
856891
result += f"\n\n{tabulate(rules_table, headers=['Traversal rules', ''])}"
857892

858893
return result + '\n'
859-

0 commit comments

Comments
 (0)