forked from Tensor-Array/Tensor-Array-Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlinear.py
More file actions
37 lines (32 loc) · 1.16 KB
/
linear.py
File metadata and controls
37 lines (32 loc) · 1.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from .. import Layer
from .. import Parameter
from tensor_array.core import Tensor
from tensor_array.core import zeros
from tensor_array.core import DataTypes
from typing import Any
class Linear(Layer):
def __init__(self, bias) -> None:
"""
Initializes a Linear layer with a specified bias shape.
Args:
bias (int): The shape of the bias tensor.
"""
super().__init__()
self.bias_shape = bias
self.b = Parameter(zeros(shape = (bias,), dtype = DataTypes.FLOAT))
def layer_init(self, t):
"""
Initializes the layer with the shape of the input tensor.
Args:
t (Tensor): The input tensor to determine the shape for the weight parameter.
"""
self.w = Parameter(zeros(shape = (t[-1], self.bias_shape), dtype = DataTypes.FLOAT))
def calculate(self, t):
"""
Calculates the linear transformation of the input tensor.
Args:
t (Tensor): The input tensor to be transformed.
Returns:
Tensor: The transformed tensor after applying the linear transformation.
"""
return t @ self.w + self.b