[lab4] Model implementation

This commit is contained in:
Andrew Golovashevich 2026-02-08 21:40:19 +03:00
parent 5c8e1f9dc2
commit ef711b5e95
4 changed files with 91 additions and 15 deletions

View File

@ -1,8 +1,8 @@
use std::ops::Index;
fn compute_potential<WA: Index<usize, Output = f64>>(
weights: &[WA],
pub(super) fn compute<WA: Index<usize, Output = f64>>(
input_data: &[f64],
weights: &[WA],
potential_data: &mut [f64],
output_data: &mut [f64],
f: impl Fn(f64) -> f64,
@ -28,7 +28,7 @@ fn test_compiles_static() {
let mut potential_data = [0.0, 0.0];
let mut output_data = [0.0, 0.0];
compute_potential(&weights, &input_data, &mut potential_data, &mut output_data, |x| x);
compute(&input_data,&weights, &mut potential_data, &mut output_data, |x| x);
}
#[test]
@ -38,5 +38,5 @@ fn test_compiles_dynamic() {
let mut potential_data = [0.0, 0.0];
let mut output_data = [0.0, 0.0];
compute_potential(&weights, &input_data, &mut potential_data, &mut output_data, |x| x);
compute(&input_data, &weights, &mut potential_data, &mut output_data, |x| x);
}

View File

@ -1,12 +1,12 @@
use std::ops::{Index, IndexMut};
fn calc_error<NWA: Index<usize, Output = f64>>(
pub(super) fn calc_error<NWA: Index<usize, Output = f64>>(
next_errors: &[f64],
weights: &[NWA],
next_weights: &[NWA],
current_errors: &mut [f64],
) {
for i in 0..current_errors.len() {
current_errors[i] = weights
current_errors[i] = next_weights
.iter()
.enumerate()
.map(|(j, ww)| ww[i] * next_errors[j])
@ -14,19 +14,19 @@ fn calc_error<NWA: Index<usize, Output = f64>>(
}
}
fn apply_error<NWA: IndexMut<usize, Output = f64>>(
pub(super) fn apply_error<NWA: IndexMut<usize, Output = f64>>(
n: f64,
errors: &[f64],
weights: &mut [NWA],
current_potentials: &[f64],
next_errors: &[f64],
next_weights: &mut [NWA],
current_outputs: &[f64],
next_potentials: &[f64],
f: impl Fn(f64) -> f64,
f1: impl Fn(f64) -> f64,
) {
for i in 0..current_potentials.len() {
for i in 0..current_outputs.len() {
for j in 0..next_potentials.len() {
let dw = n * errors[j] * f1(next_potentials[j]) * f(current_potentials[i]);
weights[j][i] += dw;
let dw = n * next_errors[j] * f1(next_potentials[j]) * current_outputs[i];
next_weights[j][i] += dw;
}
}
}

View File

@ -1,2 +1,3 @@
mod compute;
mod fix;
mod net;

75
lab4/src/algo/net.rs Normal file
View File

@ -0,0 +1,75 @@
use crate::algo::compute::compute;
use crate::algo::fix::{apply_error, calc_error};
use std::ops::{Div, Mul};
struct Net<const InputLayerSize: usize, const OutputLayerSize: usize> {
inner_layer_size: usize,
pub input_data: [f64; InputLayerSize],
i_to_h_weights: Vec<[f64; InputLayerSize]>,
hidden_potentials: Vec<f64>,
hidden_data: Vec<f64>,
h_to_o_weights: [Vec<f64>; OutputLayerSize],
output_potentials: [f64; OutputLayerSize],
pub output_data: [f64; OutputLayerSize],
output_errors: [f64; OutputLayerSize],
hidden_errors: Vec<f64>,
}
impl<const InputLayerSize: usize, const OutputLayerSize: usize>
Net<InputLayerSize, OutputLayerSize>
{
fn _sigmoid(x: f64) -> f64 {
return 1.0.div(1.0 + (-x).exp());
}
fn _sigmoidDerivative(x: f64) -> f64 {
return x.mul(1.0 - x);
}
pub fn compute(&mut self) {
compute(
&self.input_data,
self.i_to_h_weights.as_slice(),
self.hidden_potentials.as_mut_slice(),
self.hidden_data.as_mut_slice(),
Self::_sigmoid,
);
compute(
self.hidden_data.as_slice(),
&self.h_to_o_weights,
&mut self.output_potentials,
&mut self.output_data,
Self::_sigmoid,
)
}
pub fn fix(&mut self, expected: &[f64; OutputLayerSize], n: f64) {
for (i, (a, e)) in self.output_data.iter().zip(expected).enumerate() {
self.output_errors[i] = e - a;
}
calc_error(
&self.output_errors,
&self.h_to_o_weights,
self.hidden_errors.as_mut_slice(),
);
apply_error(
n,
&self.output_errors,
&mut self.h_to_o_weights,
self.hidden_data.as_slice(),
&self.output_potentials,
Self::_sigmoidDerivative
);
apply_error(
n,
self.hidden_errors.as_slice(),
self.i_to_h_weights.as_mut_slice(),
&self.input_data,
self.hidden_potentials.as_slice(),
Self::_sigmoidDerivative
);
}
}