KMeans/KMeansBase/KMeans.cs
2022-01-22 13:27:08 +01:00

77 lines
2.5 KiB
C#

namespace KMeansBase;
public static class KMeans
{
private const int MaxLoops = 30;
public static void KMeansCalculation(IReadOnlyCollection<Point> points, int k,
out IReadOnlyCollection<Point> centroids)
{
centroids = InitializeRandomCentroids(points, k).ToArray();
for (var i = 0; i < MaxLoops; i++)
{
AssignPointsToCentroids(points, centroids);
if (!UpdateCentroids(points, centroids)) break;
}
}
private static bool UpdateCentroids(IReadOnlyCollection<Point> points, IEnumerable<Point> centroids)
{
// calculate mean of all points of one cluster.
var updated = false;
foreach (var centroid in centroids)
{
updated |= UpdateClusterCentroid(centroid, points);
}
return updated;
}
private static bool UpdateClusterCentroid(Point centroid, IEnumerable<Point> points)
{
var pointsOfCluster = points.Where(p => p.ClusterId == centroid.ClusterId).ToArray();
var sumX = pointsOfCluster.Sum(p => p.X);
var meanX = sumX / pointsOfCluster.Length;
var sumY = pointsOfCluster.Sum(p => p.Y);
var meanY = sumY / pointsOfCluster.Length;
return centroid.SetCoordinates(meanX, meanY);
}
private static void AssignPointsToCentroids(IEnumerable<Point> points, IReadOnlyCollection<Point> centroids)
{
foreach (var point in points)
{
var distance = double.MaxValue;
foreach (var centroid in centroids)
{
// calculate euclid distance and assign id.
var currentDistance = Distance(centroid, point);
if (currentDistance > distance) continue;
distance = currentDistance;
point.ClusterId = centroid.ClusterId;
}
}
}
private static double Distance(Point centroid, Point point)
{
return Math.Sqrt(Math.Pow(centroid.X - point.X, 2) + Math.Pow(centroid.Y - point.Y, 2));
}
private static IEnumerable<Point> InitializeRandomCentroids(IReadOnlyCollection<Point> points, int k)
{
var (minX, maxX) = (points.Min(p => p.X), points.Max(p => p.X));
var (minY, maxY) = (points.Min(p => p.Y), points.Max(p => p.Y));
var rnd = new Random();
for (var i = 0; i < k; i++)
{
var x = (minX + maxX) * rnd.NextDouble();
var y = (minY + maxY) * rnd.NextDouble();
var point = new Point(x, y, i);
yield return point;
}
}
}