105 lines
2.6 KiB
Rust
105 lines
2.6 KiB
Rust
use crate::algo::net::Net;
|
|
use rand::Rng;
|
|
|
|
struct ComparationOperatorsModel<const W: usize, const H: usize, const ClassCount: usize>
|
|
where
|
|
[(); { W * H }]:,
|
|
{
|
|
net: Net<{ W * H }, ClassCount>,
|
|
pub input: [[bool; W]; H],
|
|
etalons: [[[bool; W]; H]; ClassCount],
|
|
}
|
|
|
|
impl<const W: usize, const H: usize, const ClassCount: usize>
|
|
ComparationOperatorsModel<W, H, ClassCount>
|
|
where
|
|
[(); { W * H }]:,
|
|
{
|
|
pub fn new(etalons: [[[bool; W]; H]; ClassCount]) -> Self {
|
|
return Self {
|
|
net: Net::new(),
|
|
input: [[false; W]; H],
|
|
etalons,
|
|
};
|
|
}
|
|
|
|
fn set_inputs(net: &mut Net<{ W * H }, ClassCount>, data: &[[bool; W]; H]) {
|
|
let mut i = 0;
|
|
for y in 0..H {
|
|
for x in 0..W {
|
|
net.input_data[i] = if data[y][x] { 1.0 } else { 0.0 };
|
|
i += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn train_one(&mut self, n: f64, image: usize) {
|
|
assert!(image < self.etalons.len());
|
|
|
|
Self::set_inputs(&mut self.net, &self.etalons[image]);
|
|
|
|
self.net.compute();
|
|
|
|
let mut expected = [0.0; ClassCount];
|
|
expected[image] = 1.0;
|
|
|
|
self.net.fix(&expected, n);
|
|
}
|
|
|
|
pub fn train_epoch(&mut self, n: f64) {
|
|
for i in 0..ClassCount {
|
|
self.train_one(n, i)
|
|
}
|
|
}
|
|
|
|
pub fn train(&mut self, n: f64, epochs: usize) -> impl Iterator<Item = ()> {
|
|
return TrainIterator {
|
|
owner: self,
|
|
n: n,
|
|
epochs_left: epochs,
|
|
};
|
|
}
|
|
pub fn hidden_layer_size(&self) -> usize {
|
|
return self.net.hidden_layer_size();
|
|
}
|
|
|
|
pub fn resize_hidden_layer(&mut self, new_size: usize) {
|
|
self.net.resize_hidden_layer(new_size);
|
|
}
|
|
pub fn set_random_weights(&mut self, rng: &mut impl Rng) {
|
|
self.net.set_random_weights(rng)
|
|
}
|
|
|
|
pub fn calculate_predictions(&mut self) -> [f64; ClassCount] {
|
|
Self::set_inputs(&mut self.net, &self.input);
|
|
self.net.compute();
|
|
return self.net.output_data;
|
|
}
|
|
}
|
|
|
|
struct TrainIterator<'a, const W: usize, const H: usize, const ClassCount: usize>
|
|
where
|
|
[(); { W * H }]:,
|
|
{
|
|
owner: &'a mut ComparationOperatorsModel<W, H, ClassCount>,
|
|
n: f64,
|
|
epochs_left: usize,
|
|
}
|
|
|
|
impl<'a, const W: usize, const H: usize, const ClassCount: usize> Iterator
|
|
for TrainIterator<'_, W, H, ClassCount>
|
|
where
|
|
[(); { W * H }]:,
|
|
{
|
|
type Item = ();
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
if self.epochs_left == 0 {
|
|
return None;
|
|
}
|
|
|
|
self.owner.train_epoch(self.n);
|
|
return Some(());
|
|
}
|
|
}
|