[lab4] Model implementation
This commit is contained in:
parent
5c8e1f9dc2
commit
ef711b5e95
@ -1,8 +1,8 @@
|
||||
use std::ops::Index;
|
||||
|
||||
fn compute_potential<WA: Index<usize, Output = f64>>(
|
||||
weights: &[WA],
|
||||
pub(super) fn compute<WA: Index<usize, Output = f64>>(
|
||||
input_data: &[f64],
|
||||
weights: &[WA],
|
||||
potential_data: &mut [f64],
|
||||
output_data: &mut [f64],
|
||||
f: impl Fn(f64) -> f64,
|
||||
@ -28,7 +28,7 @@ fn test_compiles_static() {
|
||||
let mut potential_data = [0.0, 0.0];
|
||||
let mut output_data = [0.0, 0.0];
|
||||
|
||||
compute_potential(&weights, &input_data, &mut potential_data, &mut output_data, |x| x);
|
||||
compute(&input_data,&weights, &mut potential_data, &mut output_data, |x| x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -38,5 +38,5 @@ fn test_compiles_dynamic() {
|
||||
let mut potential_data = [0.0, 0.0];
|
||||
let mut output_data = [0.0, 0.0];
|
||||
|
||||
compute_potential(&weights, &input_data, &mut potential_data, &mut output_data, |x| x);
|
||||
compute(&input_data, &weights, &mut potential_data, &mut output_data, |x| x);
|
||||
}
|
||||
@ -1,12 +1,12 @@
|
||||
use std::ops::{Index, IndexMut};
|
||||
|
||||
fn calc_error<NWA: Index<usize, Output = f64>>(
|
||||
pub(super) fn calc_error<NWA: Index<usize, Output = f64>>(
|
||||
next_errors: &[f64],
|
||||
weights: &[NWA],
|
||||
next_weights: &[NWA],
|
||||
current_errors: &mut [f64],
|
||||
) {
|
||||
for i in 0..current_errors.len() {
|
||||
current_errors[i] = weights
|
||||
current_errors[i] = next_weights
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, ww)| ww[i] * next_errors[j])
|
||||
@ -14,19 +14,19 @@ fn calc_error<NWA: Index<usize, Output = f64>>(
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_error<NWA: IndexMut<usize, Output = f64>>(
|
||||
pub(super) fn apply_error<NWA: IndexMut<usize, Output = f64>>(
|
||||
n: f64,
|
||||
errors: &[f64],
|
||||
weights: &mut [NWA],
|
||||
current_potentials: &[f64],
|
||||
next_errors: &[f64],
|
||||
next_weights: &mut [NWA],
|
||||
current_outputs: &[f64],
|
||||
next_potentials: &[f64],
|
||||
f: impl Fn(f64) -> f64,
|
||||
f1: impl Fn(f64) -> f64,
|
||||
) {
|
||||
for i in 0..current_potentials.len() {
|
||||
for i in 0..current_outputs.len() {
|
||||
for j in 0..next_potentials.len() {
|
||||
let dw = n * errors[j] * f1(next_potentials[j]) * f(current_potentials[i]);
|
||||
weights[j][i] += dw;
|
||||
let dw = n * next_errors[j] * f1(next_potentials[j]) * current_outputs[i];
|
||||
next_weights[j][i] += dw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,2 +1,3 @@
|
||||
mod compute;
|
||||
mod fix;
|
||||
mod net;
|
||||
|
||||
75
lab4/src/algo/net.rs
Normal file
75
lab4/src/algo/net.rs
Normal file
@ -0,0 +1,75 @@
|
||||
use crate::algo::compute::compute;
|
||||
use crate::algo::fix::{apply_error, calc_error};
|
||||
use std::ops::{Div, Mul};
|
||||
|
||||
struct Net<const InputLayerSize: usize, const OutputLayerSize: usize> {
|
||||
inner_layer_size: usize,
|
||||
pub input_data: [f64; InputLayerSize],
|
||||
i_to_h_weights: Vec<[f64; InputLayerSize]>,
|
||||
hidden_potentials: Vec<f64>,
|
||||
hidden_data: Vec<f64>,
|
||||
h_to_o_weights: [Vec<f64>; OutputLayerSize],
|
||||
output_potentials: [f64; OutputLayerSize],
|
||||
pub output_data: [f64; OutputLayerSize],
|
||||
output_errors: [f64; OutputLayerSize],
|
||||
hidden_errors: Vec<f64>,
|
||||
}
|
||||
|
||||
impl<const InputLayerSize: usize, const OutputLayerSize: usize>
|
||||
Net<InputLayerSize, OutputLayerSize>
|
||||
{
|
||||
fn _sigmoid(x: f64) -> f64 {
|
||||
return 1.0.div(1.0 + (-x).exp());
|
||||
}
|
||||
|
||||
fn _sigmoidDerivative(x: f64) -> f64 {
|
||||
return x.mul(1.0 - x);
|
||||
}
|
||||
|
||||
pub fn compute(&mut self) {
|
||||
compute(
|
||||
&self.input_data,
|
||||
self.i_to_h_weights.as_slice(),
|
||||
self.hidden_potentials.as_mut_slice(),
|
||||
self.hidden_data.as_mut_slice(),
|
||||
Self::_sigmoid,
|
||||
);
|
||||
|
||||
compute(
|
||||
self.hidden_data.as_slice(),
|
||||
&self.h_to_o_weights,
|
||||
&mut self.output_potentials,
|
||||
&mut self.output_data,
|
||||
Self::_sigmoid,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn fix(&mut self, expected: &[f64; OutputLayerSize], n: f64) {
|
||||
for (i, (a, e)) in self.output_data.iter().zip(expected).enumerate() {
|
||||
self.output_errors[i] = e - a;
|
||||
}
|
||||
|
||||
calc_error(
|
||||
&self.output_errors,
|
||||
&self.h_to_o_weights,
|
||||
self.hidden_errors.as_mut_slice(),
|
||||
);
|
||||
|
||||
apply_error(
|
||||
n,
|
||||
&self.output_errors,
|
||||
&mut self.h_to_o_weights,
|
||||
self.hidden_data.as_slice(),
|
||||
&self.output_potentials,
|
||||
Self::_sigmoidDerivative
|
||||
);
|
||||
apply_error(
|
||||
n,
|
||||
self.hidden_errors.as_slice(),
|
||||
self.i_to_h_weights.as_mut_slice(),
|
||||
&self.input_data,
|
||||
self.hidden_potentials.as_slice(),
|
||||
Self::_sigmoidDerivative
|
||||
);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user