#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release #![allow(incomplete_features)] #![feature(generic_const_exprs)] mod algo; use crate::algo::{gen_images, ComparationOperatorsModel}; use bgtu_ai_utility::gui::{boot_eframe, labeled_slider}; use eframe::egui; use eframe::egui::Widget; use egui_extras::{Column, TableBuilder}; use std::cmp::min; fn main() -> eframe::Result { return boot_eframe( "Neural net", || MyApp::new() ); } enum TrainingState { NoTrain, Training { epochs_left: usize, epochs_per_update: usize, n: f64, }, } struct MyApp { model: ComparationOperatorsModel<7, 7, 8>, training: TrainingState, hidden_layer_size: usize, n: f64, epochs_count: usize, symbols: [&'static str; 8], compute_result: [f64; 8], } impl MyApp { fn new() -> Self { let imgs = gen_images(); return Self { model: ComparationOperatorsModel::new(imgs.1), training: TrainingState::NoTrain, hidden_layer_size: 1, n: 0.1, epochs_count: 1, symbols: imgs.0, compute_result: [0.0; 8], }; } } impl eframe::App for MyApp { fn update(&mut self, ui: &eframe::egui::Context, _frame: &mut eframe::Frame) { egui::CentralPanel::default().show(ui, |ui| { ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| { labeled_slider(ui, "Hidden layer size", &mut self.hidden_layer_size, 1..=49, 1f64); if (self.hidden_layer_size != self.model.hidden_layer_size()) { self.model.resize_hidden_layer(self.hidden_layer_size); } ui.label(""); labeled_slider(ui, "η", &mut self.n, 0.0..=1.0, 0.001); ui.label(""); labeled_slider(ui, "Epochs count", &mut self.epochs_count, 1..=500, 1f64); ui.label(""); ui.horizontal(|ui| { if ui.button("Train").clicked() { self.training = TrainingState::Training { epochs_left: self.epochs_count, epochs_per_update: min(10, self.epochs_count / 60), n: 0.0, }; self.model.set_random_weights(&mut rand::rng()); } match &mut self.training { TrainingState::NoTrain => { egui::ProgressBar::new(1.0).ui(ui); } TrainingState::Training { epochs_left, epochs_per_update, n, } => { if *epochs_left >= *epochs_per_update { for _ in 0..*epochs_per_update { self.model.train_epoch(*n); } *epochs_left -= *epochs_per_update; } else { for _ in 0..*epochs_left { self.model.train_epoch(*n); } *epochs_left = 0; } ui.add_enabled_ui(true, |ui| { egui::ProgressBar::new( 1.0 - (*epochs_left as f32) / (self.epochs_count as f32), ) .ui(ui) }); if 0 == *epochs_left { self.training = TrainingState::NoTrain; } } } }) }); ui.label(""); ui.horizontal(|ui| { for (i, s) in self.symbols.iter().enumerate() { if ui.button(s.to_owned()).clicked() { self.model.set_input_from_etalon(i); } } }); ui.label(""); for y in 0..7 { ui.horizontal(|ui| { for x in 0..7 { ui.checkbox(&mut self.model.input[y][x], ""); } }); } ui.label(""); ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| { if ui.button("Check").clicked() { self.compute_result = self.model.calculate_predictions(); } }); TableBuilder::new(ui) .striped(true) // Alternating row colors .resizable(true) .column(Column::remainder()) .column(Column::remainder()) .header(20.0, |mut header| { header.col(|ui| { ui.label("Class"); }); header.col(|ui| { ui.label("Probability"); }); }) .body(|body| { body.rows(20.0, self.symbols.len(), |mut row| { let i = row.index(); row.col(|ui| { ui.label(self.symbols[i]); }); row.col(|ui| { ui.label(self.compute_result[i].to_string()); }); }); }); }); if matches!(self.training, TrainingState::Training {..}) { ui.request_repaint() } } }