/* * Copyright (c) 2023, Alliance for Open Media. All rights reserved. * * This source code is subject to the terms of the BSD 2 Clause License and * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License * was not distributed with this source code in the LICENSE file, you can * obtain it at www.aomedia.org/license/software. If the Alliance for Open * Media Patent License 1.0 was not distributed with this source code in the * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
staticinlinevoid nn_propagate_8to4( constfloat *const inputs, constfloat *const weights, constfloat *const bias, int num_inputs_to_process, int tot_num_inputs, int num_outputs, float *const output_nodes, int is_clip_required) {
__m256 hadd[2]; for (int out = 0; out < num_outputs; out += 4) {
__m128 bias_reg = _mm_loadu_ps(&bias[out]);
__m128 in_result = _mm_setzero_ps(); for (int in = 0; in < num_inputs_to_process; in += 8) { const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]); constint weight_idx = in + (out * tot_num_inputs); // Process two output row at a time. for (int i = 0; i < 2; i++) {
CALC_OUTPUT_FOR_2ROWS
}
// Process input multiple of 8 using AVX2 intrinsic. if (num_inputs % 8 == 0) {
nn_propagate_input_multiple_of_8(input_nodes, layer_weights, layer_bias,
num_inputs, num_inputs, is_output_layer,
num_outputs, output_nodes);
} else { // When number of inputs is not multiple of 8, use hybrid approach of AVX2 // and SSE3 based on the need. constint in_mul_8 = num_inputs / 8; constint num_inputs_to_process = in_mul_8 * 8; int bias_is_considered = 0; if (in_mul_8) {
nn_propagate_input_multiple_of_8(
input_nodes, layer_weights, layer_bias, num_inputs_to_process,
num_inputs, is_output_layer, num_outputs, output_nodes);
bias_is_considered = 1;
}
constfloat *out_temp = bias_is_considered ? output_nodes : layer_bias; constint input_remaining = num_inputs % 8; if (input_remaining % 4 == 0 && num_outputs % 8 == 0) { for (int out = 0; out < num_outputs; out += 8) {
__m128 out_h = _mm_loadu_ps(&out_temp[out + 4]);
__m128 out_l = _mm_loadu_ps(&out_temp[out]); for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
av1_nn_propagate_4to8_sse3(&input_nodes[in],
&layer_weights[out * num_inputs + in],
&out_h, &out_l, num_inputs);
} if (!is_output_layer) { const __m128 zero = _mm_setzero_ps();
out_h = _mm_max_ps(out_h, zero);
out_l = _mm_max_ps(out_l, zero);
}
_mm_storeu_ps(&output_nodes[out + 4], out_h);
_mm_storeu_ps(&output_nodes[out], out_l);
}
} elseif (input_remaining % 4 == 0 && num_outputs % 4 == 0) { for (int out = 0; out < num_outputs; out += 4) {
__m128 outputs = _mm_loadu_ps(&out_temp[out]); for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
av1_nn_propagate_4to4_sse3(&input_nodes[in],
&layer_weights[out * num_inputs + in],
&outputs, num_inputs);
} if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
_mm_storeu_ps(&output_nodes[out], outputs);
}
} elseif (input_remaining % 4 == 0) { for (int out = 0; out < num_outputs; out++) {
__m128 outputs = _mm_load1_ps(&out_temp[out]); for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
av1_nn_propagate_4to1_sse3(&input_nodes[in],
&layer_weights[out * num_inputs + in],
&outputs);
} if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
output_nodes[out] = _mm_cvtss_f32(outputs);
}
} else { // Use SSE instructions for scalar operations to avoid the latency // of swapping between SIMD and FPU modes. for (int out = 0; out < num_outputs; out++) {
__m128 outputs = _mm_load1_ps(&out_temp[out]); for (int in_node = in_mul_8 * 8; in_node < num_inputs; in_node++) {
__m128 input = _mm_load1_ps(&input_nodes[in_node]);
__m128 weight =
_mm_load1_ps(&layer_weights[num_inputs * out + in_node]);
outputs = _mm_add_ps(outputs, _mm_mul_ps(input, weight));
} if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
output_nodes[out] = _mm_cvtss_f32(outputs);
}
}
} // Before processing the next layer, treat the output of current layer as // input to next layer.
input_nodes = output_nodes;
num_inputs = num_outputs;
buf_index = 1 - buf_index;
} if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
}
Messung V0.5
¤ Dauer der Verarbeitung: 0.1 Sekunden
(vorverarbeitet)
¤
Die Informationen auf dieser Webseite wurden
nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit,
noch Qualität der bereit gestellten Informationen zugesichert.
Bemerkung:
Die farbliche Syntaxdarstellung und die Messung sind noch experimentell.