finalized

This commit is contained in:
Holger Börchers 2021-12-08 23:17:11 +01:00
parent ccef445d4e
commit bb2c5af1fa
3 changed files with 103 additions and 79 deletions

54
src/Helper.cs Normal file
View File

@ -0,0 +1,54 @@
using System.Diagnostics;
using System.Drawing;
using System.Globalization;
using ScottPlot;
namespace kMeans;
public static class Helper
{
public static Plot CreatePlot(IEnumerable<Point> points, IEnumerable<Point> centroids)
{
var plot = new Plot();
plot.Legend(true, Alignment.UpperRight);
var colors = new Dictionary<int, Color>();
var colorGroups = points.GroupBy(x => x.ClusterId);
foreach (var clusterGroup in colorGroups)
{
var color = plot.GetNextColor();
colors.Add(clusterGroup.Key, color);
var xs = clusterGroup.Select(p => p.X).ToArray();
var ys = clusterGroup.Select(p => p.Y).ToArray();
plot.AddScatterPoints(xs, ys, color);
}
const MarkerShape marker = MarkerShape.cross;
const float size = 10f;
foreach (var (x, y, clusterId) in centroids)
{
var color = colors.TryGetValue(clusterId, out var c) ? c : plot.GetNextColor();
plot.AddScatterPoints(new[] { x }, new[] { y }, color, size, marker, $"ClusterId: {clusterId}");
}
return plot;
}
public static void ExportAndPreviewPlot(Plot plot, string path)
{
plot.SaveFig(path);
using var p = Process.Start(new ProcessStartInfo(path) { UseShellExecute = true });
p?.WaitForExit();
}
public static IEnumerable<Point> ParseCsv(string path)
{
var lines = File.ReadLines(path);
foreach (var line in lines)
{
var current = line.Split(',');
if (!double.TryParse(current[0], NumberStyles.Any, CultureInfo.InvariantCulture, out var x)) continue;
if (!double.TryParse(current[1], NumberStyles.Any, CultureInfo.InvariantCulture, out var y)) continue;
yield return new Point(x, y, -1);
}
}
}

8
src/Point.cs Normal file
View File

@ -0,0 +1,8 @@
namespace kMeans;
public record Point(double X, double Y, int ClusterId)
{
public int ClusterId { get; set; } = ClusterId;
public double X { get; set; } = X;
public double Y { get; set; } = Y;
}

View File

@ -1,7 +1,4 @@
using ScottPlot; using static kMeans.Helper;
using System.Diagnostics;
using System.Drawing;
using System.Globalization;
namespace kMeans; namespace kMeans;
@ -13,78 +10,43 @@ public static class Program
public static void Main() public static void Main()
{ {
Console.WriteLine("k-Means-Algorithmus"); Console.WriteLine("k-Means-Algorithm");
var points = ParseCsv(Coordinates).ToList(); var points = ParseCsv(Coordinates).ToList();
var centroids = InitializeRandomlyCentroids(points, K).ToList();
AssignPointsToCentroids(points, centroids);
UpdateCentroids(points, centroids);
var plot = CreatePlot(points, centroids);
ExportAndPreviewPlot(plot);
}
private static bool UpdateCentroids(List<Point> points, List<Point> centroids) var centroids = InitializeRandomlyCentroids(points, K).ToList();
{ var update = true;
// calculate mean of all points of one cluster. while (update)
foreach (var centroid in centroids)
{ {
var clusterPoints = points.Where(x => x.ClusterId == centroid.ClusterId); AssignPointsToCentroids(points, centroids);
double newMeanX = 0; update = UpdateCentroids(points, centroids);
double newMeanY = 0;
foreach (var cluster in clusterPoints)
{
centroid.X = 3;
}
} }
return false; var plot = CreatePlot(points, centroids);
ExportAndPreviewPlot(plot, PlotOut);
} }
private static void AssignPointsToCentroids(List<Point> points, List<Point> centroids) private static void AssignPointsToCentroids(List<Point> points, List<Point> centroids)
{ {
foreach (var point in points) foreach (var point in points)
{ {
var id = 0;
var distance = double.MaxValue;
foreach (var centroid in centroids) foreach (var centroid in centroids)
{ {
// calculate euclid distance and assign id. // calculate euclid distance and assign id.
var currentDistance = Distance(centroid, point);
if (currentDistance > distance) continue;
distance = currentDistance;
id = centroid.ClusterId;
} }
point.ClusterId = id;
} }
} }
private static Plot CreatePlot(IEnumerable<Point> points, IEnumerable<Point> centroids) private static double Distance(Point centroid, Point point)
{ {
var sw = Stopwatch.StartNew(); return Math.Sqrt(Math.Pow(centroid.X - point.X, 2) + Math.Pow(centroid.Y - point.Y, 2));
var plot = new Plot();
var colors = new Dictionary<int, Color>();
var colorGroups = points.GroupBy(x => x.ClusterId);
foreach (var clusterGroup in colorGroups)
{
var color = plot.GetNextColor();
colors.Add(clusterGroup.Key, color);
var xs = clusterGroup.Select(p => p.X).ToArray();
var ys = clusterGroup.Select(p => p.Y).ToArray();
plot.AddScatterPoints(xs, ys, color);
}
const MarkerShape marker = MarkerShape.cross;
const float size = 10f;
foreach (var centroid in centroids)
{
var color = colors.TryGetValue(centroid.ClusterId, out var c) ? c : plot.GetNextColor();
plot.AddScatterPoints(new[] { centroid.X }, new[] { centroid.Y }, color, size, marker);
}
sw.Stop();
Console.WriteLine($"[{nameof(CreatePlot)}] Elapsed: {sw.ElapsedMilliseconds}ms");
return plot;
}
private static void ExportAndPreviewPlot(Plot plot)
{
var sw = Stopwatch.StartNew();
plot.SaveFig(PlotOut);
Process.Start(new ProcessStartInfo(PlotOut) { UseShellExecute = true });
sw.Stop();
Console.WriteLine($"[{nameof(ExportAndPreviewPlot)}] Elapsed: {sw.ElapsedMilliseconds}ms");
} }
private static IEnumerable<Point> InitializeRandomlyCentroids(IReadOnlyCollection<Point> points, int k) private static IEnumerable<Point> InitializeRandomlyCentroids(IReadOnlyCollection<Point> points, int k)
@ -99,35 +61,35 @@ public static class Program
{ {
var x = (minX + maxX) * rnd.NextDouble(); var x = (minX + maxX) * rnd.NextDouble();
var y = (minY + maxY) * rnd.NextDouble(); var y = (minY + maxY) * rnd.NextDouble();
yield return new Point(x, y, true, i); var point = new Point(x, y, i);
Console.WriteLine(point);
yield return point;
} }
} }
private static IEnumerable<Point> ParseCsv(string path) private static bool UpdateCentroids(IReadOnlyCollection<Point> points, List<Point> centroids)
{ {
var lines = File.ReadLines(path); // calculate mean of all points of one cluster.
foreach (var line in lines) var updated = false;
foreach (var centroid in centroids)
{ {
var current = line.Split(','); updated |= UpdateClusterCentroid(centroid, points);
if (!double.TryParse(current[0], NumberStyles.Any, CultureInfo.InvariantCulture, out var x)) continue;
if (!double.TryParse(current[1], NumberStyles.Any, CultureInfo.InvariantCulture, out var y)) continue;
yield return new Point(x, y, false, -1);
} }
}
}
public class Point return updated;
{ }
public Point(double x, double y, bool focusPoint, int clusterId)
private static bool UpdateClusterCentroid(Point centroid, IEnumerable<Point> points)
{ {
X = x; var pointsOfCluster = points.Where(p => p.ClusterId == centroid.ClusterId).ToArray();
Y = y; var sumX = pointsOfCluster.Sum(p => p.X);
FocusPoint = focusPoint; var meanX = sumX / pointsOfCluster.Length;
ClusterId = clusterId; var sumY = pointsOfCluster.Sum(p => p.Y);
} var meanY = sumY / pointsOfCluster.Length;
public int ClusterId { get; set; } if (!(Math.Abs(centroid.X - meanX) > 1e-5) || !(Math.Abs(centroid.Y - meanY) > 1e-5)) return false;
public bool FocusPoint { get; set; } centroid.X = meanX;
public double X { get; set; } centroid.Y = meanY;
public double Y { get; set; } return true;
}
} }