forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFastTree.cs
More file actions
3406 lines (3020 loc) · 158 KB
/
Copy pathFastTree.cs
File metadata and controls
3406 lines (3020 loc) · 158 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.ML.Calibrator;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
using Microsoft.ML.Data.Conversion;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.Onnx;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Trainers.FastTree.Internal;
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.Conversions;
using Microsoft.ML.TreePredictor;
using Newtonsoft.Json.Linq;
using Float = System.Single;
// All of these reviews apply in general to fast tree and random forest implementations.
//REVIEW: Decouple train method in Application.cs to have boosting and random forest logic seperate.
//REVIEW: Do we need to keep all the fast tree based testers?
namespace Microsoft.ML.Trainers.FastTree
{
public delegate void SignatureTreeEnsembleTrainer();
/// <summary>
/// FastTreeTrainerBase is generic class and can't have shared object among classes.
/// This class is to provide common for all classes object which we can use for lock purpose.
/// </summary>
internal static class FastTreeShared
{
public static readonly object TrainLock = new object();
}
public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
TrainerEstimatorBaseWithGroupId<TTransformer, TModel>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TArgs : TreeArgs, new()
where TModel : IPredictorProducing<Float>
{
protected readonly TArgs Args;
protected readonly bool AllowGC;
protected TreeEnsemble TrainedEnsemble;
protected int FeatureCount;
private protected RoleMappedData ValidData;
/// <summary>
/// If not null, it's a test data set passed in from training context. It will be converted to one element in
/// <see cref="Tests"/> by calling <see cref="ExamplesToFastTreeBins.GetCompatibleDataset"/> in <see cref="InitializeTests"/>.
/// </summary>
private protected RoleMappedData TestData;
protected IParallelTraining ParallelTraining;
protected OptimizationAlgorithm OptimizationAlgorithm;
protected Dataset TrainSet;
protected Dataset ValidSet;
/// <summary>
/// Data sets used to evaluate the prediction scores produced the trained model during the triaining process.
/// </summary>
protected Dataset[] TestSets;
protected int[] FeatureMap;
/// <summary>
/// In the training process, <see cref="TrainSet"/>, <see cref="ValidSet"/>, <see cref="TestSets"/> would be
/// converted into <see cref="Tests"/> for efficient model evaluation.
/// </summary>
protected List<Test> Tests;
protected TestHistory PruningTest;
protected int[] CategoricalFeatures;
// Test for early stopping.
protected Test TrainTest;
protected Test ValidTest;
protected double[] InitTrainScores;
protected double[] InitValidScores;
protected double[][] InitTestScores;
//protected int Iteration;
protected TreeEnsemble Ensemble;
protected bool HasValidSet => ValidSet != null;
private const string RegisterName = "FastTreeTraining";
// random for active features selection
private Random _featureSelectionRandom;
protected string InnerArgs => CmdParser.GetSettings(Host, Args, new TArgs());
public override TrainerInfo Info { get; }
public bool HasCategoricalFeatures => Utils.Size(CategoricalFeatures) > 0;
private protected virtual bool NeedCalibration => false;
/// <summary>
/// Constructor to use when instantiating the classes deriving from here through the API.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env,
SchemaShape.Column label,
string featureColumn,
string weightColumn,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDatapointsInLeaves,
Action<TArgs> advancedSettings)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
Args = new TArgs();
// set up the directly provided values
// override with the directly provided values.
Args.NumLeaves = numLeaves;
Args.NumTrees = numTrees;
Args.MinDocumentsInLeafs = minDatapointsInLeaves;
//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);
Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;
if (weightColumn != null)
Args.WeightColumn = Optional<string>.Explicit(weightColumn);
if (groupIdColumn != null)
Args.GroupIdColumn = Optional<string>.Explicit(groupIdColumn);
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true, supportTest: true);
// REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment
// with memory consumption more than 5GB, GC get stuck in infinite loop.
// Before, we could check a specific type of the environment here, but now it is internal, so we will need another
// mechanism to detect that we are running in Scope.
AllowGC = true;
Initialize(env);
}
/// <summary>
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
{
Host.CheckValue(args, nameof(args));
Args = args;
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true, supportTest: true);
// REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment
// with memory consumption more than 5GB, GC get stuck in infinite loop.
// Before, we could check a specific type of the environment here, but now it is internal, so we will need another
// mechanism to detect that we are running in Scope.
AllowGC = true;
Initialize(env);
}
protected abstract void PrepareLabels(IChannel ch);
protected abstract void InitializeTests();
protected abstract Test ConstructTestForTrainingData();
protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch);
protected abstract TreeLearner ConstructTreeLearner(IChannel ch);
protected abstract ObjectiveFunctionBase ConstructObjFunc(IChannel ch);
protected virtual Float GetMaxLabel()
{
return Float.PositiveInfinity;
}
private void Initialize(IHostEnvironment env)
{
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
{
using (var ch = Host.Start("FastTreeTrainerBase"))
{
numThreads = Host.ConcurrencyFactor;
ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
+ "setting of the environment. Using {0} training threads instead.", numThreads);
}
}
ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
ParallelTraining.InitEnvironment();
Tests = new List<Test>();
InitializeThreads(numThreads);
}
private protected void ConvertData(RoleMappedData trainData)
{
MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, trainData.Schema.Feature.Value.Index, out CategoricalFeatures);
var useTranspose = UseTranspose(Args.DiskTranspose, trainData) && (ValidData == null || UseTranspose(Args.DiskTranspose, ValidData));
var instanceConverter = new ExamplesToFastTreeBins(Host, Args.MaxBins, useTranspose, !Args.FeatureFlocks, Args.MinDocumentsInLeafs, GetMaxLabel());
TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, CategoricalFeatures, Args.CategoricalSplit);
FeatureMap = instanceConverter.FeatureMap;
if (ValidData != null)
ValidSet = instanceConverter.GetCompatibleDataset(ValidData, PredictionKind, CategoricalFeatures, Args.CategoricalSplit);
if (TestData != null)
TestSets = new[] { instanceConverter.GetCompatibleDataset(TestData, PredictionKind, CategoricalFeatures, Args.CategoricalSplit) };
}
private bool UseTranspose(bool? useTranspose, RoleMappedData data)
{
Host.AssertValue(data);
Host.Assert(data.Schema.Feature.HasValue);
if (useTranspose.HasValue)
return useTranspose.Value;
ITransposeDataView td = data.Data as ITransposeDataView;
return td != null && td.TransposeSchema.GetSlotType(data.Schema.Feature.Value.Index) != null;
}
protected void TrainCore(IChannel ch)
{
Contracts.CheckValue(ch, nameof(ch));
// REVIEW:Get rid of this lock then we completly remove all static classes from FastTree such as BlockingThreadPool.
lock (FastTreeShared.TrainLock)
{
using (Timer.Time(TimerEvent.TotalInitialization))
{
CheckArgs(ch);
PrintPrologInfo(ch);
Initialize(ch);
PrintMemoryStats(ch);
}
using (Timer.Time(TimerEvent.TotalTrain))
Train(ch);
if (Args.ExecutionTimes)
PrintExecutionTimes(ch);
TrainedEnsemble = Ensemble;
if (FeatureMap != null)
TrainedEnsemble.RemapFeatures(FeatureMap);
ParallelTraining.FinalizeEnvironment();
}
}
protected virtual bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStopping, ref int bestIteration)
{
bestIteration = Ensemble.NumTrees;
return false;
}
protected virtual int GetBestIteration(IChannel ch) => Ensemble.NumTrees;
protected virtual void InitializeThreads(int numThreads)
{
ThreadTaskManager.Initialize(numThreads);
}
protected virtual void PrintExecutionTimes(IChannel ch)
{
ch.Info("Execution time breakdown:\n{0}", Timer.GetString());
}
protected virtual void CheckArgs(IChannel ch)
{
Args.Check(ch);
IntArray.CompatibilityLevel = Args.FeatureCompressionLevel;
// change arguments
if (Args.HistogramPoolSize < 2)
Args.HistogramPoolSize = Args.NumLeaves * 2 / 3;
if (Args.HistogramPoolSize > Args.NumLeaves - 1)
Args.HistogramPoolSize = Args.NumLeaves - 1;
if (Args.BaggingSize > 0)
{
int bagCount = Args.NumTrees / Args.BaggingSize;
if (bagCount * Args.BaggingSize != Args.NumTrees)
throw ch.Except("Number of trees should be a multiple of number bag size");
}
if (!(0 <= Args.GainConfidenceLevel && Args.GainConfidenceLevel < 1))
throw ch.Except("Gain confidence level must be in the range [0,1)");
#if OLD_DATALOAD
#if !NO_STORE
if (_args.offloadBinsToFileStore)
{
if (!string.IsNullOrEmpty(_args.offloadBinsDirectory) && !Directory.Exists(_args.offloadBinsDirectory))
{
try
{
Directory.CreateDirectory(_args.offloadBinsDirectory);
}
catch (Exception e)
{
throw ch.Except(e, "Failure creating bins offload directory {0} - Exception {1}", _args.offloadBinsDirectory, e.Message);
}
}
}
#endif
#endif
}
/// <summary>
/// A virtual method that is used to print header of test graph.
/// Appliations that need printing test graph are supposed to override
/// it to print specific test graph header.
/// </summary>
/// <returns> string representation of test graph header </returns>
protected virtual string GetTestGraphHeader() => string.Empty;
/// <summary>
/// A virtual method that is used to print a single line of test graph.
/// Applications that need printing test graph are supposed to override
/// it to print a specific line of test graph after a new iteration is finished.
/// </summary>
/// <returns> string representation of a line of test graph </returns>
protected virtual string GetTestGraphLine() => string.Empty;
/// <summary>
/// A virtual method that is used to compute test results after each iteration is finished.
/// </summary>
protected virtual void ComputeTests()
{
}
protected void PrintTestGraph(IChannel ch)
{
// we call Tests computing no matter whether we require to print test graph
ComputeTests();
if (!Args.PrintTestGraph)
return;
if (Ensemble.NumTrees == 0)
ch.Info(GetTestGraphHeader());
else
ch.Info(GetTestGraphLine());
return;
}
protected virtual void Initialize(IChannel ch)
{
#region Load/Initialize State
using (Timer.Time(TimerEvent.InitializeLabels))
PrepareLabels(ch);
using (Timer.Time(TimerEvent.InitializeTraining))
{
InitializeEnsemble();
OptimizationAlgorithm = ConstructOptimizationAlgorithm(ch);
}
using (Timer.Time(TimerEvent.InitializeTests))
InitializeTests();
if (AllowGC)
{
GC.Collect(2, GCCollectionMode.Forced);
GC.Collect(2, GCCollectionMode.Forced);
}
#endregion
}
#if !NO_STORE
/// <summary>
/// Calculates the percentage of feature bins that will fit into memory based on current available memory in the machine.
/// </summary>
/// <returns>A float number between 0 and 1 indicating the percentage of features to load.
/// The number will not be smaller than two times the feature fraction value</returns>
private float GetFeaturePercentInMemory(IChannel ch)
{
const float maxFeaturePercentValue = 1.0f;
float availableMemory = GetMachineAvailableBytes();
ch.Info("Available memory in the machine is = {0} bytes", availableMemory.ToString("N", CultureInfo.InvariantCulture));
float minFeaturePercentThreshold = _args.preloadFeatureBinsBeforeTraining ? (float)_args.featureFraction * 2 : (float)_args.featureFraction;
if (minFeaturePercentThreshold >= maxFeaturePercentValue)
{
return maxFeaturePercentValue;
}
// Initial free memory allowance in bytes for single and parallel fastrank modes
float freeMemoryAllowance = 1024 * 1024 * 512;
if (_optimizationAlgorithm.TreeLearner != null)
{
// Get the size of memory in bytes needed by the tree learner internal data structures
freeMemoryAllowance += _optimizationAlgorithm.TreeLearner.GetSizeOfReservedMemory();
}
availableMemory = (availableMemory > freeMemoryAllowance) ? availableMemory - freeMemoryAllowance : 0;
long featureSize = TrainSet.FeatureSetSize;
if (ValidSet != null)
{
featureSize += ValidSet.FeatureSetSize;
}
if (TestSets != null)
{
foreach (var item in TestSets)
{
featureSize += item.FeatureSetSize;
}
}
ch.Info("Total Feature bins size is = {0} bytes", featureSize.ToString("N", CultureInfo.InvariantCulture));
return Math.Min(Math.Max(minFeaturePercentThreshold, availableMemory / featureSize), maxFeaturePercentValue);
}
#endif
protected bool[] GetActiveFeatures()
{
var activeFeatures = Utils.CreateArray(TrainSet.NumFeatures, true);
if (Args.FeatureFraction < 1.0)
{
if (_featureSelectionRandom == null)
_featureSelectionRandom = new Random(Args.FeatureSelectSeed);
for (int i = 0; i < TrainSet.NumFeatures; ++i)
{
if (activeFeatures[i])
activeFeatures[i] = _featureSelectionRandom.NextDouble() <= Args.FeatureFraction;
}
}
return activeFeatures;
}
private string GetDatasetStatistics(Dataset set)
{
long datasetSize = set.SizeInBytes();
int skeletonSize = set.Skeleton.SizeInBytes();
return string.Format("set contains {0} query-doc pairs in {1} queries with {2} features and uses {3} MB ({4} MB for features)",
set.NumDocs, set.NumQueries, set.NumFeatures, datasetSize / 1024 / 1024, (datasetSize - skeletonSize) / 1024 / 1024);
}
protected virtual void PrintMemoryStats(IChannel ch)
{
Contracts.AssertValue(ch);
ch.Trace("Training {0}", GetDatasetStatistics(TrainSet));
if (ValidSet != null)
ch.Trace("Validation {0}", GetDatasetStatistics(ValidSet));
if (TestSets != null)
{
for (int i = 0; i < TestSets.Length; ++i)
ch.Trace("ComputeTests[{1}] {0}",
GetDatasetStatistics(TestSets[i]), i);
}
if (AllowGC)
ch.Trace("GC Total Memory = {0} MB", GC.GetTotalMemory(true) / 1024 / 1024);
Process currentProcess = Process.GetCurrentProcess();
ch.Trace("Working Set = {0} MB", currentProcess.WorkingSet64 / 1024 / 1024);
ch.Trace("Virtual Memory = {0} MB",
currentProcess.VirtualMemorySize64 / 1024 / 1024);
ch.Trace("Private Memory = {0} MB",
currentProcess.PrivateMemorySize64 / 1024 / 1024);
ch.Trace("Peak Working Set = {0} MB", currentProcess.PeakWorkingSet64 / 1024 / 1024);
ch.Trace("Peak Virtual Memory = {0} MB",
currentProcess.PeakVirtualMemorySize64 / 1024 / 1024);
}
protected bool AreSamplesWeighted(IChannel ch)
{
return TrainSet.SampleWeights != null;
}
private void InitializeEnsemble()
{
Ensemble = new TreeEnsemble();
}
/// <summary>
/// Creates weights wrapping (possibly, trivial) for gradient target values.
/// </summary>
protected virtual IGradientAdjuster MakeGradientWrapper(IChannel ch)
{
if (AreSamplesWeighted(ch))
return new QueryWeightsGradientWrapper();
else
return new TrivialGradientWrapper();
}
#if !NO_STORE
/// <summary>
/// Unloads feature bins being used in the current iteration.
/// </summary>
/// <param name="featureToUnload">Boolean array indicating the features to unload</param>
private void UnloadFeatureBins(bool[] featureToUnload)
{
foreach (ScoreTracker scoreTracker in this._optimizationAlgorithm.TrackedScores)
{
for (int i = 0; i < scoreTracker.Dataset.Features.Length; i++)
{
if (featureToUnload[i])
{
// Only return buffers to the pool that were allocated using the pool
// So far only this type of IntArrays below have buffer pool support.
// This is to avoid unexpected leaks in case a new IntArray is added but we are not allocating it from the pool.
if (scoreTracker.Dataset.Features[i].Bins is DenseIntArray ||
scoreTracker.Dataset.Features[i].Bins is DeltaSparseIntArray ||
scoreTracker.Dataset.Features[i].Bins is DeltaRepeatIntArray)
{
scoreTracker.Dataset.Features[i].Bins.ReturnBuffer();
scoreTracker.Dataset.Features[i].Bins = null;
}
}
}
}
}
/// <summary>
/// Worker thread delegate that loads features for the next training iteration
/// </summary>
/// <param name="state">thread state object</param>
private void LazyFeatureLoad(object state)
{
bool[] featuresToLoad = (bool[])state;
foreach (ScoreTracker scoreTracker in this._optimizationAlgorithm.TrackedScores)
{
for (int i = 0; i < scoreTracker.Dataset.Features.Length; i++)
{
if (featuresToLoad[i])
{
// just using the Bins property so feature bins are loaded into memory
IntArray bins = scoreTracker.Dataset.Features[i].Bins;
}
}
}
}
/// <summary>
/// Iterates through the feature sets needed in future tree training iterations (i.e. in ActiveFeatureSetQueue),
/// using the same order as they were enqueued, and it returns the initial active features based on the percentage parameter.
/// </summary>
/// <param name="pctFeatureThreshold">A float value between 0 and 1 indicating maximum percentage of features to return</param>
/// <returns>Array indicating calculated feature list</returns>
private bool[] GetNextFeaturesByThreshold(float pctFeatureThreshold)
{
int totalUniqueFeatureCount = 0;
bool[] nextActiveFeatures = new bool[TrainSet.NumFeatures];
if (pctFeatureThreshold == 1.0f)
{
// return all features to load
return nextActiveFeatures.Select(x => x = true).ToArray();
}
int maxNumberOfFeatures = (int)(pctFeatureThreshold * TrainSet.NumFeatures);
for (int i = 0; i < _activeFeatureSetQueue.Count; i++)
{
bool[] tempActiveFeatures = _activeFeatureSetQueue.ElementAt(i);
for (int j = 0; j < tempActiveFeatures.Length; j++)
{
if (tempActiveFeatures[j] && !nextActiveFeatures[j])
{
nextActiveFeatures[j] = true;
if (totalUniqueFeatureCount++ > maxNumberOfFeatures)
return nextActiveFeatures;
}
}
}
return nextActiveFeatures;
}
/// <summary>
/// Adds several items in the ActiveFeature queue
/// </summary>
/// <param name="numberOfItems">Number of items to add</param>
private void GenerateActiveFeatureLists(int numberOfItems)
{
for (int i = 0; i < numberOfItems; i++)
{
_activeFeatureSetQueue.Enqueue(GetActiveFeatures());
}
}
#endif
protected virtual BaggingProvider CreateBaggingProvider()
{
Contracts.Assert(Args.BaggingSize > 0);
return new BaggingProvider(TrainSet, Args.NumLeaves, Args.RngSeed, Args.BaggingTrainFraction);
}
protected virtual bool ShouldRandomStartOptimizer()
{
return false;
}
protected virtual void Train(IChannel ch)
{
Contracts.AssertValue(ch);
int numTotalTrees = Args.NumTrees;
ch.Info(
"Reserved memory for tree learner: {0} bytes",
OptimizationAlgorithm.TreeLearner.GetSizeOfReservedMemory());
#if !NO_STORE
if (_args.offloadBinsToFileStore)
{
// Initialize feature percent to load before loading any features
_featurePercentToLoad = GetFeaturePercentInMemory(ch);
ch.Info("Using featurePercentToLoad = {0} ", _featurePercentToLoad);
}
#endif
// random starting point
bool revertRandomStart = false;
if (Ensemble.NumTrees < numTotalTrees && ShouldRandomStartOptimizer())
{
ch.Info("Randomizing start point");
OptimizationAlgorithm.TrainingScores.RandomizeScores(Args.RngSeed, false);
revertRandomStart = true;
}
ch.Info("Starting to train ...");
BaggingProvider baggingProvider = Args.BaggingSize > 0 ? CreateBaggingProvider() : null;
#if OLD_DATALOAD
#if !NO_STORE
// Preload
GenerateActiveFeatureLists(_args.numTrees);
Thread featureLoadThread = null;
// Initial feature load
if (_args.offloadBinsToFileStore)
{
FileObjectStore<IntArrayFormatter>.GetDefaultInstance().SealObjectStore();
if (_args.preloadFeatureBinsBeforeTraining)
{
StartFeatureLoadThread(GetNextFeaturesByThreshold(_featurePercentToLoad)).Join();
}
}
#endif
#endif
IEarlyStoppingCriterion earlyStoppingRule = null;
int bestIteration = 0;
int emptyTrees = 0;
using (var pch = Host.StartProgressChannel("FastTree training"))
{
pch.SetHeader(new ProgressHeader("trees"), e => e.SetProgress(0, Ensemble.NumTrees, numTotalTrees));
while (Ensemble.NumTrees < numTotalTrees)
{
using (Timer.Time(TimerEvent.Iteration))
{
#if NO_STORE
bool[] activeFeatures = GetActiveFeatures();
#else
bool[] activeFeatures = _activeFeatureSetQueue.Dequeue();
#endif
if (Args.BaggingSize > 0 && Ensemble.NumTrees % Args.BaggingSize == 0)
{
baggingProvider.GenerateNewBag();
OptimizationAlgorithm.TreeLearner.Partitioning =
baggingProvider.GetCurrentTrainingPartition();
}
#if !NO_STORE
if (_args.offloadBinsToFileStore)
{
featureLoadThread = StartFeatureLoadThread(GetNextFeaturesByThreshold(_featurePercentToLoad));
if (!_args.preloadFeatureBinsBeforeTraining)
featureLoadThread.Join();
}
#endif
// call the weak learner
var tree = OptimizationAlgorithm.TrainingIteration(ch, activeFeatures);
if (tree == null)
{
emptyTrees++;
numTotalTrees--;
}
else if (Args.BaggingSize > 0 && Ensemble.Trees.Count() > 0)
{
ch.Assert(Ensemble.Trees.Last() == tree);
Ensemble.Trees.Last()
.AddOutputsToScores(OptimizationAlgorithm.TrainingScores.Dataset,
OptimizationAlgorithm.TrainingScores.Scores,
baggingProvider.GetCurrentOutOfBagPartition().Documents);
}
CustomizedTrainingIteration(tree);
using (Timer.Time(TimerEvent.Test))
{
PrintIterationMessage(ch, pch);
PrintTestResults(ch);
}
// revert randomized start
if (revertRandomStart)
{
revertRandomStart = false;
ch.Info("Reverting random score assignment");
OptimizationAlgorithm.TrainingScores.RandomizeScores(Args.RngSeed, true);
}
#if !NO_STORE
if (_args.offloadBinsToFileStore)
{
// Unload only features that are not needed for the next iteration
bool[] featuresToUnload = activeFeatures;
if (_args.preloadFeatureBinsBeforeTraining)
{
featuresToUnload =
activeFeatures.Zip(GetNextFeaturesByThreshold(_featurePercentToLoad),
(current, next) => current && !next).ToArray();
}
UnloadFeatureBins(featuresToUnload);
if (featureLoadThread != null &&
_args.preloadFeatureBinsBeforeTraining)
{
// wait for loading the features needed for the next iteration
featureLoadThread.Join();
}
}
#endif
if (ShouldStop(ch, ref earlyStoppingRule, ref bestIteration))
break;
}
}
if (emptyTrees > 0)
{
ch.Warning("{0} of the boosting iterations failed to grow a tree. This is commonly because the " +
"minimum documents in leaf hyperparameter was set too high for this dataset.", emptyTrees);
}
}
if (earlyStoppingRule != null)
{
Contracts.Assert(numTotalTrees == 0 || bestIteration > 0);
// REVIEW: Need to reconcile with future progress reporting changes.
ch.Info("The training is stopped at {0} and iteration {1} is picked",
Ensemble.NumTrees, bestIteration);
}
else
{
bestIteration = GetBestIteration(ch);
}
OptimizationAlgorithm.FinalizeLearning(bestIteration);
Ensemble.PopulateRawThresholds(TrainSet);
ParallelTraining.FinalizeTreeLearner();
}
#if !NO_STORE
/// <summary>
/// Gets the available bytes performance counter on the local machine
/// </summary>
/// <returns>Available bytes number</returns>
private float GetMachineAvailableBytes()
{
using (var availableBytes = new System.Diagnostics.PerformanceCounter("Memory", "Available Bytes", true))
{
return availableBytes.NextValue();
}
}
#endif
// This method is called at the end of each training iteration, with the tree that was learnt on that iteration.
// Note that this tree can be null if no tree was learnt this iteration.
protected virtual void CustomizedTrainingIteration(RegressionTree tree)
{
}
protected virtual void PrintIterationMessage(IChannel ch, IProgressChannel pch)
{
// REVIEW: report some metrics, not just number of trees?
int iteration = Ensemble.NumTrees;
if (iteration % 50 == 49)
pch.Checkpoint(iteration + 1);
}
protected virtual void PrintTestResults(IChannel ch)
{
if (Args.TestFrequency != int.MaxValue && (Ensemble.NumTrees % Args.TestFrequency == 0 || Ensemble.NumTrees == Args.NumTrees))
{
var sb = new StringBuilder();
using (var sw = new StringWriter(sb))
{
foreach (var t in Tests)
{
var results = t.ComputeTests();
sw.Write(t.FormatInfoString());
}
}
if (sb.Length > 0)
ch.Info(sb.ToString());
}
}
protected virtual void PrintPrologInfo(IChannel ch)
{
Contracts.AssertValue(ch);
ch.Trace("Host = {0}", Environment.MachineName);
ch.Trace("CommandLine = {0}", CmdParser.GetSettings(Host, Args, new TArgs()));
ch.Trace("GCSettings.IsServerGC = {0}", System.Runtime.GCSettings.IsServerGC);
ch.Trace("{0}", Args);
}
protected ScoreTracker ConstructScoreTracker(Dataset set)
{
// If not found contruct one
ScoreTracker st = null;
if (set == TrainSet)
st = OptimizationAlgorithm.GetScoreTracker("train", TrainSet, InitTrainScores);
else if (set == ValidSet)
st = OptimizationAlgorithm.GetScoreTracker("valid", ValidSet, InitValidScores);
else
{
for (int t = 0; t < TestSets.Length; ++t)
{
if (set == TestSets[t])
{
double[] initTestScores = InitTestScores?[t];
st = OptimizationAlgorithm.GetScoreTracker(string.Format("test[{0}]", t), TestSets[t], initTestScores);
}
}
}
Contracts.Check(st != null, "unknown dataset passed to ConstructScoreTracker");
return st;
}
private double[] ComputeScoresSmart(IChannel ch, Dataset set)
{
if (!Args.CompressEnsemble)
{
foreach (var st in OptimizationAlgorithm.TrackedScores)
if (st.Dataset == set)
{
ch.Trace("Computing scores fast");
return st.Scores;
}
}
return ComputeScoresSlow(ch, set);
}
private double[] ComputeScoresSlow(IChannel ch, Dataset set)
{
ch.Trace("Computing scores slow");
double[] scores = new double[set.NumDocs];
Ensemble.GetOutputs(set, scores);
double[] initScores = GetInitScores(set);
if (initScores != null)
{
Contracts.Check(scores.Length == initScores.Length, "Length of initscores and scores mismatch");
for (int i = 0; i < scores.Length; i++)
scores[i] += initScores[i];
}
return scores;
}
private double[] GetInitScores(Dataset set)
{
if (set == TrainSet)
return InitTrainScores;
if (set == ValidSet)
return InitValidScores;
for (int i = 0; TestSets != null && i < TestSets.Length; i++)
{
if (set == TestSets[i])
return InitTestScores?[i];
}
throw Contracts.Except("Queried for unknown set");
}
}
internal abstract class DataConverter
{
protected readonly int NumFeatures;
public abstract int NumExamples { get; }
protected readonly Float MaxLabel;
protected readonly PredictionKind PredictionKind;
/// <summary>
/// The per-feature bin upper bounds. Implementations may differ on when all of the items
/// in this array are initialized to non-null values but it must happen at least no later
/// than immediately after we return from <see cref="GetDataset"/>.
/// </summary>
public readonly Double[][] BinUpperBounds;
/// <summary>
/// In the event that any features are filtered, this will contain the feature map, where
/// the indices are the indices of features within the dataset, and the tree as we are
/// learning, and the values are the indices of the features within the original input
/// data. This array is used to "rehydrate" the tree once we finish training, so that the
/// feature indices are once again over the full set of features, as opposed to the subset
/// of features we actually trained on. This can be null in the event that no filtering
/// occurred.
/// </summary>
/// <seealso cref="TreeEnsemble.RemapFeatures"/>
public int[] FeatureMap;
protected readonly IHost Host;
protected readonly int[] CategoricalFeatureIndices;
protected readonly bool CategoricalSplit;
protected bool UsingMaxLabel
{
get { return MaxLabel != Float.PositiveInfinity; }
}
private DataConverter(RoleMappedData data, IHost host, Double[][] binUpperBounds, Float maxLabel,
PredictionKind kind, int[] categoricalFeatureIndices, bool categoricalSplit)
{
Contracts.AssertValue(host, "host");
Host = host;
Host.CheckValue(data, nameof(data));
data.CheckFeatureFloatVector(out int featLen);
data.CheckOptFloatWeight();
data.CheckOptGroup();
NumFeatures = featLen;
if (binUpperBounds != null)
{
Host.AssertValue(binUpperBounds);
Host.Assert(Utils.Size(binUpperBounds) == NumFeatures);
Host.Assert(binUpperBounds.All(b => b != null));
BinUpperBounds = binUpperBounds;
}
else
BinUpperBounds = new Double[NumFeatures][];
MaxLabel = maxLabel;
PredictionKind = kind;
CategoricalSplit = categoricalSplit;
CategoricalFeatureIndices = categoricalFeatureIndices;
}
public static DataConverter Create(RoleMappedData data, IHost host, int maxBins,
Float maxLabel, bool diskTranspose, bool noFlocks, int minDocsPerLeaf, PredictionKind kind,
IParallelTraining parallelTraining, int[] categoricalFeatureIndices, bool categoricalSplit)
{
Contracts.AssertValue(host, "host");
host.AssertValue(data);
host.Assert(maxBins > 0);
DataConverter conv;
using (var ch = host.Start("CreateConverter"))
{
if (!diskTranspose)
conv = new MemImpl(data, host, maxBins, maxLabel, noFlocks, minDocsPerLeaf, kind,
parallelTraining, categoricalFeatureIndices, categoricalSplit);
else
conv = new DiskImpl(data, host, maxBins, maxLabel, kind, parallelTraining, categoricalFeatureIndices, categoricalSplit);
}
return conv;
}
public static DataConverter Create(RoleMappedData data, IHost host, Double[][] binUpperBounds,
Float maxLabel, bool diskTranspose, bool noFlocks, PredictionKind kind, int[] categoricalFeatureIndices, bool categoricalSplit)
{
Contracts.AssertValue(host, "host");
host.AssertValue(data);
DataConverter conv;
using (var ch = host.Start("CreateConverter"))
{
if (!diskTranspose)
conv = new MemImpl(data, host, binUpperBounds, maxLabel, noFlocks, kind, categoricalFeatureIndices, categoricalSplit);
else
conv = new DiskImpl(data, host, binUpperBounds, maxLabel, kind, categoricalFeatureIndices, categoricalSplit);