forked from BVLC/caffe
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsyncedmem.cpp
More file actions
141 lines (126 loc) · 2.62 KB
/
Copy pathsyncedmem.cpp
File metadata and controls
141 lines (126 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <cstring>
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
SyncedMemory::~SyncedMemory() {
if (cpu_ptr_ && own_cpu_data_) {
CaffeFreeHost(cpu_ptr_);
}
#ifndef CPU_ONLY
if (gpu_ptr_ && own_gpu_data_) {
CUDA_CHECK(cudaFree(gpu_ptr_));
}
#endif // CPU_ONLY
}
inline void SyncedMemory::to_cpu() {
switch (head_) {
case UNINITIALIZED:
CaffeMallocHost(&cpu_ptr_, size_);
caffe_memset(size_, 0, cpu_ptr_);
head_ = HEAD_AT_CPU;
own_cpu_data_ = true;
break;
case HEAD_AT_GPU:
#ifndef CPU_ONLY
if (cpu_ptr_ == NULL) {
CaffeMallocHost(&cpu_ptr_, size_);
own_cpu_data_ = true;
}
caffe_gpu_memcpy(size_, gpu_ptr_, cpu_ptr_);
head_ = SYNCED;
#else
NO_GPU;
#endif
break;
case HEAD_AT_CPU:
case SYNCED:
break;
}
}
inline void SyncedMemory::to_gpu() {
#ifndef CPU_ONLY
switch (head_) {
case UNINITIALIZED:
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
caffe_gpu_memset(size_, 0, gpu_ptr_);
head_ = HEAD_AT_GPU;
own_gpu_data_ = true;
break;
case HEAD_AT_CPU:
if (gpu_ptr_ == NULL) {
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
own_gpu_data_ = true;
}
caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_);
head_ = SYNCED;
break;
case HEAD_AT_GPU:
case SYNCED:
break;
}
#else
NO_GPU;
#endif
}
const void* SyncedMemory::cpu_data() {
to_cpu();
return (const void*)cpu_ptr_;
}
void SyncedMemory::set_cpu_data(void* data) {
CHECK(data);
if (own_cpu_data_) {
CaffeFreeHost(cpu_ptr_);
}
cpu_ptr_ = data;
head_ = HEAD_AT_CPU;
own_cpu_data_ = false;
}
const void* SyncedMemory::gpu_data() {
#ifndef CPU_ONLY
to_gpu();
return (const void*)gpu_ptr_;
#else
NO_GPU;
#endif
}
void SyncedMemory::set_gpu_data(void* data) {
#ifndef CPU_ONLY
CHECK(data);
if (own_gpu_data_) {
CUDA_CHECK(cudaFree(gpu_ptr_));
}
gpu_ptr_ = data;
head_ = HEAD_AT_GPU;
own_gpu_data_ = false;
#else
NO_GPU;
#endif
}
void* SyncedMemory::mutable_cpu_data() {
to_cpu();
head_ = HEAD_AT_CPU;
return cpu_ptr_;
}
void* SyncedMemory::mutable_gpu_data() {
#ifndef CPU_ONLY
to_gpu();
head_ = HEAD_AT_GPU;
return gpu_ptr_;
#else
NO_GPU;
#endif
}
#ifndef CPU_ONLY
void SyncedMemory::async_gpu_push(const cudaStream_t& stream) {
CHECK(head_ == HEAD_AT_CPU);
if (gpu_ptr_ == NULL) {
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
own_gpu_data_ = true;
}
CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyHostToDevice, stream));
// Assume caller will synchronize on the stream before use
head_ = SYNCED;
}
#endif
} // namespace caffe