Skip to content

Commit 01d8a6d

Browse files
committed
add domain names for object storage
1 parent c3fc2c7 commit 01d8a6d

File tree

20 files changed

+881
-55
lines changed

20 files changed

+881
-55
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package ai.starwhale.mlops.configuration.security;
18+
19+
import ai.starwhale.mlops.storage.domain.DomainAwareStorageAccessService;
20+
import java.io.IOException;
21+
import java.util.regex.Pattern;
22+
import javax.servlet.FilterChain;
23+
import javax.servlet.ServletException;
24+
import javax.servlet.http.HttpServletRequest;
25+
import javax.servlet.http.HttpServletResponse;
26+
import org.jetbrains.annotations.NotNull;
27+
import org.springframework.stereotype.Component;
28+
import org.springframework.util.StringUtils;
29+
import org.springframework.web.filter.OncePerRequestFilter;
30+
31+
@Component
32+
public class ObjectStoreDomainDetectionFilter extends OncePerRequestFilter {
33+
34+
public static final String HEADER_NAME = "SW_CLIENT_FAVORED_OSS_DOMAIN_PATTERN";
35+
36+
@Override
37+
protected void doFilterInternal(
38+
HttpServletRequest request,
39+
@NotNull HttpServletResponse response,
40+
FilterChain filterChain
41+
) throws ServletException, IOException {
42+
String pattern = request.getHeader(HEADER_NAME);
43+
if (StringUtils.hasText(pattern)) {
44+
request.setAttribute(DomainAwareStorageAccessService.OSS_DOMAIN_PATTERN_ATTR, Pattern.compile(pattern));
45+
}
46+
filterChain.doFilter(request, response);
47+
}
48+
49+
}

server/controller/src/main/java/ai/starwhale/mlops/configuration/security/SecurityConfiguration.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ public class SecurityConfiguration extends WebSecurityConfigurerAdapter {
8484
@Resource
8585
private ContentCachingFilter contentCachingFilter;
8686

87+
@Resource
88+
private ObjectStoreDomainDetectionFilter objectStoreDomainDetectionFilter;
89+
8790

8891
public SecurityConfiguration() {
8992
super();
@@ -139,6 +142,7 @@ protected void configure(HttpSecurity http) throws Exception {
139142
JwtLoginFilter.class)
140143
.addFilterBefore(projectDetectionFilter, JwtTokenFilter.class)
141144
.addFilterBefore(contentCachingFilter, ProjectDetectionFilter.class)
145+
.addFilterAfter(objectStoreDomainDetectionFilter, JwtTokenFilter.class)
142146
;
143147
}
144148

server/controller/src/main/java/ai/starwhale/mlops/domain/blob/CachedBlobService.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import ai.starwhale.mlops.storage.LengthAbleInputStream;
2020
import ai.starwhale.mlops.storage.StorageAccessService;
21+
import ai.starwhale.mlops.storage.domain.DomainAwareStorageAccessService;
2122
import ai.starwhale.mlops.storage.memory.StorageAccessServiceMemory;
2223
import java.io.IOException;
2324
import java.util.HashMap;
@@ -40,9 +41,12 @@ public CachedBlobService(StorageAccessService defaultStorageAccessService,
4041
// for test only
4142
storageAccessService = new StorageAccessServiceMemory();
4243
} else {
43-
storageAccessService = StorageAccessService.getS3LikeStorageAccessService(
44-
cacheConfig.getStorageType(),
45-
cacheConfig);
44+
storageAccessService = new DomainAwareStorageAccessService(
45+
StorageAccessService.getS3LikeStorageAccessService(
46+
cacheConfig.getStorageType(),
47+
cacheConfig
48+
)
49+
);
4650
}
4751
this.caches.put(cacheConfig.getBlobIdPrefix(),
4852
new BlobServiceImpl(storageAccessService,

server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/objectstore/StorageAccessParser.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import ai.starwhale.mlops.storage.StorageAccessService;
2222
import ai.starwhale.mlops.storage.StorageConnectionToken;
2323
import ai.starwhale.mlops.storage.StorageUri;
24+
import ai.starwhale.mlops.storage.domain.DomainAwareStorageAccessService;
2425
import ai.starwhale.mlops.storage.fs.FsConfig;
2526
import ai.starwhale.mlops.storage.s3.S3Config;
2627
import java.util.Map;
@@ -99,9 +100,12 @@ private StorageAccessService buildStorageAccessService(StorageConnectionToken to
99100
return StorageAccessService.getFileStorageAccessService(
100101
new FsConfig(token.getTokens().get("rootDir"), token.getTokens().get("serviceProvider")));
101102
default:
102-
return StorageAccessService.getS3LikeStorageAccessService(
103-
token.getType(),
104-
new S3Config(token.getTokens()));
103+
return new DomainAwareStorageAccessService(
104+
StorageAccessService.getS3LikeStorageAccessService(
105+
token.getType(),
106+
new S3Config(token.getTokens())
107+
)
108+
);
105109
}
106110
} catch (Exception e) {
107111
log.error("can not build storage access service", e);

server/controller/src/main/resources/application.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ sw:
9797
secret-key: ${SW_STORAGE_SECRETKEY:starwhale}
9898
region: ${SW_STORAGE_REGION:local}
9999
endpoint: ${SW_STORAGE_ENDPOINT:http://localhost:9000}
100+
endpoint-equivalents: ${SW_STORAGE_ENDPOINT_EQS:http://127.0.0.1:9000}
100101
huge-file-threshold: 10485760 # 10MB
101102
huge-file-part-size: 5242880 # 5MB
102103
controller:
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package ai.starwhale.mlops.configuration.security;
18+
19+
import static org.mockito.ArgumentMatchers.any;
20+
import static org.mockito.ArgumentMatchers.eq;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.times;
23+
import static org.mockito.Mockito.verify;
24+
import static org.mockito.Mockito.when;
25+
26+
import java.io.IOException;
27+
import java.util.regex.Pattern;
28+
import javax.servlet.FilterChain;
29+
import javax.servlet.ServletException;
30+
import javax.servlet.http.HttpServletRequest;
31+
import javax.servlet.http.HttpServletResponse;
32+
import org.junit.jupiter.api.Assertions;
33+
import org.junit.jupiter.api.Test;
34+
import org.mockito.ArgumentCaptor;
35+
36+
class ObjectStoreDomainDetectionFilterTest {
37+
38+
ObjectStoreDomainDetectionFilter filter = new ObjectStoreDomainDetectionFilter();
39+
40+
@Test
41+
void doFilterInternal() throws ServletException, IOException {
42+
HttpServletRequest req = mock(HttpServletRequest.class);
43+
HttpServletResponse resp = mock(HttpServletResponse.class);
44+
FilterChain filterChain = mock(FilterChain.class);
45+
when(req.getHeader("SW_CLIENT_FAVORED_OSS_DOMAIN_PATTERN")).thenReturn(null);
46+
filter.doFilterInternal(req, resp, filterChain);
47+
verify(req, times(0)).setAttribute(any(), any());
48+
when(req.getHeader("SW_CLIENT_FAVORED_OSS_DOMAIN_PATTERN")).thenReturn("^aa$");
49+
filter.doFilterInternal(req, resp, filterChain);
50+
ArgumentCaptor<Pattern> ac = ArgumentCaptor.forClass(Pattern.class);
51+
verify(req).setAttribute(eq("SW_OSS_DOMAIN_REG_PATTERN"), ac.capture());
52+
Assertions.assertTrue(ac.getValue().matcher("aa").matches());
53+
}
54+
}

server/storage-access-layer/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@
9191
<artifactId>spring-boot-configuration-processor</artifactId>
9292
<optional>true</optional>
9393
</dependency>
94+
<dependency>
95+
<groupId>org.springframework</groupId>
96+
<artifactId>spring-web</artifactId>
97+
</dependency>
9498
</dependencies>
9599

96100
<build>

server/storage-access-layer/src/main/java/ai/starwhale/mlops/storage/StorageAccessService.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,5 +148,14 @@ static StorageAccessService getS3LikeStorageAccessService(String type, S3Config
148148
*/
149149
String signedUrl(String path, Long expTimeMillis) throws IOException;
150150

151+
default List<String> signedUrlAllDomains(String path, Long expTimeMillis) throws IOException {
152+
return List.of(signedUrl(path, expTimeMillis));
153+
}
154+
151155
String signedPutUrl(String path, String contentType, Long expTimeMillis) throws IOException;
156+
157+
default List<String> signedPutUrlAllDomains(String path, String contentType, Long expTimeMillis)
158+
throws IOException {
159+
return List.of(signedPutUrl(path, contentType, expTimeMillis));
160+
}
152161
}

server/storage-access-layer/src/main/java/ai/starwhale/mlops/storage/aliyun/StorageAccessServiceAliyun.java

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,17 @@
4747
import java.io.InputStream;
4848
import java.util.ArrayList;
4949
import java.util.Date;
50+
import java.util.List;
51+
import java.util.stream.Collectors;
5052
import java.util.stream.Stream;
5153
import lombok.extern.slf4j.Slf4j;
54+
import org.springframework.util.CollectionUtils;
5255

5356
@Slf4j
5457
public class StorageAccessServiceAliyun extends S3LikeStorageAccessService {
5558

5659
private final OSS ossClient;
60+
private final List<OSS> ossClientEquivalents;
5761

5862
public StorageAccessServiceAliyun(S3Config s3Config) {
5963
super(s3Config);
@@ -62,6 +66,15 @@ public StorageAccessServiceAliyun(S3Config s3Config) {
6266
config.setRequestTimeoutEnabled(true);
6367
this.ossClient = new OSSClientBuilder()
6468
.build(s3Config.getEndpoint(), s3Config.getAccessKey(), s3Config.getSecretKey(), config);
69+
if (!CollectionUtils.isEmpty(s3Config.getEndpointEquivalents())) {
70+
this.ossClientEquivalents = s3Config.getEndpointEquivalents()
71+
.stream()
72+
.map(edp -> new OSSClientBuilder()
73+
.build(edp, s3Config.getAccessKey(), s3Config.getSecretKey(), config))
74+
.collect(Collectors.toList());
75+
} else {
76+
this.ossClientEquivalents = List.of();
77+
}
6578
}
6679

6780
@Override
@@ -73,10 +86,12 @@ public StorageObjectInfo head(String path) throws IOException {
7386
public StorageObjectInfo head(String path, boolean md5sum) throws IOException {
7487
try {
7588
var resp = this.ossClient.headObject(new HeadObjectRequest(this.bucket, path));
76-
return new StorageObjectInfo(true,
89+
return new StorageObjectInfo(
90+
true,
7791
resp.getContentLength(),
7892
md5sum ? resp.getETag().replace("\"", "").toLowerCase() : null,
79-
MetaHelper.mapToString(resp.getUserMetadata()));
93+
MetaHelper.mapToString(resp.getUserMetadata())
94+
);
8095
} catch (OSSException e) {
8196
if (e.getErrorCode().equals(OSSErrorCode.NO_SUCH_KEY)) {
8297
return new StorageObjectInfo(false, 0L, null, null);
@@ -117,7 +132,8 @@ public void put(String path, InputStream inputStream) throws IOException {
117132
uploadId,
118133
i,
119134
new ByteArrayInputStream(data),
120-
data.length));
135+
data.length
136+
));
121137
etagList.add(resp.getPartETag());
122138
if (data.length < this.partSize) {
123139
break;
@@ -187,12 +203,41 @@ public void delete(String path) throws IOException {
187203

188204
@Override
189205
public String signedUrl(String path, Long expTimeMillis) {
206+
return signedUrl(path, expTimeMillis, this.ossClient);
207+
}
208+
209+
private String signedUrl(String path, Long expTimeMillis, OSS ossClient) {
190210
return ossClient.generatePresignedUrl(this.bucket, path, new Date(System.currentTimeMillis() + expTimeMillis))
191211
.toString();
192212
}
193213

214+
@Override
215+
public List<String> signedUrlAllDomains(String path, Long expTimeMillis) {
216+
return Stream.concat(Stream.of(ossClient), ossClientEquivalents.stream())
217+
.map(client -> signedUrl(
218+
path,
219+
expTimeMillis,
220+
client
221+
)).collect(Collectors.toList());
222+
}
223+
194224
@Override
195225
public String signedPutUrl(String path, String contentType, Long expTimeMillis) throws IOException {
226+
return signPutUrl(path, contentType, expTimeMillis, this.ossClient);
227+
}
228+
229+
@Override
230+
public List<String> signedPutUrlAllDomains(String path, String contentType, Long expTimeMillis) {
231+
return Stream.concat(Stream.of(ossClient), ossClientEquivalents.stream())
232+
.map(client -> signPutUrl(
233+
path,
234+
contentType,
235+
expTimeMillis,
236+
client
237+
)).collect(Collectors.toList());
238+
}
239+
240+
private String signPutUrl(String path, String contentType, Long expTimeMillis, OSS ossClient) {
196241
var request = new GeneratePresignedUrlRequest(this.bucket, path, HttpMethod.PUT);
197242
request.setExpiration(new Date(System.currentTimeMillis() + expTimeMillis));
198243
request.setContentType(contentType);

0 commit comments

Comments
 (0)