forked from Theano/libgpuarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpuarray_buffer_collectives.c
More file actions
98 lines (86 loc) · 3.31 KB
/
gpuarray_buffer_collectives.c
File metadata and controls
98 lines (86 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include "gpuarray/buffer.h"
#include "gpuarray/buffer_collectives.h"
#include "gpuarray/error.h"
#include "private.h"
int gpucomm_new(gpucomm** comm, gpucontext* ctx, gpucommCliqueId comm_id,
int ndev, int rank) {
if (ctx->comm_ops == NULL) {
*comm = NULL;
return GA_UNSUPPORTED_ERROR;
}
return ctx->comm_ops->comm_new(comm, ctx, comm_id, ndev, rank);
}
void gpucomm_free(gpucomm* comm) {
gpucontext* ctx = gpucomm_context(comm);
if (ctx->comm_ops != NULL)
ctx->comm_ops->comm_free(comm);
}
const char* gpucomm_error(gpucontext* ctx) {
if (ctx->comm_ops != NULL)
return ctx->error_msg;
return "No collective ops available, API error. Is a collectives library "
"installed?";
}
gpucontext* gpucomm_context(gpucomm* comm) {
return ((partial_gpucomm*)comm)->ctx;
}
int gpucomm_gen_clique_id(gpucontext* ctx, gpucommCliqueId* comm_id) {
if (ctx->comm_ops == NULL)
return GA_COMM_ERROR;
return ctx->comm_ops->generate_clique_id(ctx, comm_id);
}
int gpucomm_get_count(gpucomm* comm, int* gpucount) {
gpucontext* ctx = gpucomm_context(comm);
if (ctx->comm_ops == NULL)
return GA_COMM_ERROR;
return ctx->comm_ops->get_count(comm, gpucount);
}
int gpucomm_get_rank(gpucomm* comm, int* rank) {
gpucontext* ctx = gpucomm_context(comm);
if (ctx->comm_ops == NULL)
return GA_COMM_ERROR;
return ctx->comm_ops->get_rank(comm, rank);
}
int gpucomm_reduce(gpudata* src, size_t offsrc, gpudata* dest, size_t offdest,
size_t count, int typecode, int opcode, int root,
gpucomm* comm) {
gpucontext* ctx = gpucomm_context(comm);
if (ctx->comm_ops == NULL)
return GA_COMM_ERROR;
return ctx->comm_ops->reduce(src, offsrc, dest, offdest, count, typecode,
opcode, root, comm);
}
int gpucomm_all_reduce(gpudata* src, size_t offsrc, gpudata* dest,
size_t offdest, size_t count, int typecode, int opcode,
gpucomm* comm) {
gpucontext* ctx = gpucomm_context(comm);
if (ctx->comm_ops == NULL)
return GA_COMM_ERROR;
return ctx->comm_ops->all_reduce(src, offsrc, dest, offdest, count, typecode,
opcode, comm);
}
int gpucomm_reduce_scatter(gpudata* src, size_t offsrc, gpudata* dest,
size_t offdest, size_t count, int typecode,
int opcode, gpucomm* comm) {
gpucontext* ctx = gpucomm_context(comm);
if (ctx->comm_ops == NULL)
return GA_COMM_ERROR;
return ctx->comm_ops->reduce_scatter(src, offsrc, dest, offdest, count,
typecode, opcode, comm);
}
int gpucomm_broadcast(gpudata* array, size_t offset, size_t count, int typecode,
int root, gpucomm* comm) {
gpucontext* ctx = gpucomm_context(comm);
if (ctx->comm_ops == NULL)
return GA_COMM_ERROR;
return ctx->comm_ops->broadcast(array, offset, count, typecode, root, comm);
}
int gpucomm_all_gather(gpudata* src, size_t offsrc, gpudata* dest,
size_t offdest, size_t count, int typecode,
gpucomm* comm) {
gpucontext* ctx = gpucomm_context(comm);
if (ctx->comm_ops == NULL)
return GA_COMM_ERROR;
return ctx->comm_ops->all_gather(src, offsrc, dest, offdest, count, typecode,
comm);
}