This repository was archived by the owner on Jul 24, 2022. It is now read-only.
forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRegressionTree.cs
More file actions
241 lines (215 loc) · 12.5 KB
/
Copy pathRegressionTree.cs
File metadata and controls
241 lines (215 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Collections.Immutable;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// A container base class for exposing <see cref="InternalRegressionTree"/>'s and
/// <see cref="InternalQuantileRegressionTree"/>'s attributes to users.
/// This class should not be mutable, so it contains a lot of read-only members.
/// </summary>
public abstract class RegressionTreeBase
{
/// <summary>
/// <see cref="RegressionTreeBase"/> is an immutable wrapper over <see cref="_tree"/> for exposing some tree's
/// attribute to users.
/// </summary>
private readonly InternalRegressionTree _tree;
/// <summary>
/// See <see cref="LeftChild"/>.
/// </summary>
private readonly ImmutableArray<int> _lteChild;
/// <summary>
/// See <see cref="RightChild"/>.
/// </summary>
private readonly ImmutableArray<int> _gtChild;
/// <summary>
/// See <see cref="NumericalSplitFeatureIndexes"/>.
/// </summary>
private readonly ImmutableArray<int> _numericalSplitFeatureIndexes;
/// <summary>
/// See <see cref="NumericalSplitThresholds"/>.
/// </summary>
private readonly ImmutableArray<float> _numericalSplitThresholds;
/// <summary>
/// See <see cref="CategoricalSplitFlags"/>.
/// </summary>
private readonly ImmutableArray<bool> _categoricalSplitFlags;
/// <summary>
/// See <see cref="LeafValues"/>.
/// </summary>
private readonly ImmutableArray<double> _leafValues;
/// <summary>
/// See <see cref="SplitGains"/>.
/// </summary>
private readonly ImmutableArray<double> _splitGains;
/// <summary>
/// <see cref="LeftChild"/>[i] is the i-th node's child index used when
/// (1) the numerical feature indexed by <see cref="NumericalSplitFeatureIndexes"/>[i] is less than or equal
/// to the threshold <see cref="NumericalSplitThresholds"/>[i], or
/// (2) the categorical features indexed by <see cref="GetCategoricalCategoricalSplitFeatureRangeAt(int)"/>'s
/// returned value with nodeIndex=i is NOT a sub-set of <see cref="GetCategoricalSplitFeaturesAt(int)"/> with
/// nodeIndex=i.
/// Note that the case (1) happens only when <see cref="CategoricalSplitFlags"/>[i] is true and otherwise (2)
/// occurs. A non-negative returned value means a node (i.e., not a leaf); for example, 2 means the 3rd node in
/// the underlying <see cref="_tree"/>. A negative returned value means a leaf; for example, -1 stands for the
/// <see langword="~"/>(-1)-th leaf in the underlying <see cref="_tree"/>. Note that <see langword="~"/> is the
/// bitwise complement operator in C#; for details, see
/// https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/operators/bitwise-complement-operator.
/// </summary>
public IReadOnlyList<int> LeftChild => _lteChild;
/// <summary>
/// <see cref="RightChild"/>[i] is the i-th node's child index used when the two conditions, (1) and (2),
/// described in <see cref="LeftChild"/>'s document are not true. Its return value follows the format
/// used in <see cref="LeftChild"/>.
/// </summary>
public IReadOnlyList<int> RightChild => _gtChild;
/// <summary>
/// <see cref="NumericalSplitFeatureIndexes"/>[i] is the feature index used the splitting function of the
/// i-th node. This value is valid only if <see cref="CategoricalSplitFlags"/>[i] is false.
/// </summary>
public IReadOnlyList<int> NumericalSplitFeatureIndexes => _numericalSplitFeatureIndexes;
/// <summary>
/// <see cref="NumericalSplitThresholds"/>[i] is the threshold on feature indexed by
/// <see cref="NumericalSplitFeatureIndexes"/>[i], where i is the i-th node's index
/// (for example, i is 1 for the 2nd node in <see cref="_tree"/>).
/// </summary>
public IReadOnlyList<float> NumericalSplitThresholds => _numericalSplitThresholds;
/// <summary>
/// Determine the types of splitting function. If <see cref="CategoricalSplitFlags"/>[i] is true, the i-th
/// node's uses categorical splitting function. Otherwise, traditional numerical split is used.
/// </summary>
public IReadOnlyList<bool> CategoricalSplitFlags => _categoricalSplitFlags;
/// <summary>
/// <see cref="LeafValues"/>[i] is the learned value at the i-th leaf.
/// </summary>
public IReadOnlyList<double> LeafValues => _leafValues;
/// <summary>
/// Return categorical thresholds used at node indexed by nodeIndex. If the considered input feature does NOT
/// matche any of values returned by <see cref="GetCategoricalSplitFeaturesAt(int)"/>, we call it a
/// less-than-threshold event and therefore <see cref="LeftChild"/>[nodeIndex] is the child node that input
/// should go next. The returned value is valid only if <see cref="CategoricalSplitFlags"/>[nodeIndex] is true.
/// </summary>
public IReadOnlyList<int> GetCategoricalSplitFeaturesAt(int nodeIndex)
{
if (nodeIndex < 0 || nodeIndex >= NumberOfNodes)
throw Contracts.Except($"The input index, {nodeIndex}, is invalid. Its valid range is from 0 (inclusive) to {NumberOfNodes} (exclusive).");
if (_tree.CategoricalSplitFeatures == null || _tree.CategoricalSplitFeatures[nodeIndex] == null)
return new List<int>(); // Zero-length vector.
else
return _tree.CategoricalSplitFeatures[nodeIndex];
}
/// <summary>
/// Return categorical thresholds' range used at node indexed by nodeIndex. A categorical split at node indexed
/// by nodeIndex can consider multiple consecutive input features at one time; their range is specified by
/// <see cref="GetCategoricalCategoricalSplitFeatureRangeAt(int)"/>. The returned value is always a 2-element
/// array; its 1st element is the starting index and its 2nd element is the endining index of a feature segment.
/// The returned value is valid only if <see cref="CategoricalSplitFlags"/>[nodeIndex] is true.
/// </summary>
public IReadOnlyList<int> GetCategoricalCategoricalSplitFeatureRangeAt(int nodeIndex)
{
if (nodeIndex < 0 || nodeIndex >= NumberOfNodes)
throw Contracts.Except($"The input node index, {nodeIndex}, is invalid. Its valid range is from 0 (inclusive) to {NumberOfNodes} (exclusive).");
if (_tree.CategoricalSplitFeatureRanges == null || _tree.CategoricalSplitFeatureRanges[nodeIndex] == null)
return new List<int>(); // Zero-length vector.
else
return _tree.CategoricalSplitFeatureRanges[nodeIndex];
}
/// <summary>
/// The gains obtained by splitting data at nodes. Its i-th value is computed from to the split at the i-th node.
/// </summary>
public IReadOnlyList<double> SplitGains => _splitGains;
/// <summary>
/// Number of leaves in the tree. Note that <see cref="NumberOfLeaves"/> does not take non-leaf nodes into account.
/// </summary>
public int NumberOfLeaves => _tree.NumLeaves;
/// <summary>
/// Number of nodes in the tree. This doesn't include any leaves. For example, a tree with node0->node1,
/// node0->leaf3, node1->leaf1, node1->leaf2, <see cref="NumberOfNodes"/> and <see cref="NumberOfLeaves"/> should
/// be 2 and 3, respectively.
/// </summary>
// A visualization of the example mentioned in this doc string.
// node0
// / \
// node1 leaf3
// / \
// leaf1 leaf2
// The index of leaf starts with 1 because interally we use "-1" as the 1st leaf's index, "-2" for the 2nd leaf's index, and so on.
public int NumberOfNodes => _tree.NumNodes;
internal RegressionTreeBase(InternalRegressionTree tree)
{
_tree = tree;
_lteChild = ImmutableArray.Create(_tree.LteChild, 0, _tree.NumNodes);
_gtChild = ImmutableArray.Create(_tree.GtChild, 0, _tree.NumNodes);
_numericalSplitFeatureIndexes = ImmutableArray.Create(_tree.SplitFeatures, 0, _tree.NumNodes);
_numericalSplitThresholds = ImmutableArray.Create(_tree.RawThresholds, 0, _tree.NumNodes);
_categoricalSplitFlags = ImmutableArray.Create(_tree.CategoricalSplit, 0, _tree.NumNodes);
_leafValues = ImmutableArray.Create(_tree.LeafValues, 0, _tree.NumLeaves);
_splitGains = ImmutableArray.Create(_tree.SplitGains, 0, _tree.NumNodes);
}
}
/// <summary>
/// A container class for exposing <see cref="InternalRegressionTree"/>'s attributes to users.
/// This class should not be mutable, so it contains a lot of read-only members. Note that
/// <see cref="RegressionTree"/> is identical to <see cref="RegressionTreeBase"/> but in
/// another derived class <see cref="QuantileRegressionTree"/> some attributes are added.
/// </summary>
public sealed class RegressionTree : RegressionTreeBase
{
internal RegressionTree(InternalRegressionTree tree) : base(tree) { }
}
/// <summary>
/// A container class for exposing <see cref="InternalQuantileRegressionTree"/>'s attributes to users.
/// This class should not be mutable, so it contains a lot of read-only members. In addition to
/// things inherited from <see cref="RegressionTreeBase"/>, we add <see cref="GetLeafSamplesAt(int)"/>
/// and <see cref="GetLeafSampleWeightsAt(int)"/> to expose (sub-sampled) training labels falling into
/// the leafIndex-th leaf and their weights.
/// </summary>
public sealed class QuantileRegressionTree : RegressionTreeBase
{
/// <summary>
/// Sample labels from training data. <see cref="_leafSamples"/>[i] stores the labels falling into the
/// i-th leaf.
/// </summary>
private readonly double[][] _leafSamples;
/// <summary>
/// Sample labels' weights from training data. <see cref="_leafSampleWeights"/>[i] stores the weights for
/// labels falling into the i-th leaf. <see cref="_leafSampleWeights"/>[i][j] is the weight of
/// <see cref="_leafSamples"/>[i][j].
/// </summary>
private readonly double[][] _leafSampleWeights;
/// <summary>
/// Return the training labels falling into the specified leaf.
/// </summary>
/// <param name="leafIndex">The index of the specified leaf.</param>
/// <returns>Training labels</returns>
public IReadOnlyList<double> GetLeafSamplesAt(int leafIndex)
{
if (leafIndex < 0 || leafIndex >= NumberOfLeaves)
throw Contracts.Except($"The input leaf index, {leafIndex}, is invalid. Its valid range is from 0 (inclusive) to {NumberOfLeaves} (exclusive).");
// _leafSample always contains valid values assigned in constructor.
return _leafSamples[leafIndex];
}
/// <summary>
/// Return the weights for training labels falling into the specified leaf. If <see cref="GetLeafSamplesAt"/>
/// and <see cref="GetLeafSampleWeightsAt"/> use the same input, the i-th returned value of this function is
/// the weight of the i-th label in <see cref="GetLeafSamplesAt"/>.
/// </summary>
/// <param name="leafIndex">The index of the specified leaf.</param>
/// <returns>Training labels' weights</returns>
public IReadOnlyList<double> GetLeafSampleWeightsAt(int leafIndex)
{
if (leafIndex < 0 || leafIndex >= NumberOfLeaves)
throw Contracts.Except($"The input leaf index, {leafIndex}, is invalid. Its valid range is from 0 (inclusive) to {NumberOfLeaves} (exclusive).");
// _leafSampleWeights always contains valid values assigned in constructor.
return _leafSampleWeights[leafIndex];
}
internal QuantileRegressionTree(InternalQuantileRegressionTree tree) : base(tree)
{
tree.ExtractLeafSamplesAndTheirWeights(out _leafSamples, out _leafSampleWeights);
}
}
}