Skip to content

Commit a16ac9e

Browse files
jerryyangliYang Li
andauthored
Add checkpoint comparison (#62)
* Add checkpoint comparison * Corrected a typo Co-authored-by: Yang Li <[email protected]>
1 parent b1f02b2 commit a16ac9e

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#This script is for testing whether two checkpoints match; it prints all the differences
2+
3+
import torch
4+
import os
5+
import sys
6+
import pickle
7+
from collections import OrderedDict
8+
9+
exclude_key_str = {'ds_config/checkpoint/writer'}
10+
11+
def main():
12+
dir1 = sys.argv[1]
13+
dir2 = sys.argv[2]
14+
print ("Begin comparison")
15+
print ("The first directory {}" .format(dir1))
16+
print ("The second directory {}" .format(dir2))
17+
print (' ')
18+
19+
file_list1 = [f for f in os.listdir(dir1) if os.path.isfile(os.path.join(dir1, f))]
20+
file_list2 = [f for f in os.listdir(dir2) if os.path.isfile(os.path.join(dir2, f))]
21+
common_files = []
22+
23+
for f in file_list1:
24+
if not (f in file_list2):
25+
log_error_file_mismatch_first(f)
26+
else:
27+
common_files.append(f)
28+
for f in file_list2:
29+
if not (f in file_list1):
30+
log_error_file_mismatch_second(f)
31+
32+
for f in common_files:
33+
full_dir1 = os.path.join(dir1, f)
34+
full_dir2 = os.path.join(dir2, f)
35+
print ("Begin comparison")
36+
print("The first checkpoint {}" .format(full_dir1))
37+
print("The second checkpoint {}" .format(full_dir2))
38+
print(' ')
39+
model_first = torch.load(full_dir1)
40+
model_second = torch.load(full_dir2)
41+
object_compare(model_first, model_second, [])
42+
43+
44+
def object_compare(model_first, model_second, key_chain):
45+
if not (type(model_first) == type(model_second)):
46+
log_error_value_mismatch(model_first, model_second, key_chain)
47+
return
48+
49+
if type(model_first) is list:
50+
if len(model_first) != len(model_second):
51+
log_error_value_mismatch(model_first, model_second, key_chain)
52+
return
53+
for i in range(len(model_first)):
54+
object_compare(model_first[i], model_second[i], key_chain)
55+
return
56+
57+
if type(model_first) is dict or type(model_first) is OrderedDict:
58+
common_keys = []
59+
for key in model_first:
60+
if key not in model_second:
61+
key_chain.append(key)
62+
log_error_key_mismatch_first(model_first[key], key_chain)
63+
key_chain.pop()
64+
else:
65+
common_keys.append(key)
66+
67+
for key in model_second:
68+
if key not in model_first:
69+
key_chain.append(key)
70+
log_error_key_mismatch_second(model_second[key], key_chain)
71+
key_chain.pop()
72+
73+
for key in common_keys:
74+
key_chain.append(key)
75+
object_compare(model_first[key], model_second[key], key_chain)
76+
key_chain.pop()
77+
return
78+
79+
if hasattr(model_first, '__dict__'):
80+
equality = (model_first.__dict__ == model_second.__dict__)
81+
else:
82+
equality = (model_first == model_second)
83+
if type(equality) is not bool:
84+
equality = (equality.all())
85+
if not equality:
86+
log_error_value_mismatch(model_first, model_second, key_chain)
87+
return
88+
89+
90+
def log_error_file_mismatch_first(filename):
91+
print("The following file appeared in the first but not the second directory: {}" .format(filename))
92+
print(' ')
93+
94+
95+
def log_error_file_mismatch_second(filename):
96+
print("The following key appeared in the second but not the first directory: {}" .format(filename))
97+
print(" ")
98+
99+
100+
def log_error_key_mismatch_first(model, key_chain):
101+
key_str = "/".join(key_chain)
102+
if not (key_str in exclude_key_str):
103+
print("The following key appeared in the first but not the second model: {}" .format(key_str))
104+
print("The value of the first model is: {}" .format(model))
105+
print(" ")
106+
107+
108+
def log_error_key_mismatch_second(model, key_chain):
109+
key_str = "/".join(key_chain)
110+
if not (key_str in exclude_key_str):
111+
print("The following key appeared in the second but not the first model: {}" .format(key_str))
112+
print("The value of the second model is: {}" .format(model))
113+
print(" ")
114+
115+
116+
def log_error_value_mismatch(model_first, model_second, key_chain):
117+
print ("The values of the following key do not match: {}" .format("/".join(key_chain)))
118+
print ("The value of the first model is: {}" .format(model_first))
119+
print ("The value of the second model is: {}" .format(model_second))
120+
print(" ")
121+
122+
if __name__ == "__main__":
123+
main()

0 commit comments

Comments
 (0)