16
16
import tempfile
17
17
from datetime import datetime
18
18
from pathlib import Path
19
- from typing import Callable , Iterable , Optional , Sequence , Union
19
+ from typing import Callable , Iterable , Optional , Union
20
20
21
21
from tabulate import tabulate
22
22
23
23
from aiida import orm
24
- from aiida .common .exceptions import LicensingException
25
24
from aiida .common .lang import type_check
26
25
from aiida .common .links import GraphTraversalRules
27
26
from aiida .common .log import AIIDA_LOGGER
42
41
EXPORT_LOGGER = AIIDA_LOGGER .getChild ('export' )
43
42
QbType = Callable [[], orm .QueryBuilder ]
44
43
44
+
45
45
def _batch_query_ids (
46
46
querybuilder : QbType ,
47
47
source_entity_type : EntityTypes ,
48
48
source_ids : set [int ],
49
49
target_entity_type : EntityTypes ,
50
50
relationship : str ,
51
51
batch_size_limit : int = 10000 ,
52
- query_batch_size : int = 1000
52
+ query_batch_size : int = 1000 ,
53
53
) -> set [int ]:
54
54
"""Batch query to avoid PostgreSQL parameter limits.
55
-
55
+
56
56
:param querybuilder: QueryBuilder factory function
57
57
:param source_entity_type: Entity type to filter on
58
58
:param source_ids: Set of IDs to filter
@@ -81,7 +81,7 @@ def _batch_query_ids(
81
81
source_id_list = list (source_ids )
82
82
83
83
for i in range (0 , len (source_id_list ), batch_size_limit ):
84
- batch_ids = source_id_list [i : i + batch_size_limit ]
84
+ batch_ids = source_id_list [i : i + batch_size_limit ]
85
85
86
86
qb = querybuilder ()
87
87
qb .append (source_orm , filters = {'id' : {'in' : batch_ids }}, tag = 'source' )
@@ -356,12 +356,14 @@ def querybuilder():
356
356
with get_progress_reporter ()(desc = 'Archiving database: ' , total = sum (entity_counts .values ())) as progress :
357
357
for etype , ids in entity_ids .items ():
358
358
if etype == EntityTypes .NODE and strip_checkpoints :
359
+
359
360
def transform (row ):
360
361
data = row ['entity' ]
361
362
if data .get ('node_type' , '' ).startswith ('process.' ):
362
363
data ['attributes' ].pop (orm .ProcessNode .CHECKPOINT_KEY , None )
363
364
return data
364
365
else :
366
+
365
367
def transform (row ):
366
368
return row ['entity' ]
367
369
@@ -370,14 +372,14 @@ def transform(row):
370
372
# Batch the IDs to avoid parameter limits
371
373
ids_list = list (ids )
372
374
for i in range (0 , len (ids_list ), 10000 ): # Batch of 10k to stay under 32k limit
373
- batch_ids = ids_list [i : i + 10000 ]
375
+ batch_ids = ids_list [i : i + 10000 ]
374
376
for nrows , rows in batch_iter (
375
377
querybuilder ()
376
378
.append (
377
379
entity_type_to_orm [etype ],
378
380
filters = {'id' : {'in' : batch_ids }}, # Now only 10k parameters max
379
381
tag = 'entity' ,
380
- project = ['**' ]
382
+ project = ['**' ],
381
383
)
382
384
.iterdict (batch_size = batch_size ),
383
385
batch_size ,
@@ -569,7 +571,7 @@ def progress_str(name):
569
571
if entity_ids [EntityTypes .GROUP ]:
570
572
group_id_list = list (entity_ids [EntityTypes .GROUP ])
571
573
for i in range (0 , len (group_id_list ), 10000 ): # Batch to avoid parameter limits
572
- batch_ids = group_id_list [i : i + 10000 ]
574
+ batch_ids = group_id_list [i : i + 10000 ]
573
575
qbuilder = querybuilder ()
574
576
qbuilder .append (orm .Group , filters = {'id' : {'in' : batch_ids }}, project = 'id' , tag = 'group' )
575
577
qbuilder .append (orm .Node , with_group = 'group' , project = 'id' )
@@ -600,7 +602,7 @@ def progress_str(name):
600
602
EntityTypes .COMPUTER ,
601
603
'with_node' ,
602
604
batch_size_limit = 10000 ,
603
- query_batch_size = batch_size
605
+ query_batch_size = batch_size ,
604
606
)
605
607
)
606
608
@@ -616,7 +618,7 @@ def progress_str(name):
616
618
EntityTypes .AUTHINFO ,
617
619
'with_computer' ,
618
620
batch_size_limit = 10000 ,
619
- query_batch_size = batch_size
621
+ query_batch_size = batch_size ,
620
622
)
621
623
)
622
624
@@ -632,7 +634,7 @@ def progress_str(name):
632
634
EntityTypes .LOG ,
633
635
'with_node' ,
634
636
batch_size_limit = 10000 ,
635
- query_batch_size = batch_size
637
+ query_batch_size = batch_size ,
636
638
)
637
639
)
638
640
@@ -648,7 +650,7 @@ def progress_str(name):
648
650
EntityTypes .COMMENT ,
649
651
'with_node' ,
650
652
batch_size_limit = 10000 ,
651
- query_batch_size = batch_size
653
+ query_batch_size = batch_size ,
652
654
)
653
655
)
654
656
@@ -664,7 +666,7 @@ def progress_str(name):
664
666
EntityTypes .USER ,
665
667
'with_node' ,
666
668
batch_size_limit = 10000 ,
667
- query_batch_size = batch_size
669
+ query_batch_size = batch_size ,
668
670
)
669
671
)
670
672
if entity_ids [EntityTypes .GROUP ]:
@@ -676,7 +678,7 @@ def progress_str(name):
676
678
EntityTypes .USER ,
677
679
'with_group' ,
678
680
batch_size_limit = 10000 ,
679
- query_batch_size = batch_size
681
+ query_batch_size = batch_size ,
680
682
)
681
683
)
682
684
if entity_ids [EntityTypes .COMMENT ]:
@@ -688,7 +690,7 @@ def progress_str(name):
688
690
EntityTypes .USER ,
689
691
'with_comment' ,
690
692
batch_size_limit = 10000 ,
691
- query_batch_size = batch_size
693
+ query_batch_size = batch_size ,
692
694
)
693
695
)
694
696
if entity_ids [EntityTypes .AUTHINFO ]:
@@ -700,7 +702,7 @@ def progress_str(name):
700
702
EntityTypes .USER ,
701
703
'with_authinfo' ,
702
704
batch_size_limit = 10000 ,
703
- query_batch_size = batch_size
705
+ query_batch_size = batch_size ,
704
706
)
705
707
)
706
708
@@ -713,55 +715,81 @@ def _stream_repo_files(
713
715
key_format : str , writer : ArchiveWriterAbstract , node_ids : set [int ], backend : StorageBackend , batch_size : int
714
716
) -> None :
715
717
"""Collect all repository object keys from the nodes, then stream the files to the archive."""
716
- keys = set (
717
- orm .Node .get_collection (backend ).iter_repo_keys (filters = {'id' : {'in' : list (node_ids )}}, batch_size = batch_size )
718
- )
718
+
719
+ # Batch the node IDs to avoid parameter limits when getting repo keys
720
+ node_ids_list = list (node_ids )
721
+ batch_size_limit = 10000 # Stay well under 65535 parameter limit
722
+ all_keys = set ()
723
+
724
+ for i in range (0 , len (node_ids_list ), batch_size_limit ):
725
+ batch_ids = node_ids_list [i : i + batch_size_limit ]
726
+ batch_keys = set (
727
+ orm .Node .get_collection (backend ).iter_repo_keys (filters = {'id' : {'in' : batch_ids }}, batch_size = batch_size )
728
+ )
729
+ all_keys .update (batch_keys )
719
730
720
731
repository = backend .get_repository ()
721
732
if not repository .key_format == key_format :
722
733
# Here we would have to go back and replace all the keys in the `BackendNode.repository_metadata`s
723
734
raise NotImplementedError (
724
735
f'Backend repository key format incompatible: { repository .key_format !r} != { key_format !r} '
725
736
)
726
- with get_progress_reporter ()(desc = 'Archiving files: ' , total = len (keys )) as progress :
727
- for key , stream in repository .iter_object_streams (keys ): # type: ignore[arg-type]
737
+ with get_progress_reporter ()(desc = 'Archiving files: ' , total = len (all_keys )) as progress :
738
+ for key , stream in repository .iter_object_streams (all_keys ): # type: ignore[arg-type]
728
739
# to-do should we use assume the key here is correct, or always re-compute and check?
729
740
writer .put_object (stream , key = key )
730
741
progress .update ()
731
742
732
743
733
744
def _check_unsealed_nodes (querybuilder : QbType , node_ids : set [int ], batch_size : int ) -> None :
734
745
"""Check no process nodes are unsealed, i.e. all processes have completed."""
735
- qbuilder = (
736
- querybuilder ()
737
- .append (
738
- orm .ProcessNode ,
739
- filters = {
740
- 'id' : {'in' : list (node_ids )},
741
- 'attributes.sealed' : {
742
- '!in' : [True ] # better operator?
746
+ if not node_ids :
747
+ return
748
+
749
+ # Batch the node IDs to avoid parameter limits
750
+ node_ids_list = list (node_ids )
751
+ batch_size_limit = 10000 # Stay well under 65535 parameter limit
752
+ all_unsealed_pks = []
753
+
754
+ for i in range (0 , len (node_ids_list ), batch_size_limit ):
755
+ batch_ids = node_ids_list [i : i + batch_size_limit ]
756
+
757
+ qbuilder = (
758
+ querybuilder ()
759
+ .append (
760
+ orm .ProcessNode ,
761
+ filters = {
762
+ 'id' : {'in' : batch_ids },
763
+ 'attributes.sealed' : {
764
+ '!in' : [True ] # better operator?
765
+ },
743
766
},
744
- },
745
- project = 'id' ,
767
+ project = 'id' ,
768
+ )
769
+ .distinct ()
746
770
)
747
- . distinct ( )
748
- )
749
- unsealed_node_pks = qbuilder . all ( batch_size = batch_size , flat = True )
750
- if unsealed_node_pks :
771
+ batch_unsealed_pks = qbuilder . all ( batch_size = batch_size , flat = True )
772
+ all_unsealed_pks . extend ( batch_unsealed_pks )
773
+
774
+ if all_unsealed_pks :
751
775
raise ExportValidationError (
752
776
'All ProcessNodes must be sealed before they can be exported. '
753
- f"Node(s) with PK(s): { ', ' .join (str (pk ) for pk in unsealed_node_pks )} is/are not sealed."
777
+ f"Node(s) with PK(s): { ', ' .join (str (pk ) for pk in all_unsealed_pks )} is/are not sealed."
754
778
)
755
779
756
780
757
781
def _check_node_licenses (
758
782
querybuilder : QbType ,
759
783
node_ids : set [int ],
760
- allowed_licenses : Union [None , Sequence [ str ], Callable ],
761
- forbidden_licenses : Union [None , Sequence [ str ], Callable ],
784
+ allowed_licenses : Optional [ Union [list , Callable ] ],
785
+ forbidden_licenses : Optional [ Union [list , Callable ] ],
762
786
batch_size : int ,
763
787
) -> None :
764
788
"""Check the nodes to be archived for disallowed licences."""
789
+ from typing import Sequence
790
+
791
+ from aiida .common .exceptions import LicensingException
792
+
765
793
if allowed_licenses is None and forbidden_licenses is None :
766
794
return None
767
795
@@ -807,24 +835,31 @@ def check_forbidden(lic):
807
835
else :
808
836
raise TypeError ('forbidden_licenses not a list or function' )
809
837
810
- # create query
811
- qbuilder = querybuilder ().append (
812
- orm .Node ,
813
- project = ['id' , 'attributes.source.license' ],
814
- filters = {'id' : {'in' : list (node_ids )}},
815
- )
838
+ # Batch the node IDs to avoid parameter limits
839
+ node_ids_list = list (node_ids )
840
+ batch_size_limit = 10000 # Stay well under 65535 parameter limit
816
841
817
- for node_id , name in qbuilder .iterall (batch_size = batch_size ):
818
- if name is None :
819
- continue
820
- if not check_allowed (name ):
821
- raise LicensingException (
822
- f"Node { node_id } is licensed under '{ name } ' license, which is not in the list of allowed licenses"
823
- )
824
- if check_forbidden (name ):
825
- raise LicensingException (
826
- f"Node { node_id } is licensed under '{ name } ' license, which is in the list of forbidden licenses"
827
- )
842
+ for i in range (0 , len (node_ids_list ), batch_size_limit ):
843
+ batch_ids = node_ids_list [i : i + batch_size_limit ]
844
+
845
+ # create query for this batch
846
+ qbuilder = querybuilder ().append (
847
+ orm .Node ,
848
+ project = ['id' , 'attributes.source.license' ],
849
+ filters = {'id' : {'in' : batch_ids }},
850
+ )
851
+
852
+ for node_id , name in qbuilder .iterall (batch_size = batch_size ):
853
+ if name is None :
854
+ continue
855
+ if not check_allowed (name ):
856
+ raise LicensingException (
857
+ f"Node { node_id } is licensed under '{ name } ' license, which is not in the list of allowed licenses"
858
+ )
859
+ if check_forbidden (name ):
860
+ raise LicensingException (
861
+ f"Node { node_id } is licensed under '{ name } ' license, which is in the list of forbidden licenses"
862
+ )
828
863
829
864
830
865
def get_init_summary (
@@ -856,4 +891,3 @@ def get_init_summary(
856
891
result += f"\n \n { tabulate (rules_table , headers = ['Traversal rules' , '' ])} "
857
892
858
893
return result + '\n '
859
-
0 commit comments