chore: Update GNG implementation and requirements
This commit is contained in:
@ -31,7 +31,7 @@ class GrowingNeuralGas:
|
||||
self.edges.append(Edge(self.nodes[0], self.nodes[1]))
|
||||
|
||||
def fit(self, X, num_iterations=1000):
|
||||
bar = IncrementalBar("Training", max=num_iterations)
|
||||
bar = IncrementalBar("Training", max=num_iterations, suffix="%(percent)d%% - %(eta)ds/%(elapsed)ds")
|
||||
for iteration in range(num_iterations):
|
||||
# Step 1: Select a random input sample
|
||||
x = X[np.random.randint(len(X))]
|
||||
|
@ -8,50 +8,138 @@ from sklearn import datasets as sk
|
||||
dash.register_page(__name__, path="/gng", title="GNG", name="GNG")
|
||||
|
||||
# Generate synthetic data
|
||||
X, _ = sk.make_moons(n_samples=200, noise=0.1)
|
||||
X, Y = sk.make_moons(n_samples=200, noise=0.1, random_state=0)
|
||||
|
||||
|
||||
clics = None
|
||||
|
||||
# Create and fit the GNG model
|
||||
|
||||
layout = html.Div([dbc.Button("Generate GNG", id="generate-gng"), dcc.Graph(id="gng-graph")])
|
||||
layout = html.Div(
|
||||
[
|
||||
dbc.Row(
|
||||
[
|
||||
dbc.Col(
|
||||
dbc.Button("Generate GNG", id="generate-gng"),
|
||||
),
|
||||
dbc.Col(
|
||||
[
|
||||
dbc.Label("Noise: "),
|
||||
dbc.Input(
|
||||
placeholder="Noise",
|
||||
value=0.1,
|
||||
type="number",
|
||||
id="noise",
|
||||
step=0.1,
|
||||
),
|
||||
]
|
||||
),
|
||||
dbc.Col(
|
||||
[
|
||||
dbc.Label("Iterations: "),
|
||||
dbc.Input(
|
||||
placeholder="Iterations",
|
||||
value=1000,
|
||||
type="number",
|
||||
id="iterations",
|
||||
),
|
||||
]
|
||||
),
|
||||
dbc.Col(
|
||||
[
|
||||
dbc.Label("Max Nodes: "),
|
||||
dbc.Input(
|
||||
placeholder="Max Nodes",
|
||||
value=200,
|
||||
type="number",
|
||||
id="nodes",
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
dcc.Graph(id="base-graph"),
|
||||
dcc.Graph(id="gng-graph"),
|
||||
]
|
||||
)
|
||||
|
||||
@callback(Output("gng-graph", "figure"), [Input("generate-gng", "n_clicks")])
|
||||
def generate_gng(n_clicks):
|
||||
if n_clicks is None:
|
||||
return go.Figure().update_layout(
|
||||
showlegend=False,
|
||||
margin=dict(l=0, r=0, t=0, b=0),
|
||||
xaxis=dict(visible=False),
|
||||
yaxis=dict(visible=False),
|
||||
)
|
||||
|
||||
@callback(
|
||||
[Output("gng-graph", "figure"), Output("base-graph", "figure")],
|
||||
[
|
||||
Input("generate-gng", "n_clicks"),
|
||||
Input("noise", "value"),
|
||||
Input("iterations", "value"),
|
||||
Input("nodes", "value"),
|
||||
],
|
||||
)
|
||||
def generate_gng(n_clicks, noise, iterations, nodes):
|
||||
global clics
|
||||
X, Y = sk.make_moons(n_samples=200, noise=noise, random_state=0)
|
||||
|
||||
gng = GrowingNeuralGas(input_dim=2)
|
||||
gng.fit(X, num_iterations=20000)
|
||||
|
||||
fig = go.Figure()
|
||||
for edge in gng.edges:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[edge.nodes[0].position[0], edge.nodes[1].position[0]],
|
||||
y=[edge.nodes[0].position[1], edge.nodes[1].position[1]],
|
||||
mode="lines",
|
||||
line=dict(width=2, color="white"),
|
||||
)
|
||||
print(X, Y)
|
||||
|
||||
fig2 = go.Figure()
|
||||
fig2.add_trace(
|
||||
go.Scatter(
|
||||
x=X[Y == 0, 0], y=X[Y == 0, 1], mode="markers", marker=dict(color="blue")
|
||||
)
|
||||
for node in gng.nodes:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[node.position[0]],
|
||||
y=[node.position[1]],
|
||||
mode="markers",
|
||||
marker=dict(size=10, color="red"),
|
||||
)
|
||||
)
|
||||
fig2.add_trace(
|
||||
go.Scatter(
|
||||
x=X[Y == 1, 0], y=X[Y == 1, 1], mode="markers", marker=dict(color="red")
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
)
|
||||
fig2.update_layout(
|
||||
showlegend=False,
|
||||
margin=dict(l=0, r=0, t=0, b=0),
|
||||
xaxis=dict(visible=False),
|
||||
yaxis=dict(visible=False),
|
||||
title="Base Data",
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
if n_clicks != clics:
|
||||
gng = GrowingNeuralGas(input_dim=2, max_nodes=nodes)
|
||||
gng.fit(X, num_iterations=iterations)
|
||||
|
||||
fig = go.Figure()
|
||||
for edge in gng.edges:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[edge.nodes[0].position[0], edge.nodes[1].position[0]],
|
||||
y=[edge.nodes[0].position[1], edge.nodes[1].position[1]],
|
||||
mode="lines",
|
||||
line=dict(width=2, color="white"),
|
||||
)
|
||||
)
|
||||
for node in gng.nodes:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[node.position[0]],
|
||||
y=[node.position[1]],
|
||||
mode="markers",
|
||||
marker=dict(size=10, color="red"),
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
showlegend=False,
|
||||
margin=dict(l=0, r=0, t=0, b=0),
|
||||
xaxis=dict(visible=False),
|
||||
yaxis=dict(visible=False),
|
||||
title="GNG Model",
|
||||
)
|
||||
|
||||
clics = n_clicks
|
||||
|
||||
return [fig, fig2]
|
||||
|
||||
return [
|
||||
go.Figure().update_layout(
|
||||
showlegend=False,
|
||||
margin=dict(l=0, r=0, t=0, b=0),
|
||||
xaxis=dict(visible=False),
|
||||
yaxis=dict(visible=False),
|
||||
),
|
||||
fig2,
|
||||
]
|
||||
|
Reference in New Issue
Block a user