diff --git a/XamlBrewer.Uwp.MachineLearningSample/Models/Automation/AutomationModel.cs b/XamlBrewer.Uwp.MachineLearningSample/Models/Automation/AutomationModel.cs index fc9af11..0d72f0d 100644 --- a/XamlBrewer.Uwp.MachineLearningSample/Models/Automation/AutomationModel.cs +++ b/XamlBrewer.Uwp.MachineLearningSample/Models/Automation/AutomationModel.cs @@ -76,7 +76,7 @@ public void SetUpExperiment() { MaxExperimentTimeInSeconds = 180, OptimizingMetric = MulticlassClassificationMetric.LogLoss, - CacheDirectory = null + CacheDirectoryName = null, }; // These two trainers yield no metrics in UWP: @@ -105,7 +105,7 @@ public void HyperParameterize() { MaxExperimentTimeInSeconds = 180, OptimizingMetric = MulticlassClassificationMetric.LogLoss, - CacheDirectory = null + CacheDirectoryName = null }; // There can be only one. diff --git a/XamlBrewer.Uwp.MachineLearningSample/Models/Regression/RegressionModel.cs b/XamlBrewer.Uwp.MachineLearningSample/Models/Regression/RegressionModel.cs index 7f1f90e..36ee5a5 100644 --- a/XamlBrewer.Uwp.MachineLearningSample/Models/Regression/RegressionModel.cs +++ b/XamlBrewer.Uwp.MachineLearningSample/Models/Regression/RegressionModel.cs @@ -41,9 +41,9 @@ public IEnumerable Load(string trainingDataPath) return _mlContext.Data.CreateEnumerable(trainingData, reuseRowObject: false); } - public void BuildAndTrain() + public void BuildAndTrain(string regressionTrainer) { - var pipeline = _mlContext.Transforms.ReplaceMissingValues("Age", "Age", MissingValueReplacingEstimator.ReplacementMode.Mean) + var prepipeline = _mlContext.Transforms.ReplaceMissingValues("Age", "Age", MissingValueReplacingEstimator.ReplacementMode.Mean) .Append(_mlContext.Transforms.ReplaceMissingValues("Ws", "Ws", MissingValueReplacingEstimator.ReplacementMode.Mean)) .Append(_mlContext.Transforms.ReplaceMissingValues("Bmp", "Bmp", MissingValueReplacingEstimator.ReplacementMode.Mean)) .Append(_mlContext.Transforms.ReplaceMissingValues("NBA_DraftNumber", "NBA_DraftNumber", MissingValueReplacingEstimator.ReplacementMode.Mean)) @@ -53,15 +53,43 @@ public void BuildAndTrain() .Append(_mlContext.Transforms.NormalizeMeanVariance("Bmp", "Bmp")) .Append(_mlContext.Transforms.Concatenate( "Features", - new[] { "NBA_DraftNumber", "Age", "Ws", "Bmp" })) - // .Append(_mlContext.Regression.Trainers.FastTree()); // PlatformNotSupportedException - // .Append(_mlContext.Regression.Trainers.OnlineGradientDescent(new OnlineGradientDescentTrainer.Options { })); // InvalidOperationException if you don't normalize. - // .Append(_mlContext.Regression.Trainers.StochasticDualCoordinateAscent()); - // .Append(_mlContext.Regression.Trainers.PoissonRegression()); - .Append(_mlContext.Regression.Trainers.Gam()); - - Model = pipeline.Fit(trainingData); - + new[] { "NBA_DraftNumber", "Age", "Ws", "Bmp" })); + // .Append(_mlContext.Regression.Trainers.FastTree()); // PlatformNotSupportedException + // .Append(_mlContext.Regression.Trainers.OnlineGradientDescent(new OnlineGradientDescentTrainer.Options { })); // InvalidOperationException if you don't normalize. + // .Append(_mlContext.Regression.Trainers.StochasticDualCoordinateAscent()); + // .Append(_mlContext.Regression.Trainers.PoissonRegression()); + //.Append(_mlContext.Regression.Trainers.Gam()); + switch (regressionTrainer) + { + //case "FastTree": // PlatformNotSupportedException + // var pipelineFastTree = prepipeline.Append(_mlContext.Regression.Trainers.FastTree()); + // Model = pipelineFastTree.Fit(trainingData); + // break; + //case "FastTreeTweedie": // PlatformNotSupportedException + // var pipelineFastTreeTweedie = prepipeline.Append(_mlContext.Regression.Trainers.FastTreeTweedie()); + // Model = pipelineFastTreeTweedie.Fit(trainingData); + // break; + case "Gam": + var pipelineGam = prepipeline.Append(_mlContext.Regression.Trainers.Gam()); + Model = pipelineGam.Fit(trainingData); + break; + case "LightGbm": + var pipelineLightGbm = prepipeline.Append(_mlContext.Regression.Trainers.LightGbm()); + Model = pipelineLightGbm.Fit(trainingData); + break; + case "Ols": + var pipelineOls = prepipeline.Append(_mlContext.Regression.Trainers.Ols()); + Model = pipelineOls.Fit(trainingData); + break; + case "Sdca": + var pipelineSdca = prepipeline.Append(_mlContext.Regression.Trainers.Sdca()); + Model = pipelineSdca.Fit(trainingData); + break; + default: + var pipeline = prepipeline.Append(_mlContext.Regression.Trainers.Gam()); + Model = pipeline.Fit(trainingData); + break; + } predictionEngine = _mlContext.Model.CreatePredictionEngine(Model); } diff --git a/XamlBrewer.Uwp.MachineLearningSample/ViewModels/RegressionPageViewModel.cs b/XamlBrewer.Uwp.MachineLearningSample/ViewModels/RegressionPageViewModel.cs index 94c9067..6a001e2 100644 --- a/XamlBrewer.Uwp.MachineLearningSample/ViewModels/RegressionPageViewModel.cs +++ b/XamlBrewer.Uwp.MachineLearningSample/ViewModels/RegressionPageViewModel.cs @@ -17,11 +17,11 @@ public Task> Load(string trainingDataPath) }); } - public Task BuildAndTrain() + public Task BuildAndTrain(string regressionTrainer) { return Task.Run(() => { - _model.BuildAndTrain(); + _model.BuildAndTrain(regressionTrainer); }); } diff --git a/XamlBrewer.Uwp.MachineLearningSample/Views/RegressionPage.xaml b/XamlBrewer.Uwp.MachineLearningSample/Views/RegressionPage.xaml index 582805f..dffe07f 100644 --- a/XamlBrewer.Uwp.MachineLearningSample/Views/RegressionPage.xaml +++ b/XamlBrewer.Uwp.MachineLearningSample/Views/RegressionPage.xaml @@ -105,9 +105,14 @@ IsChecked="False" Grid.Row="4" Grid.Column="1" /> - + + + + DataContext as RegressionPageViewModel; @@ -50,7 +52,7 @@ private async void Page_Loaded(object sender, Windows.UI.Xaml.RoutedEventArgs e) // Create and train the model TrainingBox.IsChecked = true; - await ViewModel.BuildAndTrain(); + await ViewModel.BuildAndTrain(RegressionTrainersCombo.SelectedItem.ToString()); // Save the model. await ViewModel.Save("regressionModel.zip"); @@ -172,5 +174,10 @@ private async void Slider_ValueChanged(object sender, Windows.UI.Xaml.Controls.P Diagram.Model.Annotations.Add(annotation); Diagram.InvalidatePlot(); } + + private void RegressionTrainersCombo_SelectionChanged(object sender, SelectionChangedEventArgs e) + { + } + } } diff --git a/XamlBrewer.Uwp.MachineLearningSample/XamlBrewer.Uwp.MachineLearningSample.csproj b/XamlBrewer.Uwp.MachineLearningSample/XamlBrewer.Uwp.MachineLearningSample.csproj index 536b2ca..5fe9f45 100644 --- a/XamlBrewer.Uwp.MachineLearningSample/XamlBrewer.Uwp.MachineLearningSample.csproj +++ b/XamlBrewer.Uwp.MachineLearningSample/XamlBrewer.Uwp.MachineLearningSample.csproj @@ -333,22 +333,22 @@ - 1.5.1 + 1.7.0 - 0.17.1 + 0.19.0 - 1.5.1 + 1.7.0 - 0.17.1 + 0.19.0 - 6.2.10 + 6.2.13 - 6.1.0 + 7.1.0 2.0.0-unstable1035