chore: Update GNG implementation and requirements
feat: Update GNG implementation in gng2.py and gng.py to include a progress bar during training. Also, update requirements.txt.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
import plotly.graph_objects as plt
|
||||
from progress.bar import IncrementalBar
|
||||
|
||||
class Neuron:
|
||||
def __init__(self, position):
|
||||
@ -31,6 +31,7 @@ class GrowingNeuralGas:
|
||||
self.edges.append(Edge(self.nodes[0], self.nodes[1]))
|
||||
|
||||
def fit(self, X, num_iterations=1000):
|
||||
bar = IncrementalBar("Training", max=num_iterations)
|
||||
for iteration in range(num_iterations):
|
||||
# Step 1: Select a random input sample
|
||||
x = X[np.random.randint(len(X))]
|
||||
@ -78,6 +79,9 @@ class GrowingNeuralGas:
|
||||
# Step 10: Decrease all errors
|
||||
for node in self.nodes:
|
||||
node.error *= self.delta
|
||||
|
||||
bar.next()
|
||||
bar.finish()
|
||||
|
||||
def find_edge(self, node1, node2):
|
||||
for edge in self.edges:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import dash
|
||||
import plotly.graph_objects as go
|
||||
from dash import html, dcc
|
||||
from dash import html, dcc, callback, Input, Output
|
||||
import dash_bootstrap_components as dbc
|
||||
from gng2 import GrowingNeuralGas
|
||||
from sklearn import datasets as sk
|
||||
|
||||
@ -10,34 +11,47 @@ dash.register_page(__name__, path="/gng", title="GNG", name="GNG")
|
||||
X, _ = sk.make_moons(n_samples=200, noise=0.1)
|
||||
|
||||
# Create and fit the GNG model
|
||||
gng = GrowingNeuralGas(input_dim=2)
|
||||
gng.fit(X, num_iterations=2000)
|
||||
|
||||
fig = go.Figure()
|
||||
for edge in gng.edges:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[edge.nodes[0].position[0], edge.nodes[1].position[0]],
|
||||
y=[edge.nodes[0].position[1], edge.nodes[1].position[1]],
|
||||
mode="lines",
|
||||
line=dict(width=2, color="black"),
|
||||
layout = html.Div([dbc.Button("Generate GNG", id="generate-gng"), dcc.Graph(id="gng-graph")])
|
||||
|
||||
@callback(Output("gng-graph", "figure"), [Input("generate-gng", "n_clicks")])
|
||||
def generate_gng(n_clicks):
|
||||
if n_clicks is None:
|
||||
return go.Figure().update_layout(
|
||||
showlegend=False,
|
||||
margin=dict(l=0, r=0, t=0, b=0),
|
||||
xaxis=dict(visible=False),
|
||||
yaxis=dict(visible=False),
|
||||
)
|
||||
)
|
||||
for node in gng.nodes:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[node.position[0]],
|
||||
y=[node.position[1]],
|
||||
mode="markers",
|
||||
marker=dict(size=10, color="red"),
|
||||
|
||||
gng = GrowingNeuralGas(input_dim=2)
|
||||
gng.fit(X, num_iterations=20000)
|
||||
|
||||
fig = go.Figure()
|
||||
for edge in gng.edges:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[edge.nodes[0].position[0], edge.nodes[1].position[0]],
|
||||
y=[edge.nodes[0].position[1], edge.nodes[1].position[1]],
|
||||
mode="lines",
|
||||
line=dict(width=2, color="white"),
|
||||
)
|
||||
)
|
||||
for node in gng.nodes:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[node.position[0]],
|
||||
y=[node.position[1]],
|
||||
mode="markers",
|
||||
marker=dict(size=10, color="red"),
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
showlegend=False,
|
||||
margin=dict(l=0, r=0, t=0, b=0),
|
||||
xaxis=dict(visible=False),
|
||||
yaxis=dict(visible=False),
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
showlegend=False,
|
||||
margin=dict(l=0, r=0, t=0, b=0),
|
||||
xaxis=dict(visible=False),
|
||||
yaxis=dict(visible=False),
|
||||
)
|
||||
|
||||
layout = html.Div(dcc.Graph(figure=fig))
|
||||
return fig
|
||||
|
Reference in New Issue
Block a user