ai-0/lab4/src/algo/comparation_operators_model.rs

105 lines
2.6 KiB
Rust

use crate::algo::net::Net;
use rand::Rng;
struct ComparationOperatorsModel<const W: usize, const H: usize, const ClassCount: usize>
where
[(); { W * H }]:,
{
net: Net<{ W * H }, ClassCount>,
pub input: [[bool; W]; H],
etalons: [[[bool; W]; H]; ClassCount],
}
impl<const W: usize, const H: usize, const ClassCount: usize>
ComparationOperatorsModel<W, H, ClassCount>
where
[(); { W * H }]:,
{
pub fn new(etalons: [[[bool; W]; H]; ClassCount]) -> Self {
return Self {
net: Net::new(),
input: [[false; W]; H],
etalons,
};
}
fn set_inputs(net: &mut Net<{ W * H }, ClassCount>, data: &[[bool; W]; H]) {
let mut i = 0;
for y in 0..H {
for x in 0..W {
net.input_data[i] = if data[y][x] { 1.0 } else { 0.0 };
i += 1;
}
}
}
fn train_one(&mut self, n: f64, image: usize) {
assert!(image < self.etalons.len());
Self::set_inputs(&mut self.net, &self.etalons[image]);
self.net.compute();
let mut expected = [0.0; ClassCount];
expected[image] = 1.0;
self.net.fix(&expected, n);
}
pub fn train_epoch(&mut self, n: f64) {
for i in 0..ClassCount {
self.train_one(n, i)
}
}
pub fn train(&mut self, n: f64, epochs: usize) -> impl Iterator<Item = ()> {
return TrainIterator {
owner: self,
n: n,
epochs_left: epochs,
};
}
pub fn hidden_layer_size(&self) -> usize {
return self.net.hidden_layer_size();
}
pub fn resize_hidden_layer(&mut self, new_size: usize) {
self.net.resize_hidden_layer(new_size);
}
pub fn set_random_weights(&mut self, rng: &mut impl Rng) {
self.net.set_random_weights(rng)
}
pub fn calculate_predictions(&mut self) -> [f64; ClassCount] {
Self::set_inputs(&mut self.net, &self.input);
self.net.compute();
return self.net.output_data;
}
}
struct TrainIterator<'a, const W: usize, const H: usize, const ClassCount: usize>
where
[(); { W * H }]:,
{
owner: &'a mut ComparationOperatorsModel<W, H, ClassCount>,
n: f64,
epochs_left: usize,
}
impl<'a, const W: usize, const H: usize, const ClassCount: usize> Iterator
for TrainIterator<'_, W, H, ClassCount>
where
[(); { W * H }]:,
{
type Item = ();
fn next(&mut self) -> Option<Self::Item> {
if self.epochs_left == 0 {
return None;
}
self.owner.train_epoch(self.n);
return Some(());
}
}