Răsfoiți Sursa

Added multiplication of two tensors (interpreted as matrices, otherwise trap)

git-svn-id: https://svn.inf.ethz.ch/svn/lecturers/a2/trunk@7160 8c9fc860-2736-0410-a75d-ab315db34111
felixf 8 ani în urmă
părinte
comite
e7991724e7
1 a modificat fișierele cu 78 adăugiri și 2 ștergeri
  1. 78 2
      source/FoxArrayBase.Mod

+ 78 - 2
source/FoxArrayBase.Mod

@@ -258,7 +258,7 @@ VAR
 
 	PROCEDURE GetFlags( base: UnsafeArray ): SET;
 	BEGIN
-		RETURN base.flags
+		IF base = 0 THEN RETURN {} ELSE RETURN base.flags END;
 	END GetFlags;
 
 	PROCEDURE PutDim( base: UnsafeArray;  dim: SIZE );
@@ -342,6 +342,21 @@ VAR
 		ptr.flags := {TensorFlag};
 		RETURN ptr;
 	END GetArrayDesc;
+	
+	PROCEDURE EnsureArrayDesc*(dim: SIZE; VAR d: Tensor);
+	BEGIN
+		IF d = NIL THEN
+			d := GetArrayDesc(dim);
+		ELSIF d.dim # dim THEN
+			IF ~(TensorFlag IN d.flags) &
+				~(TemporaryFlag IN d.flags) THEN  (* no, not allowed*)
+				HALT( 100 );
+			END;
+			d := GetArrayDesc(dim)
+		(* ELSE keep as is *)
+		END;
+	END EnsureArrayDesc;
+	
 
 	PROCEDURE Halt( code: LONGINT;  left, right, dest: LONGINT );
 	VAR reason: ARRAY 64 OF CHAR;
@@ -1545,7 +1560,8 @@ Sufficient (but not necessary) conditions:
 	VAR p: ANY;
 	BEGIN
 		(* Report("dest",dest); Report("src",src); *)
-		IF (dest = 0) OR ~(SameShape( dest, src )) OR (GetAdr( dest ) = 0) THEN
+		IF (src = NIL) THEN dest := NIL
+		ELSIF (dest = 0) OR ~(SameShape( dest, src )) OR (GetAdr( dest ) = 0) THEN
 			p := AllocateSame( dest, src, elementsize );   (* includes check if allocation is allowed *)
 			CopyContent( dest, src, elementsize );
 		ELSIF dest = src THEN CopyTensorSelf( dest, src, elementsize );
@@ -9081,6 +9097,66 @@ TYPE
 		RESULT[1] := vl3 * vr1 - vl1 * vr3;  RESULT[2] := vl1 * vr2 - vl2 * vr1;
 		RETURN RESULT
 	END "*";
+	
+
+	OPERATOR "*"*(CONST left, right: ARRAY [ ? ] OF LONGREAL ): ARRAY [ ? ] OF LONGREAL;
+	VAR tensor: Tensor;
+	BEGIN
+		IF (DIM(left) = 2) & (DIM(right)=2) THEN
+			EnsureArrayDesc(2, SYSTEM.VAL(Tensor, RESULT));
+			ApplyMatMulLoop(SYSTEM.VAL(Tensor, RESULT), SYSTEM.VAL(Tensor, left), SYSTEM.VAL(Tensor, right), SIZEOF( LONGREAL ),
+									    loopMatMulAXAX, matMulX );
+		ELSE HALT(200);
+		END;
+		RETURN RESULT
+	END "*";
+
+	OPERATOR "*"*(CONST left, right: ARRAY [ ? ] OF REAL ): ARRAY [ ? ] OF REAL;
+	BEGIN
+		IF (DIM(left) = 2) & (DIM(right)=2) THEN
+			EnsureArrayDesc(2, SYSTEM.VAL(Tensor, RESULT));
+			ApplyMatMulLoop(SYSTEM.VAL(Tensor, RESULT), SYSTEM.VAL(Tensor, left), SYSTEM.VAL(Tensor, right), SIZEOF( REAL ),
+									    loopMatMulARAR, matMulR );
+		ELSE HALT(200);
+		END;
+		RETURN RESULT
+	END "*";
+
+	OPERATOR "*"*(CONST left, right: ARRAY [ ? ] OF LONGINT ): ARRAY [ ? ] OF LONGINT;
+	BEGIN
+		IF (DIM(left) = 2) & (DIM(right)=2) THEN
+			EnsureArrayDesc(2, SYSTEM.VAL(Tensor, RESULT));
+			ApplyMatMulLoop(SYSTEM.VAL(Tensor, RESULT), SYSTEM.VAL(Tensor, left), SYSTEM.VAL(Tensor, right),  SIZEOF( LONGINT ),
+									    MatMulALALLoop, NIL );
+		ELSE HALT(200);
+		END;
+		RETURN RESULT
+	END "*";
+
+
+	OPERATOR "*"*(CONST left, right: ARRAY [ ? ] OF INTEGER ): ARRAY [ ? ] OF INTEGER;
+	BEGIN
+		IF (DIM(left) = 2) & (DIM(right)=2) THEN
+			EnsureArrayDesc(2, SYSTEM.VAL(Tensor, RESULT));
+			ApplyMatMulLoop(SYSTEM.VAL(Tensor, RESULT), SYSTEM.VAL(Tensor, left), SYSTEM.VAL(Tensor, right), SIZEOF( INTEGER ),
+									    MatMulAIAILoop,NIL );
+		ELSE HALT(200);
+		END;
+		RETURN RESULT
+	END "*";
+
+	OPERATOR "*"*(CONST left, right: ARRAY [ ? ] OF SHORTINT ): ARRAY [ ? ] OF SHORTINT;
+	BEGIN
+		IF (DIM(left) = 2) & (DIM(right)=2) THEN
+			EnsureArrayDesc(2, SYSTEM.VAL(Tensor, RESULT));
+			ApplyMatMulLoop(SYSTEM.VAL(Tensor, RESULT), SYSTEM.VAL(Tensor, left), SYSTEM.VAL(Tensor, right), SIZEOF( SHORTINT ),
+									    MatMulASASLoop, NIL );
+		ELSE HALT(200);
+		END;
+		RETURN RESULT
+	END "*";
+
+
 
 (** Transpose  ********************************************************************)