1 Commits

Author SHA1 Message Date
bastien ollier
c8cf0fe045 add stats 2024-06-25 08:33:32 +02:00
4 changed files with 105 additions and 135 deletions

View File

@@ -1,40 +1,25 @@
from sklearn.cluster import DBSCAN, KMeans from sklearn.cluster import DBSCAN, KMeans
import numpy as np import numpy as np
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Any, Optional
@dataclass class DBSCAN_cluster():
class ClusterResult: def __init__(self, eps, min_samples,data):
labels: np.array
centers: Optional[np.array]
statistics: list[dict[str, Any]]
class Cluster(ABC):
@abstractmethod
def run(self, data: np.array) -> ClusterResult:
pass
class DBSCANCluster(Cluster):
def __init__(self, eps: float = 0.5, min_samples: int = 5):
self.eps = eps self.eps = eps
self.min_samples = min_samples self.min_samples = min_samples
self.data = data
self.labels = np.array([])
#@typing.override def run(self):
def run(self, data: np.array) -> ClusterResult:
dbscan = DBSCAN(eps=self.eps, min_samples=self.min_samples) dbscan = DBSCAN(eps=self.eps, min_samples=self.min_samples)
labels = dbscan.fit_predict(data) self.labels = dbscan.fit_predict(self.data)
return ClusterResult(labels, None, self.get_statistics(data, labels)) return self.labels
def get_statistics(self, data: np.array, labels: np.array) -> list[dict[str, Any]]: def get_stats(self):
unique_labels = np.unique(labels) unique_labels = np.unique(self.labels)
stats = [] stats = []
for label in unique_labels: for label in unique_labels:
if label == -1: if label == -1:
continue continue
cluster_points = data[labels == label] cluster_points = self.data[self.labels == label]
num_points = len(cluster_points) num_points = len(cluster_points)
density = num_points / (np.max(cluster_points, axis=0) - np.min(cluster_points, axis=0)).prod() density = num_points / (np.max(cluster_points, axis=0) - np.min(cluster_points, axis=0)).prod()
stats.append({ stats.append({
@@ -42,42 +27,37 @@ class DBSCANCluster(Cluster):
"num_points": num_points, "num_points": num_points,
"density": density "density": density
}) })
return stats return stats
def __str__(self) -> str:
return "DBScan"
class KMeans_cluster():
class KMeansCluster(Cluster): def __init__(self, n_clusters, n_init, max_iter, data):
def __init__(self, n_clusters: int = 8, n_init: int = 1, max_iter: int = 300):
self.n_clusters = n_clusters self.n_clusters = n_clusters
self.n_init = n_init self.n_init = n_init
self.max_iter = max_iter self.max_iter = max_iter
self.data = data
self.labels = np.array([])
self.centers = []
#@typing.override def run(self):
def run(self, data: np.array) -> ClusterResult:
kmeans = KMeans(n_clusters=self.n_clusters, init="random", n_init=self.n_init, max_iter=self.max_iter, random_state=111) kmeans = KMeans(n_clusters=self.n_clusters, init="random", n_init=self.n_init, max_iter=self.max_iter, random_state=111)
labels = kmeans.fit_predict(data) self.labels = kmeans.fit_predict(self.data)
centers = kmeans.cluster_centers_ self.centers = kmeans.cluster_centers_
return ClusterResult(labels, centers, self.get_statistics(data, labels, centers)) return self.labels
def get_statistics(self, data: np.array, labels: np.array, centers: np.array) -> list[dict[str, Any]]:
unique_labels = np.unique(labels) def get_stats(self):
unique_labels = np.unique(self.labels)
stats = [] stats = []
for label in unique_labels: for label in unique_labels:
cluster_points = data[labels == label] cluster_points = self.data[self.labels == label]
num_points = len(cluster_points) num_points = len(cluster_points)
center = centers[label] center = self.centers[label]
stats.append({ stats.append({
"cluster": label, 'cluster': label,
"num_points": num_points, 'num_points': num_points,
"center": center, 'center': center
}) })
return stats return stats
def __str__(self) -> str:
return "KMeans"
CLUSTERING_STRATEGIES = [DBSCANCluster(), KMeansCluster()]

View File

@@ -1,86 +0,0 @@
import streamlit as st
import matplotlib.pyplot as plt
from clusters import DBSCANCluster, KMeansCluster, CLUSTERING_STRATEGIES
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
import numpy as np
st.header("Clustering")
if "data" in st.session_state:
data = st.session_state.data
general_row = st.columns([1, 1, 1])
clustering = general_row[0].selectbox("Clustering method", CLUSTERING_STRATEGIES)
data_name = general_row[1].multiselect("Columns", data.select_dtypes(include="number").columns)
n_components = general_row[2].number_input("Reduce dimensions to (PCA)", min_value=1, max_value=3, value=2)
with st.form("cluster_form"):
if isinstance(clustering, KMeansCluster):
row1 = st.columns([1, 1, 1])
clustering.n_clusters = row1[0].number_input("Number of clusters", min_value=1, max_value=data.shape[0], value=clustering.n_clusters)
clustering.n_init = row1[1].number_input("n_init", min_value=1, value=clustering.n_init)
clustering.max_iter = row1[2].number_input("max_iter", min_value=1, value=clustering.max_iter)
elif isinstance(clustering, DBSCANCluster):
row1 = st.columns([1, 1])
clustering.eps = row1[0].slider("eps", min_value=0.0001, max_value=1.0, step=0.05, value=clustering.eps)
clustering.min_samples = row1[1].number_input("min_samples", min_value=1, value=clustering.min_samples)
st.form_submit_button("Launch")
if len(data_name) > 0:
x = data[data_name].to_numpy()
n_components = min(n_components, len(data_name))
if len(data_name) > n_components:
pca = PCA(n_components)
x = pca.fit_transform(x)
if n_components == 2:
(fig, ax) = plt.subplots(figsize=(8, 8))
for i in range(0, pca.components_.shape[1]):
ax.arrow(
0,
0,
pca.components_[0, i],
pca.components_[1, i],
head_width=0.1,
head_length=0.1
)
plt.text(
pca.components_[0, i] + 0.05,
pca.components_[1, i] + 0.05,
data_name[i]
)
circle = plt.Circle((0, 0), radius=1, edgecolor='b', facecolor='None')
ax.add_patch(circle)
plt.axis("equal")
ax.set_title("PCA result - Correlation circle")
st.pyplot(fig)
result = clustering.run(x)
st.write("## Cluster stats")
st.table(result.statistics)
st.write("## Graphical representation")
fig = plt.figure()
if n_components == 1:
plt.scatter(x, np.zeros_like(x))
elif n_components == 2:
ax = fig.add_subplot(projection='rectilinear')
plt.scatter(x[:, 0], x[:, 1], c=result.labels, s=50, cmap="viridis")
if result.centers is not None:
plt.scatter(result.centers[:, 0], result.centers[:, 1], c="black", s=200, marker="X")
else:
ax = fig.add_subplot(projection='3d')
ax.scatter(x[:, 0], x[:, 1],x[:, 2], c=result.labels, s=50, cmap="viridis")
if result.centers is not None:
ax.scatter(result.centers[:, 0], result.centers[:, 1], result.centers[:, 2], c="black", s=200, marker="X")
st.pyplot(fig)
if not (result.labels == 0).all():
st.write("Silhouette score:", silhouette_score(x, result.labels))
else:
st.error("Select at least one column")
else:
st.error("file not loaded")

View File

@@ -0,0 +1,32 @@
import streamlit as st
import matplotlib.pyplot as plt
from clusters import DBSCAN_cluster
st.header("Clustering: dbscan")
if "data" in st.session_state:
data = st.session_state.data
with st.form("my_form"):
data_name = st.multiselect("Data Name", data.select_dtypes(include="number").columns, max_selections=3)
eps = st.slider("eps", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
min_samples = st.number_input("min_samples", step=1, min_value=1, value=5)
st.form_submit_button("launch")
if len(data_name) >= 2 and len(data_name) <=3:
x = data[data_name].to_numpy()
dbscan = DBSCAN_cluster(eps,min_samples,x)
y_dbscan = dbscan.run()
st.table(dbscan.get_stats())
fig = plt.figure()
if len(data_name) == 2:
ax = fig.add_subplot(projection='rectilinear')
plt.scatter(x[:, 0], x[:, 1], c=y_dbscan, s=50, cmap="viridis")
else:
ax = fig.add_subplot(projection='3d')
ax.scatter(x[:, 0], x[:, 1],x[:, 2], c=y_dbscan, s=50, cmap="viridis")
st.pyplot(fig)
else:
st.error("file not loaded")

View File

@@ -0,0 +1,44 @@
import streamlit as st
import matplotlib.pyplot as plt
from clusters import KMeans_cluster
st.header("Clustering: kmeans")
if "data" in st.session_state:
data = st.session_state.data
with st.form("my_form"):
row1 = st.columns([1,1,1])
n_clusters = row1[0].selectbox("Number of clusters", range(1,data.shape[0]))
data_name = row1[1].multiselect("Data Name",data.select_dtypes(include="number").columns, max_selections=3)
n_init = row1[2].number_input("n_init",step=1,min_value=1)
row2 = st.columns([1,1])
max_iter = row1[0].number_input("max_iter",step=1,min_value=1)
st.form_submit_button("launch")
if len(data_name) >= 2 and len(data_name) <=3:
x = data[data_name].to_numpy()
kmeans = KMeans_cluster(n_clusters, n_init, max_iter, x)
y_kmeans = kmeans.run()
st.table(kmeans.get_stats())
centers = kmeans.centers
fig = plt.figure()
if len(data_name) == 2:
ax = fig.add_subplot(projection='rectilinear')
plt.scatter(x[:, 0], x[:, 1], c=y_kmeans, s=50, cmap="viridis")
plt.scatter(centers[:, 0], centers[:, 1], c="black", s=200, marker="X")
else:
ax = fig.add_subplot(projection='3d')
ax.scatter(x[:, 0], x[:, 1],x[:, 2], c=y_kmeans, s=50, cmap="viridis")
ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], c="black", s=200, marker="X")
st.pyplot(fig)
else:
st.error("file not loaded")