Skip to content

Commit 05cd687

Browse files
committed
Make sure we can set binary artifacts to cache mfa related shaders.
1 parent 9f6cc40 commit 05cd687

25 files changed

+192
-49
lines changed

lib/nnc/ccv_nnc.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,15 @@ CCV_WARN_UNUSED(int) ccv_nnc_queue_watermark(void);
929929
* @param size The size of the array, only first 64 will be used.
930930
*/
931931
void ccv_nnc_set_device_permutation(const int type, const int* const device_map, const int size);
932+
/**
933+
* Set the path to binary artifacts that would accelerate command compilations. Note that the binary
934+
* artifacts path are separated into read one and write one. They could be the same, but would be
935+
* better to be a separate one to avoid competing with each other.
936+
* @param paths_to_read The file paths to read binary artifacts. Whether it is a file or directory is implementation dependent.
937+
* @param paths_to_read_size How many paths in 1.
938+
* @param path_to_write The file path to write binary artifacts. Whether it is a file or directory is implementation dependent.
939+
*/
940+
void ccv_nnc_set_binary_artifacts(const char** const paths_to_read, const int paths_to_read_size, const char* const path_to_write);
932941
/**
933942
* Quantize a given memory region of a given datatype / memory resides, into nbits palette.
934943
* @param input The input memory region, it can be CCV_64F, CCV_32F or CCV_16F.

lib/nnc/ccv_nnc_cmd.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,3 +964,10 @@ void ccv_nnc_set_device_permutation(const int type, const int* const device_map,
964964
cusetdevicemap(device_map, size);
965965
#endif
966966
}
967+
968+
void ccv_nnc_set_binary_artifacts(const char** const paths_to_read, const int paths_to_read_size, const char* const path_to_write)
969+
{
970+
#ifdef HAVE_MPS
971+
ccv_nnc_mps_set_binary_artifacts(paths_to_read, paths_to_read_size, path_to_write);
972+
#endif
973+
}

lib/nnc/mfa/ccv_nnc_mfa.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ mtl_buffer_t* ccv_nnc_mfa_request_scratch(ccv_nnc_mfa_context_t* context, const
5555
return context->request_scratch(size);
5656
}
5757

58+
void ccv_nnc_mfa_set_binary_archives(ccv_nnc_mfa_context_t* const context, const char** const paths_to_read, const int paths_to_read_size, const char* const path_to_write) {
59+
std::vector<std::string> paths_to_read_vec;
60+
for (int i = 0; i < paths_to_read_size; i++) {
61+
paths_to_read_vec.push_back(std::string(paths_to_read[i]));
62+
}
63+
std::string path_to_write_str = path_to_write != nullptr ? std::string(path_to_write) : std::string();
64+
context->v2_cache.setBinaryArchives(context->device.get(), paths_to_read_vec, path_to_write_str);
65+
}
66+
5867
void ccv_nnc_mfa_log_message(const char* message) {
5968
std::cerr << METAL_LOG_HEADER << message << std::endl;
6069
}

lib/nnc/mfa/ccv_nnc_mfa.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class context {
6565
} // namespace nnc
6666
} // namespace ccv
6767

68+
std::pair<std::string, std::string> ccv_nnc_mfa_get_binary_artifacts(void);
69+
6870
extern "C" {
6971
#endif // __cplusplus
7072

@@ -80,6 +82,7 @@ void ccv_nnc_mfa_log_message(const char* message);
8082
mtl_command_batch_t* ccv_nnc_start_command_batch(mtl_command_queue_t* command_queue);
8183
void ccv_nnc_finish_command_batch(mtl_command_batch_t* command_batch);
8284
mtl_buffer_t* ccv_nnc_mfa_request_scratch(ccv_nnc_mfa_context_t* context, const uint64_t size);
85+
void ccv_nnc_mfa_set_binary_archives(ccv_nnc_mfa_context_t* context, const char** paths_to_read, const int paths_to_read_size, const char* path_to_write);
8386

8487
#ifdef __cplusplus
8588
} // extern "C"

lib/nnc/mfa/v2/AddDescriptor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ std::size_t std::hash<AddDescriptor>::operator()(const AddDescriptor& hash) cons
1919
return seed;
2020
}
2121

22-
std::pair<AddKernelDescriptor, PipelineValue<AddKernel> *> AddDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map<AddKernelDescriptor, std::unique_ptr<AddKernel>> *const libraryCache) const noexcept {
22+
std::pair<AddKernelDescriptor, PipelineValue<AddKernel> *> AddDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, NS::Array* const binaryArchivesToRead, MTL::BinaryArchive* const binaryArchiveToWrite, const std::string& pathToWrite, std::unordered_map<AddKernelDescriptor, std::unique_ptr<AddKernel>> *const libraryCache) const noexcept {
2323
// The caller is not responsible for calling 'delete' on this pointer. The
2424
// reference is saved in the 'libraryCache'. It will be deallocated whenever
2525
// the shader cache itself is cleaned up.

lib/nnc/mfa/v2/AddDescriptor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct AddDescriptor {
3333

3434
bool operator==(const AddDescriptor& rhs) const;
3535

36-
std::pair<AddKernelDescriptor, PipelineValue<AddKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map<AddKernelDescriptor, std::unique_ptr<AddKernel>> *const libraryCache) const noexcept;
36+
std::pair<AddKernelDescriptor, PipelineValue<AddKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, NS::Array* const binaryArchivesToRead, MTL::BinaryArchive* const binaryArchiveToWrite, const std::string& pathToWrite, std::unordered_map<AddKernelDescriptor, std::unique_ptr<AddKernel>> *const libraryCache) const noexcept;
3737
};
3838

3939
template<>

lib/nnc/mfa/v2/AttentionDescriptor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
127127
}
128128
}
129129

130-
std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> AttentionDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map<AttentionKernelDescriptor, std::unique_ptr<AttentionKernel>> *const libraryCache) const noexcept {
130+
std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> AttentionDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, NS::Array* const binaryArchivesToRead, MTL::BinaryArchive* const binaryArchiveToWrite, const std::string& pathToWrite, std::unordered_map<AttentionKernelDescriptor, std::unique_ptr<AttentionKernel>> *const libraryCache) const noexcept {
131131
auto createPipeline =
132132
[=](MTL::Library* library) -> MTL::ComputePipelineState* {
133133
// Set the function constants.

lib/nnc/mfa/v2/AttentionDescriptor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct AttentionDescriptor {
6161

6262
bool operator==(const AttentionDescriptor& rhs) const;
6363

64-
std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map<AttentionKernelDescriptor, std::unique_ptr<AttentionKernel>> *const libraryCache) const noexcept;
64+
std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, NS::Array* const binaryArchivesToRead, MTL::BinaryArchive* const binaryArchiveToWrite, const std::string& pathToWrite, std::unordered_map<AttentionKernelDescriptor, std::unique_ptr<AttentionKernel>> *const libraryCache) const noexcept;
6565

6666
private:
6767
AttentionKernelDescriptor kernelDescriptor(MTL::Device *const device, const DeviceProperties &dprops) const noexcept;

lib/nnc/mfa/v2/CMulDescriptor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ std::size_t std::hash<CMulDescriptor>::operator()(const CMulDescriptor& hash) co
2727
return seed;
2828
}
2929

30-
std::pair<CMulKernelDescriptor, PipelineValue<CMulKernel> *> CMulDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map<CMulKernelDescriptor, std::unique_ptr<CMulKernel>> *const libraryCache) const noexcept {
30+
std::pair<CMulKernelDescriptor, PipelineValue<CMulKernel> *> CMulDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, NS::Array* const binaryArchivesToRead, MTL::BinaryArchive* const binaryArchiveToWrite, const std::string& pathToWrite, std::unordered_map<CMulKernelDescriptor, std::unique_ptr<CMulKernel>> *const libraryCache) const noexcept {
3131
// The caller is not responsible for calling 'delete' on this pointer. The
3232
// reference is saved in the 'libraryCache'. It will be deallocated whenever
3333
// the shader cache itself is cleaned up.

lib/nnc/mfa/v2/CMulDescriptor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct CMulDescriptor {
3939

4040
bool operator==(const CMulDescriptor& rhs) const;
4141

42-
std::pair<CMulKernelDescriptor, PipelineValue<CMulKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map<CMulKernelDescriptor, std::unique_ptr<CMulKernel>> *const libraryCache) const noexcept;
42+
std::pair<CMulKernelDescriptor, PipelineValue<CMulKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, NS::Array* const binaryArchivesToRead, MTL::BinaryArchive* const binaryArchiveToWrite, const std::string& pathToWrite, std::unordered_map<CMulKernelDescriptor, std::unique_ptr<CMulKernel>> *const libraryCache) const noexcept;
4343
};
4444

4545
template<>

0 commit comments

Comments
 (0)