資源簡介
k-means的python實現及數據,使用numpy實現了k-means的算法實例

代碼片段和文件信息
import?numpy?as?np
import?matplotlib.pyplot?as?plt
#?返回距離樣本最近的質心的下標索引
def?group_one(sample?centers):
????distance_vect?=?np.sum((sample-centers)**2?axis=1)
????return?np.argmin(distance_vect)
#?將所有樣本分組到k個質心,返回二維列表[[屬于分組1的樣本][屬于分組2的樣本]...]
def?group_all(data?k?centers):
????#?這里使用二維列表,而不是ndarray的原因在于,每個分組的大小,也就是樣本的個數是不確定的
????#?而array是確定大小的,強行轉換這里會變成列表對象的數組,效率低下且更容易出錯
????#?如果有更好的做法歡迎交流
????groups?=?[]
????for?index?in?range(k):
????????groups.append([])
????#?對每一個樣本進行分組
????for?sample?in?data:
????????index?=?group_one(sample?centers)
????????groups[index].append(sample.tolist())
????return?groups
#?根據樣本分組,更新每個質心的位置
def?update_centers(data?k?groups):
????centers?=?np.zeros((k?data.shape[1]))
????for?index?in?range(k):
????????centers[index]?=?np.mean(np.array(groups[index])?axis=0)
????return?centers
#?檢測與上一次迭代的更新差值
def?iter_diff(old_centers?new_centers):
????return?np.sum(np.abs(old_centers?-?new_centers))
#?生成隨機質心
def?rand_center(data?k):
????#?共k個質心,data.shape[1]是每個數據樣本的維度,質心的維度應與樣本的維度一致。
????centers?=?np.random.rand(k?data.shape[1])
????#?rand隨機的范圍是零到一,要適用于樣本的范圍需要進行縮放
????#?這里使用樣本在該維度的最大值作為每個維度上的縮放倍數
????scale?=?np.max(data?axis=0)
????centers?*=?scale
????return?centers
#?迭代主體函數
def?classify(data?k?threshold?max_iter=0):
????centers?=?rand_center(data?k)
????loss?=?float(“inf“)
????iter_count?=?0
????#?當loss小于閾值,或迭代次數大于指定最大次數時(若不指定則只判斷loss足夠低)終止
????while?loss?>?threshold?and?((max_iter?==?0)?or?iter_count?????????groups?=?group_all(data?k?centers)
????????old_centers?=?centers
????????centers?=?update_centers(data?k?groups)
????????loss?=?iter_diff(old_centers?centers)
????????iter_count?+=?1
????????print(“iter_%d?:?loss=%f“?%?(iter_count?loss))
????return?centers?groups
#?繪圖
def?paint_result(data?centers?k?groups?debug=False):
????c?=?[]
????flatten_group?=?[]
????for?index?in?range(k):
????????for?item?in?groups[index]:
????????????c.append(index)
????????????flatten_group.append(item)
????groups?=?np.array(flatten_group)
????if?debug:
????????plt.scatter(groups[:?0]?groups[:?1])
????else:
????????plt.scatter(groups[:?0]?groups[:?1]c=c)
????plt.scatter(centers[:?0]?centers[:?1]?color=“red“)
????plt.show()
def?main():
????data?=?np.loadtxt(“d:/data.csv“?delimiter=““)
????data.resize((500?2))
????center?groups?=?classify(data?3?0?0)
????paint_result(data?center?3?groups)
if?__name__?==?‘__main__‘:
????main()
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件???????25501??2018-03-18?23:56??data.csv
?????文件????????3160??2018-03-19?22:17??kmeans.py
評論
共有 條評論