Compare commits
8 Commits
clustering
...
main
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f464f6166a | ||
3038bd9841 | |||
![]() |
86bd285193 | ||
![]() |
9bc9e21e45 | ||
![]() |
da1e97f07f | ||
![]() |
27e69b2af8 | ||
![]() |
4054395641 | ||
![]() |
01168f3588 |
@@ -1,9 +1,11 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.metrics import accuracy_score
|
from sklearn.metrics import accuracy_score,confusion_matrix
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
st.header("Prediction: Classification")
|
st.header("Prediction: Classification")
|
||||||
|
|
||||||
@@ -60,5 +62,18 @@ if "data" in st.session_state:
|
|||||||
prediction = label_encoders[target_name].inverse_transform(prediction)
|
prediction = label_encoders[target_name].inverse_transform(prediction)
|
||||||
|
|
||||||
st.write("Prediction:", prediction[0])
|
st.write("Prediction:", prediction[0])
|
||||||
|
|
||||||
|
if len(data_name) == 1:
|
||||||
|
fig = plt.figure()
|
||||||
|
|
||||||
|
y_pred = [model.predict(pd.DataFrame([pred_value[0]], columns=data_name)) for pred_value in X.values.tolist()]
|
||||||
|
cm = confusion_matrix(y, y_pred)
|
||||||
|
|
||||||
|
sns.heatmap(cm, annot=True, fmt="d")
|
||||||
|
|
||||||
|
plt.xlabel('Predicted')
|
||||||
|
plt.ylabel('True')
|
||||||
|
|
||||||
|
st.pyplot(fig)
|
||||||
else:
|
else:
|
||||||
st.error("File not loaded")
|
st.error("File not loaded")
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from sklearn.linear_model import LinearRegression
|
from sklearn.linear_model import LinearRegression
|
||||||
|
from sklearn.metrics import r2_score
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
st.header("Prediction: Regression")
|
st.header("Prediction: Regression")
|
||||||
|
|
||||||
@@ -25,5 +27,37 @@ if "data" in st.session_state:
|
|||||||
prediction = model.predict(pd.DataFrame([pred_values], columns=data_name))
|
prediction = model.predict(pd.DataFrame([pred_values], columns=data_name))
|
||||||
|
|
||||||
st.write("Prediction:", prediction[0])
|
st.write("Prediction:", prediction[0])
|
||||||
|
|
||||||
|
fig = plt.figure()
|
||||||
|
dataframe_sorted = pd.concat([X, y], axis=1).sort_values(by=data_name)
|
||||||
|
|
||||||
|
if len(data_name) == 1:
|
||||||
|
y_pred = [model.predict(pd.DataFrame([pred_value[0]], columns=data_name)) for pred_value in X.values.tolist()]
|
||||||
|
r2 = r2_score(y, y_pred)
|
||||||
|
st.write('R-squared score:', r2)
|
||||||
|
|
||||||
|
X = dataframe_sorted[data_name[0]]
|
||||||
|
y = dataframe_sorted[target_name]
|
||||||
|
|
||||||
|
prediction_array_y = [
|
||||||
|
model.predict(pd.DataFrame([[dataframe_sorted[data_name[0]].iloc[i]]], columns=data_name))[0]
|
||||||
|
for i in range(dataframe_sorted.shape[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
plt.scatter(dataframe_sorted[data_name[0]], dataframe_sorted[target_name], color='b')
|
||||||
|
plt.plot(dataframe_sorted[data_name[0]], prediction_array_y, color='r')
|
||||||
|
elif len(data_name) == 2:
|
||||||
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
|
|
||||||
|
prediction_array_y = [
|
||||||
|
model.predict(pd.DataFrame([[dataframe_sorted[data_name[0]].iloc[i], dataframe_sorted[data_name[1]].iloc[i]]], columns=data_name))[0]
|
||||||
|
for i in range(dataframe_sorted.shape[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
ax.scatter(dataframe_sorted[data_name[0]], dataframe_sorted[data_name[1]], dataframe_sorted[target_name], color='b')
|
||||||
|
ax.plot(dataframe_sorted[data_name[0]], dataframe_sorted[data_name[1]], prediction_array_y, color='r')
|
||||||
|
|
||||||
|
st.pyplot(fig)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
st.error("File not loaded")
|
st.error("File not loaded")
|
||||||
|
Reference in New Issue
Block a user