forked from Theano/libgpuarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpuarray_array_collectives.c
More file actions
117 lines (105 loc) · 3.89 KB
/
gpuarray_array_collectives.c
File metadata and controls
117 lines (105 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "gpuarray/array.h"
#include "gpuarray/buffer_collectives.h"
#include "gpuarray/collectives.h"
#include "gpuarray/error.h"
#include "private.h"
/**
* \brief Finds total number of elements contained in `array`.
*/
static inline size_t find_total_elems(const GpuArray* array) {
unsigned int i;
size_t total_elems = 1;
for (i = 0; i < array->nd; ++i)
total_elems *= array->dimensions[i];
return total_elems;
}
/**
* \brief Checks if `src` and `dest` arrays are appropriate to participate in a
* collective operation.
*
* Checks to see if they contain the appropriate number of elements, if they are
* properly aligned (contiguous) and writeable (for `dest`) and if they contain
* elements of the same datatype. It returns the number of elements of the array
* with
* the less length.
*/
static inline int check_gpuarrays(int times_src, const GpuArray* src,
int times_dest, const GpuArray* dest,
size_t* count) {
size_t count_src, count_dest;
count_src = find_total_elems(src);
count_dest = find_total_elems(dest);
if (times_src * count_src != times_dest * count_dest)
return GA_VALUE_ERROR;
if (src->typecode != dest->typecode)
return GA_VALUE_ERROR;
if (!GpuArray_ISALIGNED(src) || !GpuArray_CHKFLAGS(dest, GA_BEHAVED))
return GA_UNALIGNED_ERROR;
if (times_src >= times_dest)
*count = count_src;
else
*count = count_dest;
return GA_NO_ERROR;
}
int GpuArray_reduce_from(const GpuArray* src, int opcode, int root,
gpucomm* comm) {
size_t total_elems;
if (!GpuArray_ISALIGNED(src))
return GA_UNALIGNED_ERROR;
total_elems = find_total_elems(src);
return gpucomm_reduce(src->data, src->offset, NULL, 0, total_elems,
src->typecode, opcode, root, comm);
}
int GpuArray_reduce(const GpuArray* src, GpuArray* dest, int opcode, int root,
gpucomm* comm) {
int rank = 0;
GA_CHECK(gpucomm_get_rank(comm, &rank));
if (rank == root) {
size_t count = 0;
GA_CHECK(check_gpuarrays(1, src, 1, dest, &count));
return gpucomm_reduce(src->data, src->offset, dest->data, dest->offset,
count, src->typecode, opcode, root, comm);
} else {
return GpuArray_reduce_from(src, opcode, root, comm);
}
}
int GpuArray_all_reduce(const GpuArray* src, GpuArray* dest, int opcode,
gpucomm* comm) {
size_t count = 0;
GA_CHECK(check_gpuarrays(1, src, 1, dest, &count));
return gpucomm_all_reduce(src->data, src->offset, dest->data, dest->offset,
count, src->typecode, opcode, comm);
}
int GpuArray_reduce_scatter(const GpuArray* src, GpuArray* dest, int opcode,
gpucomm* comm) {
size_t count = 0;
int ndev = 0;
GA_CHECK(gpucomm_get_count(comm, &ndev));
GA_CHECK(check_gpuarrays(1, src, ndev, dest, &count));
return gpucomm_reduce_scatter(src->data, src->offset, dest->data,
dest->offset, count, src->typecode, opcode,
comm);
}
int GpuArray_broadcast(GpuArray* array, int root, gpucomm* comm) {
int rank = 0;
size_t total_elems;
GA_CHECK(gpucomm_get_rank(comm, &rank));
if (rank == root) {
if (!GpuArray_CHKFLAGS(array, GA_BEHAVED))
return GA_UNALIGNED_ERROR;
} else {
if (!GpuArray_ISALIGNED(array))
return GA_UNALIGNED_ERROR;
}
total_elems = find_total_elems(array);
return gpucomm_broadcast(array->data, array->offset, total_elems,
array->typecode, root, comm);
}
int GpuArray_all_gather(const GpuArray* src, GpuArray* dest, gpucomm* comm) {
size_t count = 0;
int ndev = 0;
GA_CHECK(gpucomm_get_count(comm, &ndev));
GA_CHECK(check_gpuarrays(ndev, src, 1, dest, &count));
return gpucomm_all_gather(src->data, src->offset, dest->data, dest->offset,
count, src->typecode, comm);
}