175 lines
5.8 KiB
Rust
175 lines
5.8 KiB
Rust
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release
|
|
#![allow(incomplete_features)]
|
|
#![feature(generic_const_exprs)]
|
|
|
|
mod algo;
|
|
|
|
use crate::algo::{gen_images, ComparationOperatorsModel};
|
|
use bgtu_ai_utility::gui::{boot_eframe, labeled_slider};
|
|
use eframe::egui;
|
|
use eframe::egui::Widget;
|
|
use egui_extras::{Column, TableBuilder};
|
|
use std::cmp::min;
|
|
|
|
fn main() -> eframe::Result {
|
|
return boot_eframe(
|
|
"Neural net",
|
|
|| MyApp::new()
|
|
);
|
|
}
|
|
|
|
|
|
enum TrainingState {
|
|
NoTrain,
|
|
Training {
|
|
epochs_left: usize,
|
|
epochs_per_update: usize,
|
|
n: f64,
|
|
},
|
|
}
|
|
|
|
struct MyApp {
|
|
model: ComparationOperatorsModel<7, 7, 8>,
|
|
training: TrainingState,
|
|
hidden_layer_size: usize,
|
|
n: f64,
|
|
epochs_count: usize,
|
|
symbols: [&'static str; 8],
|
|
compute_result: [f64; 8],
|
|
}
|
|
|
|
impl MyApp {
|
|
fn new() -> Self {
|
|
let imgs = gen_images();
|
|
return Self {
|
|
model: ComparationOperatorsModel::new(imgs.1),
|
|
training: TrainingState::NoTrain,
|
|
hidden_layer_size: 1,
|
|
n: 0.1,
|
|
epochs_count: 1,
|
|
symbols: imgs.0,
|
|
compute_result: [0.0; 8],
|
|
};
|
|
}
|
|
}
|
|
|
|
|
|
impl eframe::App for MyApp {
|
|
fn update(&mut self, ui: &eframe::egui::Context, _frame: &mut eframe::Frame) {
|
|
egui::CentralPanel::default().show(ui, |ui| {
|
|
ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| {
|
|
labeled_slider(ui, "Hidden layer size", &mut self.hidden_layer_size, 1..=49, 1f64);
|
|
if (self.hidden_layer_size != self.model.hidden_layer_size()) {
|
|
self.model.resize_hidden_layer(self.hidden_layer_size);
|
|
}
|
|
ui.label("");
|
|
labeled_slider(ui, "η", &mut self.n, 0.0..=1.0, 0.001);
|
|
ui.label("");
|
|
labeled_slider(ui, "Epochs count", &mut self.epochs_count, 1..=500, 1f64);
|
|
ui.label("");
|
|
|
|
ui.horizontal(|ui| {
|
|
if ui.button("Train").clicked() {
|
|
self.training = TrainingState::Training {
|
|
epochs_left: self.epochs_count,
|
|
epochs_per_update: min(10, self.epochs_count / 60),
|
|
n: 0.0,
|
|
};
|
|
self.model.set_random_weights(&mut rand::rng());
|
|
}
|
|
|
|
match &mut self.training {
|
|
TrainingState::NoTrain => {
|
|
egui::ProgressBar::new(1.0).ui(ui);
|
|
}
|
|
TrainingState::Training {
|
|
epochs_left,
|
|
epochs_per_update,
|
|
n,
|
|
} => {
|
|
if *epochs_left >= *epochs_per_update {
|
|
for _ in 0..*epochs_per_update {
|
|
self.model.train_epoch(*n);
|
|
}
|
|
*epochs_left -= *epochs_per_update;
|
|
} else {
|
|
for _ in 0..*epochs_left {
|
|
self.model.train_epoch(*n);
|
|
}
|
|
*epochs_left = 0;
|
|
}
|
|
|
|
ui.add_enabled_ui(true, |ui| {
|
|
egui::ProgressBar::new(
|
|
1.0 - (*epochs_left as f32) / (self.epochs_count as f32),
|
|
)
|
|
.ui(ui)
|
|
});
|
|
|
|
if 0 == *epochs_left {
|
|
self.training = TrainingState::NoTrain;
|
|
}
|
|
}
|
|
}
|
|
})
|
|
});
|
|
|
|
ui.label("");
|
|
|
|
ui.horizontal(|ui| {
|
|
for (i, s) in self.symbols.iter().enumerate() {
|
|
if ui.button(s.to_owned()).clicked() {
|
|
self.model.set_input_from_etalon(i);
|
|
}
|
|
}
|
|
});
|
|
|
|
ui.label("");
|
|
|
|
for y in 0..7 {
|
|
ui.horizontal(|ui| {
|
|
for x in 0..7 {
|
|
ui.checkbox(&mut self.model.input[y][x], "");
|
|
}
|
|
});
|
|
}
|
|
|
|
ui.label("");
|
|
ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| {
|
|
if ui.button("Check").clicked() {
|
|
self.compute_result = self.model.calculate_predictions();
|
|
}
|
|
});
|
|
|
|
TableBuilder::new(ui)
|
|
.striped(true) // Alternating row colors
|
|
.resizable(true)
|
|
.column(Column::remainder())
|
|
.column(Column::remainder())
|
|
.header(20.0, |mut header| {
|
|
header.col(|ui| {
|
|
ui.label("Class");
|
|
});
|
|
header.col(|ui| {
|
|
ui.label("Probability");
|
|
});
|
|
})
|
|
.body(|body| {
|
|
body.rows(20.0, self.symbols.len(), |mut row| {
|
|
let i = row.index();
|
|
row.col(|ui| {
|
|
ui.label(self.symbols[i]);
|
|
});
|
|
row.col(|ui| {
|
|
ui.label(self.compute_result[i].to_string());
|
|
});
|
|
});
|
|
});
|
|
});
|
|
|
|
if matches!(self.training, TrainingState::Training {..}) {
|
|
ui.request_repaint()
|
|
}
|
|
}
|
|
}
|