forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRandomForest.cs
More file actions
92 lines (82 loc) · 3.98 KB
/
Copy pathRandomForest.cs
File metadata and controls
92 lines (82 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Trainers.FastTree.Internal;
namespace Microsoft.ML.Trainers.FastTree
{
public abstract class RandomForestTrainerBase<TArgs, TTransformer, TModel> : FastTreeTrainerBase<TArgs, TTransformer, TModel>
where TArgs : FastForestArgumentsBase, new()
where TModel : IPredictorProducing<float>
where TTransformer: ISingleFeaturePredictionTransformer<TModel>
{
private readonly bool _quantileEnabled;
/// <summary>
/// Constructor invoked by the maml code-path.
/// </summary>
protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label, bool quantileEnabled = false)
: base(env, args, label)
{
_quantileEnabled = quantileEnabled;
}
/// <summary>
/// Constructor invoked by the API code-path.
/// </summary>
protected RandomForestTrainerBase(IHostEnvironment env,
SchemaShape.Column label,
string featureColumn,
string weightColumn,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDatapointsInLeaves,
double learningRate,
Action<TArgs> advancedSettings,
bool quantileEnabled = false)
: base(env, label, featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, advancedSettings)
{
_quantileEnabled = quantileEnabled;
}
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
{
Host.CheckValue(ch, nameof(ch));
IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch);
var optimizationAlgorithm = new RandomForestOptimizer(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch);
optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
optimizationAlgorithm.Smoothing = Args.Smoothing;
// No notion of dropout for non-boosting applications.
optimizationAlgorithm.DropoutRate = 0;
optimizationAlgorithm.DropoutRng = null;
optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph;
return optimizationAlgorithm;
}
protected override void InitializeTests()
{
}
protected override TreeLearner ConstructTreeLearner(IChannel ch)
{
return new RandomForestLeastSquaresTreeLearner(
TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient,
Args.FeatureFirstUsePenalty, Args.FeatureReusePenalty, Args.SoftmaxTemperature,
Args.HistogramPoolSize, Args.RngSeed, Args.SplitFraction,
Args.AllowEmptyTrees, Args.GainConfidenceLevel, Args.MaxCategoricalGroupsPerNode,
Args.MaxCategoricalSplitPoints, _quantileEnabled, Args.QuantileSampleCount, ParallelTraining,
Args.MinDocsPercentageForCategoricalSplit, Args.Bundling, Args.MinDocsForCategoricalSplit, Args.Bias);
}
public abstract class RandomForestObjectiveFunction : ObjectiveFunctionBase
{
protected RandomForestObjectiveFunction(Dataset trainData, TArgs args, double maxStepSize)
: base(trainData,
1, // No learning rate in random forests.
1, // No shrinkage in random forests.
maxStepSize,
1, // No derivative sampling in random forests.
false, // Improvements to quasi-newton step not relevant to RF.
args.RngSeed)
{
}
}
}
}