Code Coverage for /src/SciPhp/NumPhp/ArithmeticTrait.php

 
Code Coverage
 
Lines
Functions and Methods
Classes and Traits
Total
97.60% covered (success)
97.60%
122 / 125
75.00% covered (warning)
75.00%
6 / 8
CRAP
0.00% covered (danger)
0.00%
0 / 1
ArithmeticTrait
97.60% covered (success)
97.60%
122 / 125
75.00% covered (warning)
75.00%
6 / 8
47
0.00% covered (danger)
0.00%
0 / 1
 reciprocal
100.00% covered (success)
100.00%
5 / 5
100.00% covered (success)
100.00%
1 / 1
2
 subtract
100.00% covered (success)
100.00%
7 / 7
100.00% covered (success)
100.00%
1 / 1
4
 add
100.00% covered (success)
100.00%
7 / 7
100.00% covered (success)
100.00%
1 / 1
4
 divide
100.00% covered (success)
100.00%
30 / 30
100.00% covered (success)
100.00%
1 / 1
14
 multiply
100.00% covered (success)
100.00%
26 / 26
100.00% covered (success)
100.00%
1 / 1
11
 dot
95.00% covered (success)
95.00%
38 / 40
0.00% covered (danger)
0.00%
0 / 1
10
 rowDot
100.00% covered (success)
100.00%
3 / 3
100.00% covered (success)
100.00%
1 / 1
1
 colDot
85.71% covered (warning)
85.71%
6 / 7
0.00% covered (danger)
0.00%
0 / 1
1.00
1 <?php
2
3 declare(strict_types=1);
4
5 namespace SciPhp\NumPhp;
6
7 use RecursiveArrayIterator;
8 use RecursiveIteratorIterator;
9 use SciPhp\Exception\Message;
10 use SciPhp\NdArray;
11 use Webmozart\Assert\Assert;
12
13 trait ArithmeticTrait
14 {
15     /**
16      * Return the reciprocal of the argument, element-wise.
17      *
18      * @param \SciPhp\NdArray|array|float|int $m
19      * @link http://sciphp.org/numphp.reciprocal
20      *    Documentation for reciprocal() method
21      * @api
22      */
23     final public static function reciprocal($m)
24     {
25         if (is_numeric($m)) {
26             Assert::notEq(0, $m);
27
28             return 1 / $m;
29         }
30
31         static::transform($m, true);
32
33         return static::ones($m->shape)->divide($m);
34     }
35
36     /**
37      * Subtract a matrix from matrix
38      *
39      * @param \SciPhp\NdArray|array|float|int $m
40      * @param \SciPhp\NdArray|array|float|int $n
41      * @link http://sciphp.org/numphp.subtract Documentation
42      * @api
43      */
44     final public static function subtract($m, $n)
45     {
46         if (static::allNumeric($m, $n)) {
47             return $m - $n;
48         }
49
50         static::transform($n);
51
52         // lambda - array
53         if (is_numeric($m) && $n instanceof NdArray) {
54             return static::full_like($n, $m)->subtract($n);
55         }
56
57         // array - array
58         static::transform($m, true);
59
60         // array - array OR array - lambda
61         return $m->negative()->add($n)->negative();
62     }
63
64     /**
65      * Add two array_like
66      *
67      * @param  \SciPhp\NdArray|array|int|float $m
68      * @param  \SciPhp\NdArray|array|int|float $n
69      * @return \SciPhp\NdArray|int|float
70      * @link http://sciphp.org/numphp.add Documentation
71      * @api
72      */
73     final public static function add($m, $n)
74     {
75         if (static::allNumeric($m, $n)) {
76             return $m + $n;
77         }
78
79         static::transform($n);
80
81         // lambda + array
82         if (is_numeric($m) && $n instanceof NdArray) {
83             return $n->copy()->add($m);
84         }
85
86         // array + array
87         static::transform($m, true);
88
89         // array + array OR array + lambda
90         return $m->copy()->add($n);
91     }
92
93     /**
94      * Divide two arrays, element-wise
95      *
96      * @param  \SciPhp\NdArray|array|float|int $m A 2-dim array.
97      * @param  \SciPhp\NdArray|array|float|int $n A 2-dim array.
98      * @return \SciPhp\NdArray|float|int
99      * @throws \InvalidArgumentException
100      * @link http://sciphp.org/numphp.divide
101      *    Documentation for divide()
102      * @api
103      */
104     final public static function divide($m, $n)
105     {
106         if (static::allNumeric($m, $n)) {
107             Assert::notEq(0, $n);
108
109             return $m / $n;
110         }
111
112         static::transform($m);
113         static::transform($n);
114
115         // array / lamba
116         if (is_numeric($n) && $m instanceof NdArray) {
117             return $m->copy()->divide($n);
118         }
119
120         // lamba / array
121         if (is_numeric($m) && $n instanceof NdArray) {
122             return static::full_like($n, $m)->divide($n);
123         }
124
125         // array / array
126         Assert::isInstanceof($m, 'SciPhp\NdArray');
127         Assert::isInstanceof($n, 'SciPhp\NdArray');
128
129         $shape_m = $m->shape;
130         $shape_n = $n->shape;
131
132         // n & m are vectors:
133         if (count($shape_m) === 1 && $m->ndim === $n->ndim) {
134             Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
135         }
136
137         // n is a vector
138         elseif (! isset($shape_n[1])) {
139             Assert::eq($shape_m[1], $shape_n[0], Message::MAT_NOT_ALIGNED);
140         }
141
142         // m is a vector
143         elseif (! isset($shape_m[1])) {
144             Assert::eq($shape_m[0], $shape_n[1], Message::MAT_NOT_ALIGNED);
145
146             $m = $m->resize($shape_n);
147         }
148
149         // array / array -> broadcast
150         elseif ($m->ndim === $n->ndim && $shape_m[0] === $shape_n[0] && $shape_m[1] > $shape_n[1]) {
151             $n = static::broadcast_to($n, $shape_m);
152         }
153
154         // array / array
155         elseif ($m->ndim === $n->ndim) {
156             Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
157         }
158
159         $iterator = new RecursiveIteratorIterator(
160             new RecursiveArrayIterator($n->data),
161             RecursiveIteratorIterator::LEAVES_ONLY
162         );
163
164         $func = static function (&$item) use (&$iterator, $n): void {
165             Assert::notEq(0, $value = $n->iterate($iterator));
166             $item /= $value;
167         };
168
169         return $m->copy()->walk_recursive($func);
170     }
171
172     /**
173      * Multiply two arrays, element-wise
174      *
175      * @param  \SciPhp\NdArray|array|float|int $m A 2-dim array.
176      * @param  \SciPhp\NdArray|array|float|int $n A 2-dim array.
177      * @return \SciPhp\NdArray|float|int
178      * @throws \InvalidArgumentException
179      * @link http://sciphp.org/numphp.multiply Documentation
180      * @api
181      */
182     final public static function multiply($m, $n)
183     {
184         if (static::allNumeric($m, $n)) {
185             return $m * $n;
186         }
187
188         static::transform($m);
189         static::transform($n);
190
191         // array * lamba
192         if (is_numeric($n) && $m instanceof NdArray) {
193             return $m->copy()->dot($n);
194         }
195
196         // lamba * array
197         if (is_numeric($m) && $n instanceof NdArray) {
198             return $n->copy()->dot($m);
199         }
200
201         // array * array
202         Assert::isInstanceof($m, NdArray::class);
203         Assert::isInstanceof($n, NdArray::class);
204
205         $shape_m = $m->shape;
206         $shape_n = $n->shape;
207
208         // n & m are vectors:
209         if (count($shape_m) === 1 && $m->ndim === $n->ndim) {
210             Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
211         }
212
213         // n is a vector
214         elseif (! isset($shape_n[1])) {
215             Assert::eq($shape_m[1], $shape_n[0], Message::MAT_NOT_ALIGNED);
216         }
217
218         // m is a vector
219         elseif (! isset($shape_m[1])) {
220             Assert::eq($shape_m[0], $shape_n[1], Message::MAT_NOT_ALIGNED);
221
222             $m = $m->resize($shape_n);
223         }
224
225         // array * array
226         elseif ($m->ndim === $n->ndim) {
227             Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
228         }
229
230         $iterator = new RecursiveIteratorIterator(
231             new RecursiveArrayIterator($n->data),
232             RecursiveIteratorIterator::LEAVES_ONLY
233         );
234
235         $func = static function (&$item) use (&$iterator, $n): void {
236             $item *= $n->iterate($iterator);
237         };
238
239         return $m->copy()->walk_recursive($func);
240     }
241
242     /**
243      * Dot product of two arrays
244      *
245      * @param  \SciPhp\NdArray|array|float|int $m A 2-dim array.
246      * @param  \SciPhp\NdArray|array|float|int $n A 2-dim array.
247      * @return \SciPhp\NdArray|float|int
248      * @throws \InvalidArgumentException
249      * @link http://sciphp.org/numphp.dot Documentation
250      * @api
251      */
252     final public static function dot($m, $n)
253     {
254         if (static::allNumeric($m, $n)) {
255             return $m * $n;
256         }
257
258         static::transform($m);
259         static::transform($n);
260
261         // array.lamba
262         if (is_numeric($n) && $m instanceof NdArray) {
263             return $m->copy()->dot($n);
264         }
265
266         // lamba.array
267         if (is_numeric($m) && $n instanceof NdArray) {
268             return $n->copy()->dot($m);
269         }
270
271         // array.array
272         Assert::isInstanceof($m, NdArray::class);
273         Assert::isInstanceof($n, NdArray::class);
274
275         $shape_m = $m->shape;
276         $shape_n = $n->shape;
277
278         // n & m are vectors:
279         if (count($shape_m) === 1 && $m->ndim === $n->ndim) {
280             Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
281
282             return array_sum(
283                 array_map(
284                     static function ($el_m, $el_n) {
285                         return $el_m * $el_n;
286                     },
287                     $m->data,
288                     $n->data
289                 )
290             );
291         }
292
293         // n is a vector
294         if (! isset($shape_n[1])) {
295             Assert::eq($shape_m[1], $shape_n[0], Message::MAT_NOT_ALIGNED);
296
297             return static::zeros($shape_m[0], 1)
298                 ->walk(
299                     self::rowDot(
300                         $m,
301                         $n->reshape($shape_n[0], 1)
302                     )
303                 )->reshape($shape_m[0]);
304         }
305
306         // m is a vector
307         if (! isset($shape_m[1])) {
308             Assert::eq($shape_m[0], $shape_n[0], Message::MAT_NOT_ALIGNED);
309
310             $callback = static function (&$item, $k_m) use ($m, $n): void {
311                 $item = array_sum(
312                     array_map(
313                         static function($el_n, $el_m) {
314                             return $el_n * $el_m;
315                         },
316                         $m->data,
317                         array_column($n->data, $k_m)
318                     )
319                 );
320             };
321
322             return static::zeros($shape_n[1])->walk($callback);
323         }
324
325         Assert::eq($shape_m[1], $shape_n[0], Message::MAT_NOT_ALIGNED);
326
327         return static::zeros($shape_m[0], $shape_n[1])->walk(
328             self::rowDot($m, $n)
329         );
330     }
331
332     /**
333      * Browse p rows
334      */
335     final protected static function rowDot(NdArray $m, NdArray $n): callable
336     {
337         return static function (&$row, $row_m) use ($m, $n): void {
338             array_walk(
339                 $row,
340                 self::colDot($row_m, $m, $n)
341             );
342         };
343     }
344
345     /**
346      * Browse p cols and sum products
347      */
348     final protected static function colDot($row_m, NdArray $m, NdArray $n): callable
349     {
350         // row_m * col_n
351         return static function (&$item, $col_m) use ($row_m, $m, $n): void {
352             $item = array_sum(
353                 array_map(
354                     static function ($el_m, $row_n) use ($col_m) {
355                         return $el_m * $row_n[$col_m];
356                     },
357                     $m->data[$row_m],
358                     $n->data
359                 )
360             );
361         };
362     }
363 }