@@ -13,27 +13,41 @@ def pillow_to_tensor(image):
1313 return torch .from_numpy (np .array (image ).astype (np .float32 ) / 255.0 ).unsqueeze (0 )
1414
1515
16- def create_image_grid (images : t .List [Image .Image ], gap : int , ncol : int ):
17- # Calculate the number of rows needed based on the number of images and columns
18- nrow = (len (images ) + ncol - 1 ) // ncol
19-
20- # Get the size of the first image to use as a template for the grid
16+ def create_image_grid_by_columns (
17+ images : t .List [Image .Image ],
18+ gap : int ,
19+ max_columns : int ,
20+ ) -> Image .Image :
21+ max_rows = (len (images ) + max_columns - 1 ) // max_columns
22+ return create_image_grid (images = images , gap = gap , max_columns = max_columns , max_rows = max_rows )
23+
24+
25+ def create_image_grid_by_rows (
26+ images : t .List [Image .Image ],
27+ gap : int ,
28+ max_rows : int ,
29+ ) -> Image .Image :
30+ max_columns = (len (images ) + max_rows - 1 ) // max_rows
31+ return create_image_grid (images = images , gap = gap , max_columns = max_columns , max_rows = max_rows )
32+
33+
34+ def create_image_grid (
35+ images : t .List [Image .Image ],
36+ gap : int ,
37+ max_columns : int ,
38+ max_rows : int ,
39+ ) -> Image .Image :
2140 size = images [0 ].size
2241
23- # Calculate the total size of the grid with gaps
24- width = size [0 ] * ncol + gap * (ncol - 1 )
25- height = size [1 ] * nrow + gap * (nrow - 1 )
42+ width = size [0 ] * max_columns + (max_columns - 1 ) * gap
43+ height = size [1 ] * max_rows + (max_rows - 1 ) * gap
2644
27- # Create a new image for the grid
2845 grid_image = Image .new ("RGB" , (width , height ), color = "white" )
2946
30- # Iterate over each image and paste it into the grid
3147 for i , image in enumerate (images ):
32- # Calculate the position of the image in the grid
33- x = (i % ncol ) * (size [0 ] + gap )
34- y = (i // ncol ) * (size [1 ] + gap )
48+ x = (i % max_columns ) * (size [0 ] + gap )
49+ y = (i // max_columns ) * (size [1 ] + gap )
3550
36- # Paste the image into the grid
3751 grid_image .paste (image , (x , y ))
3852
3953 return grid_image
0 commit comments