forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTextLoader.cs
More file actions
202 lines (175 loc) · 9.07 KB
/
Copy pathTextLoader.cs
File metadata and controls
202 lines (175 loc) · 9.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using Microsoft.ML.Data;
namespace Microsoft.ML.Legacy.Data
{
public sealed partial class TextLoaderRange
{
public TextLoaderRange()
{
}
/// <summary>
/// Convenience constructor for the scalar case, when a given column
/// in the schema spans only a single column in the dataset.
/// <see cref="Min"/> and <see cref="Max"/> are set to the single value <paramref name="ordinal"/>.
/// </summary>
/// <param name="ordinal">Column index in the dataset.</param>
public TextLoaderRange(int ordinal)
{
Contracts.CheckParam(ordinal >= 0, nameof(ordinal), "Cannot be a negative number");
Min = ordinal;
Max = ordinal;
}
/// <summary>
/// Convenience constructor for the vector case, when a given column
/// in the schema spans contiguous columns in the dataset.
/// </summary>
/// <param name="min">Starting column index in the dataset.</param>
/// <param name="max">Ending column index in the dataset.</param>
public TextLoaderRange(int min, int max)
{
Contracts.CheckParam(min >= 0, nameof(min), "Cannot be a negative number.");
Contracts.CheckParam(max >= min, nameof(max), "Cannot be less than " + nameof(min) + ".");
Min = min;
Max = max;
}
}
public sealed partial class TextLoader
{
/// <summary>
/// Construct a TextLoader object by inferencing the dataset schema from a type.
/// </summary>
/// <param name="useHeader">Does the file contains header?</param>
/// <param name="separator">Column separator character. Default is '\t'</param>
/// <param name="allowQuotedStrings">Whether the input may include quoted values,
/// which can contain separator characters, colons,
/// and distinguish empty values from missing values. When true, consecutive separators
/// denote a missing value and an empty value is denoted by \"\".
/// When false, consecutive separators denote an empty value.</param>
/// <param name="supportSparse">Whether the input may include sparse representations for example,
/// if one of the row contains "5 2:6 4:3" that's mean there are 5 columns all zero
/// except for 3rd and 5th columns which have values 6 and 3</param>
/// <param name="trimWhitespace">Remove trailing whitespace from lines</param>
public TextLoader CreateFrom<TInput>(bool useHeader = false,
char separator = '\t', bool allowQuotedStrings = true,
bool supportSparse = true, bool trimWhitespace = false)
{
var userType = typeof(TInput);
var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
var propertyInfos =
userType
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0);
var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
Arguments.Column = new TextLoaderColumn[memberInfos.Length];
for (int index = 0; index < memberInfos.Length; index++)
{
var memberInfo = memberInfos[index];
var mappingAttr = memberInfo.GetCustomAttribute<LoadColumnAttribute>();
if (mappingAttr == null)
throw Contracts.Except($"Field or property {memberInfo.Name} is missing LoadColumnAttributeAttribute");
#pragma warning disable 618
if (Regex.Match(mappingAttr.Start, @"[^(0-9,\*\-~)]+").Success)
throw Contracts.Except($"{mappingAttr.Start} contains invalid characters. " +
$"Valid characters are 0-9, *, - and ~");
var mappingNameAttr = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
var name = mappingNameAttr?.Name ?? memberInfo.Name;
ML.Data.TextLoader.Range[] sources;
if (!ML.Data.TextLoader.Column.TryParseSourceEx(mappingAttr.Start, out sources))
throw Contracts.Except($"{mappingAttr.Start} could not be parsed.");
#pragma warning restore 618
Contracts.Assert(sources != null);
TextLoaderColumn tlc = new TextLoaderColumn();
tlc.Name = name;
tlc.Source = new TextLoaderRange[sources.Length];
DataKind dk;
switch (memberInfo)
{
case FieldInfo field:
if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
throw Contracts.Except($"Field {name} is of unsupported type.");
break;
case PropertyInfo property:
if (!TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
throw Contracts.Except($"Property {name} is of unsupported type.");
break;
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
tlc.Type = dk;
for (int indexLocal = 0; indexLocal < tlc.Source.Length; indexLocal++)
{
tlc.Source[indexLocal] = new TextLoaderRange
{
AllOther = sources[indexLocal].AllOther,
AutoEnd = sources[indexLocal].AutoEnd,
ForceVector = sources[indexLocal].ForceVector,
VariableEnd = sources[indexLocal].VariableEnd,
Max = sources[indexLocal].Max,
Min = sources[indexLocal].Min
};
}
Arguments.Column[index] = tlc;
}
Arguments.HasHeader = useHeader;
Arguments.Separator = new[] { separator };
Arguments.AllowQuoting = allowQuotedStrings;
Arguments.AllowSparse = supportSparse;
Arguments.TrimWhitespace = trimWhitespace;
return this;
}
/// <summary>
/// Try to map a System.Type to a corresponding DataKind value.
/// </summary>
private static bool TryGetDataKind(Type type, out DataKind kind)
{
Contracts.AssertValue(type);
// REVIEW: Make this more efficient. Should we have a global dictionary?
if (type == typeof(sbyte))
kind = DataKind.I1;
else if (type == typeof(byte) || type == typeof(char))
kind = DataKind.U1;
else if (type == typeof(short))
kind = DataKind.I2;
else if (type == typeof(ushort))
kind = DataKind.U2;
else if (type == typeof(int))
kind = DataKind.I4;
else if (type == typeof(uint))
kind = DataKind.U4;
else if (type == typeof(long))
kind = DataKind.I8;
else if (type == typeof(ulong))
kind = DataKind.U8;
else if (type == typeof(Single))
kind = DataKind.R4;
else if (type == typeof(Double))
kind = DataKind.R8;
else if (type == typeof(ReadOnlyMemory<char>) || type == typeof(string))
kind = DataKind.TX;
else if (type == typeof(bool))
kind = DataKind.BL;
else if (type == typeof(TimeSpan))
kind = DataKind.TS;
else if (type == typeof(DateTime))
kind = DataKind.DT;
else if (type == typeof(DateTimeOffset))
kind = DataKind.DZ;
else if (type == typeof(RowId))
kind = DataKind.UG;
else
{
kind = default(DataKind);
return false;
}
return true;
}
}
}