cleanup code.

This commit is contained in:
Holger Börchers 2021-12-09 17:12:32 +01:00
parent bb2c5af1fa
commit a23e6458bc
3 changed files with 38 additions and 22 deletions

View File

@ -1,7 +1,7 @@
using System.Diagnostics; using ScottPlot;
using System.Diagnostics;
using System.Drawing; using System.Drawing;
using System.Globalization; using System.Globalization;
using ScottPlot;
namespace kMeans; namespace kMeans;
@ -33,11 +33,9 @@ public static class Helper
return plot; return plot;
} }
public static void ExportAndPreviewPlot(Plot plot, string path) public static void ExportPlot(Plot plot, string path)
{ {
plot.SaveFig(path); plot.SaveFig(path);
using var p = Process.Start(new ProcessStartInfo(path) { UseShellExecute = true });
p?.WaitForExit();
} }
public static IEnumerable<Point> ParseCsv(string path) public static IEnumerable<Point> ParseCsv(string path)
@ -51,4 +49,10 @@ public static class Helper
yield return new Point(x, y, -1); yield return new Point(x, y, -1);
} }
} }
public static void PreviewPlot(string path)
{
using var p = Process.Start(new ProcessStartInfo(path) { UseShellExecute = true });
p?.WaitForExit();
}
} }

View File

@ -3,6 +3,14 @@
public record Point(double X, double Y, int ClusterId) public record Point(double X, double Y, int ClusterId)
{ {
public int ClusterId { get; set; } = ClusterId; public int ClusterId { get; set; } = ClusterId;
public double X { get; set; } = X; public double X { get; private set; } = X;
public double Y { get; set; } = Y; public double Y { get; private set; } = Y;
public bool SetCoordinates(double x, double y)
{
if (!(Math.Abs(X - x) > 1e-5) || !(Math.Abs(Y - y) > 1e-5)) return false;
X = x;
Y = y;
return true;
}
} }

View File

@ -6,6 +6,7 @@ public static class Program
{ {
private const string Coordinates = "coordinates.csv"; private const string Coordinates = "coordinates.csv";
private const int K = 3; private const int K = 3;
private const int MaxLoops = 30;
private const string PlotOut = "plot.png"; private const string PlotOut = "plot.png";
public static void Main() public static void Main()
@ -13,19 +14,14 @@ public static class Program
Console.WriteLine("k-Means-Algorithm"); Console.WriteLine("k-Means-Algorithm");
var points = ParseCsv(Coordinates).ToList(); var points = ParseCsv(Coordinates).ToList();
var centroids = InitializeRandomlyCentroids(points, K).ToList(); KMeansCalculation(points, K, out var centroids);
var update = true;
while (update)
{
AssignPointsToCentroids(points, centroids);
update = UpdateCentroids(points, centroids);
}
var plot = CreatePlot(points, centroids); var plot = CreatePlot(points, centroids);
ExportAndPreviewPlot(plot, PlotOut); ExportPlot(plot, PlotOut);
PreviewPlot(PlotOut);
} }
private static void AssignPointsToCentroids(List<Point> points, List<Point> centroids) private static void AssignPointsToCentroids(IEnumerable<Point> points, IReadOnlyCollection<Point> centroids)
{ {
foreach (var point in points) foreach (var point in points)
{ {
@ -49,7 +45,7 @@ public static class Program
return Math.Sqrt(Math.Pow(centroid.X - point.X, 2) + Math.Pow(centroid.Y - point.Y, 2)); return Math.Sqrt(Math.Pow(centroid.X - point.X, 2) + Math.Pow(centroid.Y - point.Y, 2));
} }
private static IEnumerable<Point> InitializeRandomlyCentroids(IReadOnlyCollection<Point> points, int k) private static IEnumerable<Point> InitializeRandomCentroids(IReadOnlyCollection<Point> points, int k)
{ {
var minX = points.Min(p => p.X); var minX = points.Min(p => p.X);
var maxX = points.Max(p => p.X); var maxX = points.Max(p => p.X);
@ -67,7 +63,18 @@ public static class Program
} }
} }
private static bool UpdateCentroids(IReadOnlyCollection<Point> points, List<Point> centroids) private static void KMeansCalculation(IReadOnlyCollection<Point> points, int k, out IReadOnlyCollection<Point> centroids)
{
centroids = InitializeRandomCentroids(points, k).ToArray();
for (var i = 0; i < MaxLoops; i++)
{
AssignPointsToCentroids(points, centroids);
if (!UpdateCentroids(points, centroids)) break;
}
}
private static bool UpdateCentroids(IReadOnlyCollection<Point> points, IEnumerable<Point> centroids)
{ {
// calculate mean of all points of one cluster. // calculate mean of all points of one cluster.
var updated = false; var updated = false;
@ -87,9 +94,6 @@ public static class Program
var sumY = pointsOfCluster.Sum(p => p.Y); var sumY = pointsOfCluster.Sum(p => p.Y);
var meanY = sumY / pointsOfCluster.Length; var meanY = sumY / pointsOfCluster.Length;
if (!(Math.Abs(centroid.X - meanX) > 1e-5) || !(Math.Abs(centroid.Y - meanY) > 1e-5)) return false; return centroid.SetCoordinates(meanX, meanY);
centroid.X = meanX;
centroid.Y = meanY;
return true;
} }
} }