ai-0/lab4/src/main.rs

175 lines
5.8 KiB
Rust

#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
mod algo;
use crate::algo::{gen_images, ComparationOperatorsModel};
use bgtu_ai_utility::gui::{boot_eframe, labeled_slider};
use eframe::egui;
use eframe::egui::Widget;
use egui_extras::{Column, TableBuilder};
use std::cmp::min;
fn main() -> eframe::Result {
return boot_eframe(
"Neural net",
|| MyApp::new()
);
}
enum TrainingState {
NoTrain,
Training {
epochs_left: usize,
epochs_per_update: usize,
n: f64,
},
}
struct MyApp {
model: ComparationOperatorsModel<7, 7, 8>,
training: TrainingState,
hidden_layer_size: usize,
n: f64,
epochs_count: usize,
symbols: [&'static str; 8],
compute_result: [f64; 8],
}
impl MyApp {
fn new() -> Self {
let imgs = gen_images();
return Self {
model: ComparationOperatorsModel::new(imgs.1),
training: TrainingState::NoTrain,
hidden_layer_size: 1,
n: 0.1,
epochs_count: 1,
symbols: imgs.0,
compute_result: [0.0; 8],
};
}
}
impl eframe::App for MyApp {
fn update(&mut self, ui: &eframe::egui::Context, _frame: &mut eframe::Frame) {
egui::CentralPanel::default().show(ui, |ui| {
ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| {
labeled_slider(ui, "Hidden layer size", &mut self.hidden_layer_size, 1..=49, 1f64);
if (self.hidden_layer_size != self.model.hidden_layer_size()) {
self.model.resize_hidden_layer(self.hidden_layer_size);
}
ui.label("");
labeled_slider(ui, "η", &mut self.n, 0.0..=1.0, 0.001);
ui.label("");
labeled_slider(ui, "Epochs count", &mut self.epochs_count, 1..=500, 1f64);
ui.label("");
ui.horizontal(|ui| {
if ui.button("Train").clicked() {
self.training = TrainingState::Training {
epochs_left: self.epochs_count,
epochs_per_update: min(10, self.epochs_count / 60),
n: 0.0,
};
self.model.set_random_weights(&mut rand::rng());
}
match &mut self.training {
TrainingState::NoTrain => {
egui::ProgressBar::new(1.0).ui(ui);
}
TrainingState::Training {
epochs_left,
epochs_per_update,
n,
} => {
if *epochs_left >= *epochs_per_update {
for _ in 0..*epochs_per_update {
self.model.train_epoch(*n);
}
*epochs_left -= *epochs_per_update;
} else {
for _ in 0..*epochs_left {
self.model.train_epoch(*n);
}
*epochs_left = 0;
}
ui.add_enabled_ui(true, |ui| {
egui::ProgressBar::new(
1.0 - (*epochs_left as f32) / (self.epochs_count as f32),
)
.ui(ui)
});
if 0 == *epochs_left {
self.training = TrainingState::NoTrain;
}
}
}
})
});
ui.label("");
ui.horizontal(|ui| {
for (i, s) in self.symbols.iter().enumerate() {
if ui.button(s.to_owned()).clicked() {
self.model.set_input_from_etalon(i);
}
}
});
ui.label("");
for y in 0..7 {
ui.horizontal(|ui| {
for x in 0..7 {
ui.checkbox(&mut self.model.input[y][x], "");
}
});
}
ui.label("");
ui.add_enabled_ui(matches!(self.training, TrainingState::NoTrain), |ui| {
if ui.button("Check").clicked() {
self.compute_result = self.model.calculate_predictions();
}
});
TableBuilder::new(ui)
.striped(true) // Alternating row colors
.resizable(true)
.column(Column::remainder())
.column(Column::remainder())
.header(20.0, |mut header| {
header.col(|ui| {
ui.label("Class");
});
header.col(|ui| {
ui.label("Probability");
});
})
.body(|body| {
body.rows(20.0, self.symbols.len(), |mut row| {
let i = row.index();
row.col(|ui| {
ui.label(self.symbols[i]);
});
row.col(|ui| {
ui.label(self.compute_result[i].to_string());
});
});
});
});
if matches!(self.training, TrainingState::Training {..}) {
ui.request_repaint()
}
}
}