Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/memory/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cc_library(memcpy SRCS memcpy.cc DEPS place device_context)
cc_library(stats SRCS stats.cc DEPS enforce)
cc_library(memory DEPS malloc memcpy stats)

cc_test(memory_stats_test SRCS memory_stats_test.cc DEPS memory)
cc_test(stats_test SRCS stats_test.cc DEPS stats)

if (WITH_GPU)
Expand Down
5 changes: 1 addition & 4 deletions paddle/fluid/memory/allocation/allocator_facade.cc
Original file line number Diff line number Diff line change
Expand Up @@ -931,10 +931,7 @@ class AllocatorFacadePrivate {

void WrapStatAllocator() {
for (auto& pair : allocators_) {
// Now memory stats is only supported for GPU
if (platform::is_gpu_place(pair.first)) {
pair.second = std::make_shared<StatAllocator>(pair.second);
}
pair.second = std::make_shared<StatAllocator>(pair.second);
}
}

Expand Down
20 changes: 16 additions & 4 deletions paddle/fluid/memory/allocation/stat_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,28 @@ class StatAllocator : public Allocator {

protected:
void FreeImpl(phi::Allocation* allocation) override {
MEMORY_STAT_UPDATE(Allocated, allocation->place().GetDeviceId(),
-allocation->size());
if (platform::is_cpu_place(allocation->place())) {
HOST_MEMORY_STAT_UPDATE(Allocated, allocation->place().GetDeviceId(),
-allocation->size());
} else {
DEVICE_MEMORY_STAT_UPDATE(Allocated, allocation->place().GetDeviceId(),
-allocation->size());
}

underlying_allocator_->Free(allocation);
}

phi::Allocation* AllocateImpl(size_t size) override {
phi::Allocator::AllocationPtr allocation =
underlying_allocator_->Allocate(size);
MEMORY_STAT_UPDATE(Allocated, allocation->place().GetDeviceId(),
allocation->size());

if (platform::is_cpu_place(allocation->place())) {
HOST_MEMORY_STAT_UPDATE(Allocated, allocation->place().GetDeviceId(),
allocation->size());
} else {
DEVICE_MEMORY_STAT_UPDATE(Allocated, allocation->place().GetDeviceId(),
allocation->size());
}
return allocation.release();
}

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/memory/detail/system_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License. */

#include "paddle/fluid/memory/detail/system_allocator.h"

#include "paddle/fluid/memory/stats.h"

#ifdef _WIN32
#include <malloc.h>
#ifndef NOMINMAX
Expand Down Expand Up @@ -92,6 +94,8 @@ void* CPUAllocator::Alloc(size_t* index, size_t size) {
}
}

HOST_MEMORY_STAT_UPDATE(Reserved, 0, size);

return p;
}

Expand All @@ -108,6 +112,8 @@ void CPUAllocator::Free(void* p, size_t size, size_t index) {
#else
free(p);
#endif

HOST_MEMORY_STAT_UPDATE(Reserved, 0, -size);
}

bool CPUAllocator::UseGpu() const { return false; }
Expand Down
64 changes: 64 additions & 0 deletions paddle/fluid/memory/memory_stats_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/memory/memory.h"
#include <algorithm>
#include <vector>
#include "gtest/gtest.h"

namespace paddle {
namespace memory {

TEST(stat_allocator_test, host_memory_stat_test) {
std::vector<int64_t> alloc_sizes{
5278, 9593, 8492, 5041, 3351, 4232, 3706, 5963, 5896, 5057, 7527,
6235, 0, 7810, 940, 1239, 1945, 789, 2891, 7553, 8046, 2685,
1332, 6547, 5238, 5345, 1133, 5475, 9137, 3111, 8478, 6350, 9395,
4, 1185, 2186, 357, 9774, 6743, 6136, 7073, 7674, 5640, 3935,
528, 6699, 9821, 8717, 2264, 4708, 9936, 3566, 1373, 6955, 3694,
221, 309, 3617, 3793, 3334, 7281, 1302};

int64_t max_alloc_size = 0;
for (int64_t size : alloc_sizes) {
AllocationPtr allocation = Alloc(platform::CPUPlace(), size);
int64_t alloc_size = static_cast<int64_t>(allocation->size());
max_alloc_size = std::max(max_alloc_size, alloc_size);
EXPECT_EQ(HostMemoryStatCurrentValue("Allocated", 0), alloc_size);
}
EXPECT_EQ(HostMemoryStatPeakValue("Allocated", 0), max_alloc_size);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(stat_allocator_test, device_memory_stat_test) {
std::vector<int64_t> alloc_sizes{
5278, 9593, 8492, 5041, 3351, 4232, 3706, 5963, 5896, 5057, 7527,
6235, 0, 7810, 940, 1239, 1945, 789, 2891, 7553, 8046, 2685,
1332, 6547, 5238, 5345, 1133, 5475, 9137, 3111, 8478, 6350, 9395,
4, 1185, 2186, 357, 9774, 6743, 6136, 7073, 7674, 5640, 3935,
528, 6699, 9821, 8717, 2264, 4708, 9936, 3566, 1373, 6955, 3694,
221, 309, 3617, 3793, 3334, 7281, 1302};

int64_t max_alloc_size = 0;
for (int64_t size : alloc_sizes) {
AllocationPtr allocation = Alloc(platform::CUDAPlace(), size);
int64_t alloc_size = static_cast<int64_t>(allocation->size());
max_alloc_size = std::max(max_alloc_size, alloc_size);
EXPECT_EQ(DeviceMemoryStatCurrentValue("Allocated", 0), alloc_size);
}
EXPECT_EQ(DeviceMemoryStatPeakValue("Allocated", 0), max_alloc_size);
}
#endif

} // namespace memory
} // namespace paddle
92 changes: 58 additions & 34 deletions paddle/fluid/memory/stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class StatRegistry {
}

std::string GetStatKey(const std::string& stat_type, int dev_id) {
return "STAT_Device" + std::to_string(dev_id) + "_" + stat_type;
return stat_type + std::to_string(dev_id);
}

int64_t GetCurrentValue(const std::string& stat_type, int dev_id) {
Expand All @@ -49,6 +49,10 @@ class StatRegistry {
return GetStat(stat_type, dev_id)->GetPeakValue();
}

void Update(const std::string& stat_type, int dev_id, int64_t increment) {
GetStat(stat_type, dev_id)->Update(increment);
}

void Register(const std::string& stat_type, int dev_id, StatBase* stat) {
std::lock_guard<SpinLock> lock_guard(stat_map_lock_);
stat_map_[GetStatKey(stat_type, dev_id)] = stat;
Expand All @@ -59,10 +63,6 @@ class StatRegistry {
stat_map_.erase(GetStatKey(stat_type, dev_id));
}

void Update(const std::string& stat_type, int dev_id, int64_t increment) {
stat_map_[GetStatKey(stat_type, dev_id)]->Update(increment);
}

private:
StatRegistry() = default;

Expand All @@ -72,43 +72,67 @@ class StatRegistry {
SpinLock stat_map_lock_;
};

int64_t StatGetCurrentValue(const std::string& stat_type, int dev_id) {
return StatRegistry::GetInstance()->GetCurrentValue(stat_type, dev_id);
int64_t DeviceMemoryStatCurrentValue(const std::string& stat_type, int dev_id) {
return StatRegistry::GetInstance()->GetCurrentValue("Device" + stat_type,
dev_id);
}

int64_t StatGetPeakValue(const std::string& stat_type, int dev_id) {
return StatRegistry::GetInstance()->GetPeakValue(stat_type, dev_id);
int64_t DeviceMemoryStatPeakValue(const std::string& stat_type, int dev_id) {
return StatRegistry::GetInstance()->GetPeakValue("Device" + stat_type,
dev_id);
}

void StatUpdate(const std::string& stat_type, int dev_id, int64_t increment) {
StatRegistry::GetInstance()->Update(stat_type, dev_id, increment);
void DeviceMemoryStatUpdate(const std::string& stat_type, int dev_id,
int64_t increment) {
StatRegistry::GetInstance()->Update("Device" + stat_type, dev_id, increment);
}

#define MEMORY_STAT_REGISTER_WITH_ID(item, id) \
StatRegistry::GetInstance()->Register( \
#item, id, Stat<ThreadLocalStatDevice##id##item>::GetInstance());

#define MEMORY_STAT_REGISTER(item) \
MEMORY_STAT_REGISTER_WITH_ID(item, 0); \
MEMORY_STAT_REGISTER_WITH_ID(item, 1); \
MEMORY_STAT_REGISTER_WITH_ID(item, 2); \
MEMORY_STAT_REGISTER_WITH_ID(item, 3); \
MEMORY_STAT_REGISTER_WITH_ID(item, 4); \
MEMORY_STAT_REGISTER_WITH_ID(item, 5); \
MEMORY_STAT_REGISTER_WITH_ID(item, 6); \
MEMORY_STAT_REGISTER_WITH_ID(item, 7); \
MEMORY_STAT_REGISTER_WITH_ID(item, 8); \
MEMORY_STAT_REGISTER_WITH_ID(item, 9); \
MEMORY_STAT_REGISTER_WITH_ID(item, 10); \
MEMORY_STAT_REGISTER_WITH_ID(item, 11); \
MEMORY_STAT_REGISTER_WITH_ID(item, 12); \
MEMORY_STAT_REGISTER_WITH_ID(item, 13); \
MEMORY_STAT_REGISTER_WITH_ID(item, 14); \
MEMORY_STAT_REGISTER_WITH_ID(item, 15)
int64_t HostMemoryStatCurrentValue(const std::string& stat_type, int dev_id) {
return StatRegistry::GetInstance()->GetCurrentValue("Host" + stat_type,
dev_id);
}

int64_t HostMemoryStatPeakValue(const std::string& stat_type, int dev_id) {
return StatRegistry::GetInstance()->GetPeakValue("Host" + stat_type, dev_id);
}

void HostMemoryStatUpdate(const std::string& stat_type, int dev_id,
int64_t increment) {
StatRegistry::GetInstance()->Update("Host" + stat_type, dev_id, increment);
}

#define DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, id) \
StatRegistry::GetInstance()->Register( \
"Device" #item, id, Stat<DeviceMemoryStat##item##id>::GetInstance());

#define DEVICE_MEMORY_STAT_REGISTER(item) \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 0); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 1); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 2); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 3); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 4); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 5); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 6); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 7); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 8); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 9); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 10); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 11); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 12); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 13); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 14); \
DEVICE_MEMORY_STAT_REGISTER_WITH_ID(item, 15)

#define HOST_MEMORY_STAT_REGISTER(item) \
StatRegistry::GetInstance()->Register( \
"Host" #item, 0, Stat<HostMemoryStat##item##0>::GetInstance());

int RegisterAllStats() {
MEMORY_STAT_REGISTER(Allocated);
MEMORY_STAT_REGISTER(Reserved);
DEVICE_MEMORY_STAT_REGISTER(Allocated);
DEVICE_MEMORY_STAT_REGISTER(Reserved);

HOST_MEMORY_STAT_REGISTER(Allocated);
HOST_MEMORY_STAT_REGISTER(Reserved);
return 0;
}

Expand Down
Loading