`timescale 1ns / 1ps// Description : 全连接层
// Change Logs : 2024.05.10 - Yang.Long - 1.0.0 - module nnLinear #(parameter G_WDEPTH									= 12				,//权重深度parameter G_PDEPTH									= 8					,//像素深度parameter G_LINEXLEN								= 160				,//每行图像宽度parameter G_LINEYLEN								= 160				,//每行图像高度  parameter G_FEATURESi								= 120				,//输入的神经元个数parameter G_FEATURESo								= 10				,//输出神经元个数parameter G_FEATURESb								= 1'b1				 //是否包含偏置
)(							input	wire										isysclk				,input	wire										isysrst				,input 	wire										s_axis_ruser		,input 	wire 										s_axis_rvalid		,input 	wire signed	[G_PDEPTH-1:0]					s_axis_rdata		,output 	wire 										s_axis_readen		,input 	wire signed [G_WDEPTH*G_FEATURESo-1:0]		s_axis_weight		,input 	wire signed	[G_WDEPTH*G_FEATURESo-1:0]		s_axis_bias			,output 	reg											m_axis_tuser		,output 	reg											m_axis_tvalid		,output 	reg  signed	[G_PDEPTH*G_FEATURESo-1:0]		m_axis_tdata    
);
/*
import torchlinear = torch.nn.Linear(in_features=3, out_features=5, bias=True)b = torch.tensor([[1, 1, 1]], dtype=torch.float32)out2 = linear(b)print(linear.weight.data)
print(linear.bias.data)
print(out2)b = torch.tensor([[1, 1, 1]], dtype=torch.float32)
tensor([[-0.1069, -0.3522,  0.3378],[ 0.2721,  0.3001,  0.4206],[-0.1825,  0.1193, -0.0052],[-0.1361, -0.3696, -0.3186],[-0.5642,  0.5640,  0.4559]])
tensor([ 0.0126, -0.3215,  0.3172, -0.0352, -0.5045])
tensor([[-0.1088,  0.6713,  0.2488, -0.8595, -0.0489]], grad_fn=<AddmmBackward0>)
*/localparam ACTIVERST = 1'b0;	function integer log2; 			
input integer number;			begin			log2 = 0;			while(2**log2 < number) begin			log2 = log2 + 1;			end			end			
endfunction	localparam A = G_FEATURESi + 2;	 localparam G_XPCOUNT = log2(A+4);	 reg					[G_XPCOUNT-1:0]		buffer_xcnt; 
reg		signed		[G_PDEPTH-1:0]		buffer_csum			[G_FEATURESo-1:0];
wire									temp_axis_tvalid	[G_FEATURESo-1:0];
wire	signed		[G_PDEPTH-1:0]		temp_axis_tdata		[G_FEATURESo-1:0];always @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)buffer_xcnt <= 0; else if(s_axis_ruser == 1'b1)buffer_xcnt <= 0; else if(temp_axis_tvalid[0] == 1'b1) begin if(buffer_xcnt == G_FEATURESi - 1)buffer_xcnt <=  0;elsebuffer_xcnt <= buffer_xcnt + 1'b1;end
endgenerate
genvar i;
for(i=0; i<=G_FEATURESo-1; i=i+1) beginSigMultiply #(.G_PDEPTH        	( G_PDEPTH 				))u_SigMultiply(		.isysclk        	( isysclk				),.isysrst        	( isysrst				),.s_axis_rvalid		( s_axis_rvalid			),.s_axis_rdat1		( s_axis_rdata			), .s_axis_rdat2		( s_axis_weight[(i+1)*G_WDEPTH-1:G_WDEPTH*i] ), .m_axis_tvalid		( temp_axis_tvalid[i]	),	.m_axis_tdata		( temp_axis_tdata[i]	));always @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)buffer_csum[i] <= 0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == 0)buffer_csum[i] <= temp_axis_tdata[i];else buffer_csum[i] <= buffer_csum[i] + temp_axis_tdata[i];endif(G_FEATURESb == 1'b0) begin always @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= 0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == G_FEATURESi - 1)  m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= buffer_csum[i] + temp_axis_tdata[i];else m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= 0;endalways @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)m_axis_tvalid <= 1'b0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == G_FEATURESi - 1)   m_axis_tvalid <= 1'b1;else m_axis_tvalid <= 1'b0;endend else begin always @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= 0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == G_FEATURESi - 1)m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= buffer_csum[i] + temp_axis_tdata[i] + s_axis_bias[(i+1)*G_WDEPTH-1:G_WDEPTH*i];else m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= 0;endalways @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)m_axis_tvalid <= 1'b0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == G_FEATURESi - 1)   m_axis_tvalid <= 1'b1;else m_axis_tvalid <= 1'b0;endend
end 
endgeneratealways @(posedge isysclk) begin m_axis_tuser  <= s_axis_ruser;
endendmodule