diff --git a/lib/src/arithmetic/floating_point/fft/butterfly.dart b/lib/src/arithmetic/floating_point/fft/butterfly.dart new file mode 100644 index 000000000..39c5f106c --- /dev/null +++ b/lib/src/arithmetic/floating_point/fft/butterfly.dart @@ -0,0 +1,35 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/src/arithmetic/signals/floating_point_logics/complex_floating_point_logic.dart'; + +class Butterfly extends Module { + late final ComplexFloatingPoint outA; + late final ComplexFloatingPoint outB; + + Butterfly({ + required ComplexFloatingPoint inA, + required ComplexFloatingPoint inB, + required ComplexFloatingPoint twiddleFactor, + super.name = 'butterfly', + }) { + final _inA = inA.clone()..gets(addInput('inA', inA, width: inA.width)); + final _inB = inA.clone()..gets(addInput('inB', inB, width: inA.width)); + final _twiddleFactor = inA.clone() + ..gets( + addInput('twiddleFactor', twiddleFactor, width: twiddleFactor.width), + ); + + final outALogic = addOutput('outA', width: inA.width); + final outBLogic = addOutput('outB', width: inA.width); + + final temp = _twiddleFactor.multiplier(_inB); + + outALogic <= _inA.adder(temp.negated); + outBLogic <= _inA.adder(temp); + + outA = inA.clone()..gets(outALogic); + outB = inA.clone()..gets(outBLogic); + } +} diff --git a/lib/src/arithmetic/floating_point/fft/fft_stage.dart b/lib/src/arithmetic/floating_point/fft/fft_stage.dart new file mode 100644 index 000000000..4a1bf6c5d --- /dev/null +++ b/lib/src/arithmetic/floating_point/fft/fft_stage.dart @@ -0,0 +1,181 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +import 'dart:math'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; +import 'package:rohd_hcl/src/arithmetic/floating_point/fft/butterfly.dart'; +import 'package:rohd_hcl/src/arithmetic/signals/floating_point_logics/complex_floating_point_logic.dart'; + +class BadFFTStage extends Module { + final int logStage; + final int exponentWidth; + final int mantissaWidth; + Logic clk; + Logic reset; + Logic go; + DataPortInterface inputSamplesA; + DataPortInterface inputSamplesB; + DataPortInterface twiddleFactorROM; + + late final Logic done; + + BadFFTStage({ + required this.logStage, + required this.exponentWidth, + required this.mantissaWidth, + required this.clk, + required this.reset, + required this.go, + required this.inputSamplesA, + required this.inputSamplesB, + required this.twiddleFactorROM, + required DataPortInterface outputSamplesA, + required DataPortInterface outputSamplesB, + super.name = 'badfftstage', + }) : assert(go.width == 1), + assert( + inputSamplesA.dataWidth == 2 * (1 + exponentWidth + mantissaWidth), + ), + assert( + inputSamplesB.dataWidth == 2 * (1 + exponentWidth + mantissaWidth), + ) { + clk = addInput('clk', clk); + reset = addInput('reset', reset); + go = addInput('go', go); + final doneInner = Logic(name: '_done'); + done = addOutput('done')..gets(doneInner); + final en = (go & ~done).named('enable'); + + inputSamplesA = addInterfacePorts( + inputSamplesA, + inputTags: [DataPortGroup.data], + outputTags: [DataPortGroup.control], + uniquify: (name) => 'inputSamplesA$name', + ); + inputSamplesB = addInterfacePorts( + inputSamplesB, + inputTags: [DataPortGroup.data], + outputTags: [DataPortGroup.control], + uniquify: (name) => 'inputSamplesB$name', + ); + twiddleFactorROM = addInterfacePorts( + twiddleFactorROM, + inputTags: [DataPortGroup.data], + outputTags: [DataPortGroup.control], + uniquify: (name) => 'twiddleFactorROM$name', + ); + + outputSamplesA = addInterfacePorts( + outputSamplesA, + inputTags: [DataPortGroup.control], + outputTags: [DataPortGroup.data], + uniquify: (name) => 'outputSamplesA$name', + ); + outputSamplesB = addInterfacePorts( + outputSamplesB, + inputTags: [DataPortGroup.control], + outputTags: [DataPortGroup.data], + uniquify: (name) => 'outputSamplesB$name', + ); + + final outputSamplesWritePortA = DataPortInterface( + inputSamplesA.dataWidth, + inputSamplesA.addrWidth, + ); + final outputSamplesWritePortB = DataPortInterface( + inputSamplesA.dataWidth, + inputSamplesA.addrWidth, + ); + final outputSamplesReadPortA = DataPortInterface( + inputSamplesA.dataWidth, + inputSamplesA.addrWidth, + ); + final outputSamplesReadPortB = DataPortInterface( + inputSamplesA.dataWidth, + inputSamplesA.addrWidth, + ); + + final n = 1 << inputSamplesA.addrWidth; + RegisterFile( + clk, + reset, + [outputSamplesWritePortA, outputSamplesWritePortB], + [outputSamplesReadPortA, outputSamplesReadPortB], + numEntries: n, + name: 'outputSamplesBuffer', + ); + outputSamplesA.data <= outputSamplesReadPortA.data; + outputSamplesReadPortA.en <= outputSamplesA.en; + outputSamplesReadPortA.addr <= outputSamplesA.addr; + outputSamplesB.data <= outputSamplesReadPortB.data; + outputSamplesReadPortB.en <= outputSamplesB.en; + outputSamplesReadPortB.addr <= outputSamplesB.addr; + + final log2Length = inputSamplesA.addrWidth; + final m = 1 << logStage; + final mShift = log2Ceil(m); + + final i = Counter.ofLogics( + [flop(clk, en)], + clk: clk, + reset: reset | (go & doneInner), + width: max(log2Length - 1, 1), + maxValue: n ~/ 2, + name: 'i', + ); + doneInner <= i.equalsMax; + + final k = ((i.count >> (mShift - 1)) << mShift).named('k'); + final j = (i.count & Const((m >> 1) - 1, width: i.width)).named('j'); + + // for k = 0 to n-1 by m do + // ω ← 1 + // for j = 0 to m/2 – 1 do + // t ← ω A[k + j + m/2] + // u ← A[k + j] + // A[k + j] ← u + t + // A[k + j + m/2] ← u – t + // ω ← ω ωm + final addressA = (k + j).named('addressA'); + final addressB = (addressA + m ~/ 2).named('addressB'); + inputSamplesA.addr <= addressA; + inputSamplesA.en <= en; + inputSamplesB.addr <= addressB; + inputSamplesB.en <= en; + twiddleFactorROM.addr <= j; + twiddleFactorROM.en <= en; + + final butterfly = Butterfly( + inA: ComplexFloatingPoint.of( + inputSamplesA.data, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + ), + inB: ComplexFloatingPoint.of( + inputSamplesB.data, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + ), + twiddleFactor: ComplexFloatingPoint.of( + twiddleFactorROM.data, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + ), + ); + + outputSamplesWritePortA.addr <= addressA; + outputSamplesWritePortA.en <= en; + outputSamplesWritePortB.addr <= addressB; + outputSamplesWritePortB.en <= en; + + Sequential( + clk, + [ + outputSamplesWritePortA.data < butterfly.outA.named('butterflyOutA'), + outputSamplesWritePortB.data < butterfly.outB.named('butterflyOutB'), + ], + reset: reset); + } +} diff --git a/lib/src/arithmetic/signals/floating_point_logics/complex_floating_point_logic.dart b/lib/src/arithmetic/signals/floating_point_logics/complex_floating_point_logic.dart new file mode 100644 index 000000000..49d6e3f20 --- /dev/null +++ b/lib/src/arithmetic/signals/floating_point_logics/complex_floating_point_logic.dart @@ -0,0 +1,102 @@ +// Copyright (C) 2024-2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +class ComplexFloatingPoint extends LogicStructure { + final FloatingPoint realPart; + + final FloatingPoint imaginaryPart; + + static String _nameJoin(String? structName, String signalName) { + if (structName == null) { + return signalName; + } + return '${structName}_$signalName'; + } + + ComplexFloatingPoint({ + required int exponentWidth, + required int mantissaWidth, + String? name, + }) : this._internal( + realPart: FloatingPoint( + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + name: _nameJoin(name, 're'), + ), + imaginaryPart: FloatingPoint( + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + name: _nameJoin(name, 'im'), + ), + name: name, + ); + + ComplexFloatingPoint.of( + Logic input, { + required int exponentWidth, + required int mantissaWidth, + String? name, + }) : this._internal( + realPart: FloatingPoint( + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + name: _nameJoin(name, 're'), + )..gets(input.getRange(0, 1 + exponentWidth + mantissaWidth)), + imaginaryPart: FloatingPoint( + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + name: _nameJoin(name, 'im'), + )..gets( + input.getRange(1 + exponentWidth + mantissaWidth, input.width)), + name: name); + + ComplexFloatingPoint._internal( + {required this.realPart, required this.imaginaryPart, super.name}) + : assert(realPart.exponent.width == imaginaryPart.exponent.width), + assert(realPart.mantissa.width == imaginaryPart.mantissa.width), + super([realPart, imaginaryPart]); + + @mustBeOverridden + @override + ComplexFloatingPoint clone({String? name}) => ComplexFloatingPoint( + exponentWidth: realPart.exponent.width, + mantissaWidth: realPart.mantissa.width, + name: name, + ); + + ComplexFloatingPoint adder(ComplexFloatingPoint other) => + ComplexFloatingPoint._internal( + realPart: FloatingPointAdderSinglePath(realPart, other.realPart).sum, + imaginaryPart: + FloatingPointAdderSinglePath(imaginaryPart, other.imaginaryPart) + .sum, + name: _nameJoin(name, 'adder')); + + ComplexFloatingPoint multiplier(ComplexFloatingPoint other) { + // use only 3 multipliers: https://mathworld.wolfram.com/ComplexMultiplication.html + final ac = FloatingPointMultiplierSimple(realPart, other.realPart).product; + final bd = FloatingPointMultiplierSimple(imaginaryPart, other.imaginaryPart) + .product; + final abcd = FloatingPointMultiplierSimple( + FloatingPointAdderSinglePath(realPart, imaginaryPart).sum, + FloatingPointAdderSinglePath(other.realPart, other.imaginaryPart) + .sum) + .product; + + return ComplexFloatingPoint._internal( + realPart: FloatingPointAdderSinglePath(ac, bd.negated()).sum, + imaginaryPart: FloatingPointAdderSinglePath(abcd, + FloatingPointAdderSinglePath(ac.negated(), bd.negated()).sum) + .sum, + name: _nameJoin(name, 'multiplier')); + } + + late final ComplexFloatingPoint negated = ComplexFloatingPoint._internal( + realPart: realPart.negated(), + imaginaryPart: imaginaryPart.negated(), + name: _nameJoin(name, 'negated')); +} diff --git a/lib/src/arithmetic/signals/floating_point_logics/floating_point_logic.dart b/lib/src/arithmetic/signals/floating_point_logics/floating_point_logic.dart index 11bb33930..71ae6facf 100644 --- a/lib/src/arithmetic/signals/floating_point_logics/floating_point_logic.dart +++ b/lib/src/arithmetic/signals/floating_point_logics/floating_point_logic.dart @@ -35,31 +35,38 @@ class FloatingPoint extends LogicStructure { /// [FloatingPoint] constructor for a variable size binary /// floating point number. - FloatingPoint( - {required int exponentWidth, - required int mantissaWidth, - bool explicitJBit = false, - bool subNormalAsZero = false, - String? name}) - : this._( - Logic(name: 'sign', naming: Naming.mergeable), - Logic( - width: exponentWidth, - name: 'exponent', - naming: Naming.mergeable), - Logic( - width: mantissaWidth, - name: 'mantissa', - naming: Naming.mergeable), - explicitJBit, - subNormalAsZero, - name: name); + FloatingPoint({ + required int exponentWidth, + required int mantissaWidth, + bool explicitJBit = false, + bool subNormalAsZero = false, + String? name, + }) : this._( + Logic(name: 'sign', naming: Naming.mergeable), + Logic( + width: exponentWidth, + name: 'exponent', + naming: Naming.mergeable, + ), + Logic( + width: mantissaWidth, + name: 'mantissa', + naming: Naming.mergeable, + ), + explicitJBit, + subNormalAsZero, + name: name, + ); /// [FloatingPoint] internal constructor. - FloatingPoint._(this.sign, this.exponent, this.mantissa, this.explicitJBit, - this.subNormalAsZero, - {super.name}) - : super([mantissa, exponent, sign]); + FloatingPoint._( + this.sign, + this.exponent, + this.mantissa, + this.explicitJBit, + this.subNormalAsZero, { + super.name, + }) : super([mantissa, exponent, sign]); @mustBeOverridden @override @@ -75,10 +82,11 @@ class FloatingPoint extends LogicStructure { /// [FloatingPoint] type. @mustBeOverridden FloatingPointValuePopulator valuePopulator() => FloatingPointValue.populator( - exponentWidth: exponent.width, - mantissaWidth: mantissa.width, - explicitJBit: explicitJBit, - subNormalAsZero: subNormalAsZero); + exponentWidth: exponent.width, + mantissaWidth: mantissa.width, + explicitJBit: explicitJBit, + subNormalAsZero: subNormalAsZero, + ); /// Return `true` if the J-bit is explicitly represented in the mantissa. final bool explicitJBit; @@ -91,14 +99,18 @@ class FloatingPoint extends LogicStructure { FloatingPoint resolveSubNormalAsZero() { if (subNormalAsZero) { return clone() - ..gets(mux( + ..gets( + mux( isNormal, this, FloatingPoint.zero( - exponentWidth: exponent.width, - mantissaWidth: mantissa.width, - explicitJBit: explicitJBit, - subNormalAsZero: subNormalAsZero))); + exponentWidth: exponent.width, + mantissaWidth: mantissa.width, + explicitJBit: explicitJBit, + subNormalAsZero: subNormalAsZero, + ), + ), + ); } else { return this; } @@ -122,10 +134,7 @@ class FloatingPoint extends LogicStructure { /// by having its exponent field set to the NaN value (typically all /// ones) and a non-zero mantissa. late final isNaN = exponent.eq(valuePopulator().nan.exponent) & - mantissa.or().named( - _nameJoin('isNaN', name), - naming: Naming.mergeable, - ); + mantissa.or().named(_nameJoin('isNaN', name), naming: Naming.mergeable); /// Return a [Logic] `1` if this [FloatingPoint] is an infinity /// by having its exponent field set to the NaN value (typically all @@ -150,29 +159,38 @@ class FloatingPoint extends LogicStructure { .named(_nameJoin('isAZero', name), naming: Naming.mergeable); /// Return the zero exponent representation for this type of [FloatingPoint]. - late final zeroExponent = Const(LogicValue.zero, width: exponent.width) - .named(_nameJoin('zeroExponent', name), naming: Naming.mergeable); + late final zeroExponent = Const( + LogicValue.zero, + width: exponent.width, + ).named(_nameJoin('zeroExponent', name), naming: Naming.mergeable); /// Return the one exponent representation for this type of [FloatingPoint]. - late final oneExponent = Const(LogicValue.one, width: exponent.width) - .named(_nameJoin('oneExponent', name), naming: Naming.mergeable); + late final oneExponent = Const( + LogicValue.one, + width: exponent.width, + ).named(_nameJoin('oneExponent', name), naming: Naming.mergeable); /// Return the exponent [Logic] representing the [bias] of this /// [FloatingPoint] signal, the offset of the exponent, also representing the /// zero exponent `2^0 = 1`. - late final bias = Const((1 << exponent.width - 1) - 1, width: exponent.width) - .named(_nameJoin('bias', name), naming: Naming.mergeable); + late final bias = Const( + (1 << exponent.width - 1) - 1, + width: exponent.width, + ).named(_nameJoin('bias', name), naming: Naming.mergeable); /// Construct a [FloatingPoint] that represents infinity for this FP type. FloatingPoint inf({Logic? sign, bool negative = false}) => FloatingPoint.inf( - exponentWidth: exponent.width, - mantissaWidth: mantissa.width, - sign: sign, - negative: negative); + exponentWidth: exponent.width, + mantissaWidth: mantissa.width, + sign: sign, + negative: negative, + ); /// Construct a [FloatingPoint] that represents NaN for this FP type. late final nan = FloatingPoint.nan( - exponentWidth: exponent.width, mantissaWidth: mantissa.width); + exponentWidth: exponent.width, + mantissaWidth: mantissa.width, + ); @override void put(dynamic val, {bool fill = false}) { @@ -186,7 +204,8 @@ class FloatingPoint extends LogicStructure { } if (val.subNormalAsZero != subNormalAsZero) { throw RohdHclException( - 'FloatingPoint subnormal as zero does not match'); + 'FloatingPoint subnormal as zero does not match', + ); } put(val.value); } else { @@ -194,44 +213,70 @@ class FloatingPoint extends LogicStructure { } } + FloatingPoint negated() => FloatingPoint._( + ~sign, + exponent.clone()..gets(exponent), + mantissa.clone()..gets(mantissa), + explicitJBit, + subNormalAsZero, + ); + /// Construct a [FloatingPoint] that represents infinity. - factory FloatingPoint.inf( - {required int exponentWidth, - required int mantissaWidth, - Logic? sign, - bool negative = false, - bool explicitJBit = false, - bool subNormalAsZero = false}) { + factory FloatingPoint.inf({ + required int exponentWidth, + required int mantissaWidth, + Logic? sign, + bool negative = false, + bool explicitJBit = false, + bool subNormalAsZero = false, + }) { final signLogic = Logic()..gets(sign ?? Const(negative)); final exponent = Const(1, width: exponentWidth, fill: true); final mantissa = Const(0, width: mantissaWidth, fill: true); return FloatingPoint._( - signLogic, exponent, mantissa, explicitJBit, subNormalAsZero); + signLogic, + exponent, + mantissa, + explicitJBit, + subNormalAsZero, + ); } /// Construct a [FloatingPoint] that represents NaN. - factory FloatingPoint.nan( - {required int exponentWidth, - required int mantissaWidth, - bool explicitJBit = false, - bool subNormalAsZero = false}) { + factory FloatingPoint.nan({ + required int exponentWidth, + required int mantissaWidth, + bool explicitJBit = false, + bool subNormalAsZero = false, + }) { final signLogic = Const(0); final exponent = Const(1, width: exponentWidth, fill: true); final mantissa = Const(1, width: mantissaWidth); return FloatingPoint._( - signLogic, exponent, mantissa, explicitJBit, subNormalAsZero); + signLogic, + exponent, + mantissa, + explicitJBit, + subNormalAsZero, + ); } /// Construct a [FloatingPoint] that represents zero. - factory FloatingPoint.zero( - {required int exponentWidth, - required int mantissaWidth, - bool explicitJBit = false, - bool subNormalAsZero = false}) { + factory FloatingPoint.zero({ + required int exponentWidth, + required int mantissaWidth, + bool explicitJBit = false, + bool subNormalAsZero = false, + }) { final signLogic = Const(0); final exponent = Const(0, width: exponentWidth, fill: true); final mantissa = Const(0, width: mantissaWidth); return FloatingPoint._( - signLogic, exponent, mantissa, explicitJBit, subNormalAsZero); + signLogic, + exponent, + mantissa, + explicitJBit, + subNormalAsZero, + ); } } diff --git a/lib/src/bit_reversal.dart b/lib/src/bit_reversal.dart new file mode 100644 index 000000000..418bd3a00 --- /dev/null +++ b/lib/src/bit_reversal.dart @@ -0,0 +1,44 @@ +// Copyright (C) 2021-2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +int bitReverse(int value, int bits) { + var reversed = 0; + for (var i = 0; i < bits; i++) { + reversed <<= 1; + reversed |= value & 1; + value >>= 1; + } + return reversed; +} + +class BitReversal extends Module { + LogicArray get out => output('out') as LogicArray; + + BitReversal(LogicArray input, {super.name = 'bit_reversal'}) + : assert(input.dimensions.length == 1, 'Can only bit reverse 1D arrays') { + input = addInputArray( + 'input_array', + input, + dimensions: input.dimensions, // it seems like these are needed + elementWidth: input.elementWidth, + numUnpackedDimensions: input.numUnpackedDimensions, + ); + + final out = addOutputArray( + 'out', + dimensions: input.dimensions, + elementWidth: input.elementWidth, + numUnpackedDimensions: input.numUnpackedDimensions, + ); + + final length = input.dimensions[0]; + final bits = log2Ceil(length); + + for (var i = 0; i < length; i++) { + out.elements[bitReverse(i, bits)] <= input.elements[i]; + } + } +} diff --git a/pubspec.yaml b/pubspec.yaml index c561a429b..0aa98f3b8 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -18,7 +18,3 @@ dependencies: dev_dependencies: logging: ^1.0.1 test: ^1.25.0 - - - - diff --git a/test.dart b/test.dart new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test.dart @@ -0,0 +1 @@ + diff --git a/test/arithmetic/floating_point/bad_fft_stage_test.dart b/test/arithmetic/floating_point/bad_fft_stage_test.dart new file mode 100644 index 000000000..bf613ee75 --- /dev/null +++ b/test/arithmetic/floating_point/bad_fft_stage_test.dart @@ -0,0 +1,158 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +import 'dart:async'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; +import 'package:rohd_hcl/src/arithmetic/floating_point/fft/fft_stage.dart'; +import 'package:rohd_hcl/src/arithmetic/signals/floating_point_logics/complex_floating_point_logic.dart'; +import 'package:rohd_vf/rohd_vf.dart'; +import 'package:test/test.dart'; +import 'fft_utils.dart'; + +Future write( + Logic clk, + DataPortInterface writePort, + LogicValue value, + int addr, +) async { + await clk.nextNegedge; + writePort.addr.inject(LogicValue.ofInt(addr, writePort.addrWidth)); + writePort.data.inject(value); + writePort.en.inject(1); + + await clk.nextNegedge; + writePort.en.inject(0); + await clk.nextNegedge; +} + +Future read(Logic clk, DataPortInterface readPort, int addr) async { + readPort.addr.inject(LogicValue.ofInt(addr, readPort.addrWidth)); + readPort.en.inject(1); + await clk.nextPosedge; + final value = readPort.data.value; + + await clk.nextNegedge; + readPort.en.inject(0); + await clk.nextNegedge; + + return value; +} + +void main() { + tearDown(() async { + await Simulator.reset(); + }); + + test('fft stage unit test', () async { + final a = Complex(real: 1, imaginary: 2); + final b = Complex(real: -3, imaginary: -4); + final twiddle = Complex(real: 1, imaginary: 0); + final aLogic = newComplex(a.real, a.imaginary); + final bLogic = newComplex(b.real, b.imaginary); + final twiddleLogic = newComplex(twiddle.real, twiddle.imaginary); + final clk = SimpleClockGenerator(10).clk; + final reset = Logic()..put(0); + final go = Logic()..put(0); + + const n = 2; + + final exponentWidth = aLogic.realPart.exponent.width; + final mantissaWidth = aLogic.realPart.mantissa.width; + final dataWidth = aLogic.width; //2 * (1 + exponentWidth + mantissaWidth); + final addrWidth = log2Ceil(n); + + final tempMemoryWritePort = DataPortInterface(dataWidth, addrWidth); + final tempMemoryReadPortA = DataPortInterface(dataWidth, addrWidth); + final tempMemoryReadPortB = DataPortInterface(dataWidth, addrWidth); + final twiddleFactorROMWritePort = DataPortInterface(dataWidth, addrWidth); + final twiddleFactorROMReadPort = DataPortInterface(dataWidth, addrWidth); + final outputSamplesA = DataPortInterface(dataWidth, addrWidth); + final outputSamplesB = DataPortInterface(dataWidth, addrWidth); + + final twiddleFactorROM = RegisterFile( + clk, + reset, + [twiddleFactorROMWritePort], + [twiddleFactorROMReadPort], + numEntries: n, + ); + final tempMemory = RegisterFile( + clk, + reset, + [tempMemoryWritePort], + [tempMemoryReadPortA, tempMemoryReadPortB], + numEntries: n, + ); + + final stage = BadFFTStage( + logStage: 1, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + clk: clk, + reset: reset, + go: go, + inputSamplesA: tempMemoryReadPortA, + inputSamplesB: tempMemoryReadPortB, + twiddleFactorROM: twiddleFactorROMReadPort, + outputSamplesA: outputSamplesA, + outputSamplesB: outputSamplesB, + ); + + await stage.build(); + + WaveDumper(stage); + + unawaited(Simulator.run()); + + reset.inject(1); + await clk.waitCycles(10); + reset.inject(0); + await clk.waitCycles(10); + + await write(clk, tempMemoryWritePort, aLogic.value, 0); + await write(clk, tempMemoryWritePort, bLogic.value, 1); + await write(clk, twiddleFactorROMWritePort, twiddleLogic.value, 0); + + go.inject(1); + flop(clk, stage.done).posedge.listen((_) { + go.inject(0); + }); + await clk.waitCycles(5); + + final output1 = await read(clk, outputSamplesA, 0); + final output2 = await read(clk, outputSamplesA, 1); + final output1float = ComplexFloatingPoint.of( + Const(output1), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + ); + final output2float = ComplexFloatingPoint.of( + Const(output2), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + ); + + final expected = butterfly(a, b, twiddle); + + compareDouble( + output1float.realPart.floatingPointValue.toDouble(), + expected[0].real, + ); + compareDouble( + output1float.imaginaryPart.floatingPointValue.toDouble(), + expected[0].imaginary, + ); + compareDouble( + output2float.realPart.floatingPointValue.toDouble(), + expected[1].real, + ); + compareDouble( + output2float.imaginaryPart.floatingPointValue.toDouble(), + expected[1].imaginary, + ); + + await Simulator.endSimulation(); + }); +} diff --git a/test/arithmetic/floating_point/butterfly_test.dart b/test/arithmetic/floating_point/butterfly_test.dart new file mode 100644 index 000000000..742d4c87f --- /dev/null +++ b/test/arithmetic/floating_point/butterfly_test.dart @@ -0,0 +1,62 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +import 'dart:math'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/src/arithmetic/floating_point/fft/butterfly.dart'; +import 'package:test/test.dart'; +import 'fft_utils.dart'; + +void main() { + tearDown(() async { + await Simulator.reset(); + }); + + test('butterfly unit test', () { + final random = Random(); + for (var i = 0; i < 5; i++) { + final a = Complex( + real: random.nextDouble(), + imaginary: random.nextDouble(), + ); + final b = Complex( + real: random.nextDouble(), + imaginary: random.nextDouble(), + ); + final twiddle = Complex( + real: random.nextDouble(), + imaginary: random.nextDouble(), + ); + + final aLogic = newComplex(a.real, a.imaginary); + final bLogic = newComplex(b.real, b.imaginary); + final twiddleLogic = newComplex(twiddle.real, twiddle.imaginary); + + final expected = butterfly(a, b, twiddle); + + final butterflyModule = Butterfly( + inA: aLogic, + inB: bLogic, + twiddleFactor: twiddleLogic, + ); + + compareDouble( + butterflyModule.outA.realPart.floatingPointValue.toDouble(), + expected[0].real, + ); + compareDouble( + butterflyModule.outA.imaginaryPart.floatingPointValue.toDouble(), + expected[0].imaginary, + ); + compareDouble( + butterflyModule.outB.realPart.floatingPointValue.toDouble(), + expected[1].real, + ); + compareDouble( + butterflyModule.outB.imaginaryPart.floatingPointValue.toDouble(), + expected[1].imaginary, + ); + } + }); +} diff --git a/test/arithmetic/floating_point/complex_floating_point_test.dart b/test/arithmetic/floating_point/complex_floating_point_test.dart new file mode 100644 index 000000000..9ebcba671 --- /dev/null +++ b/test/arithmetic/floating_point/complex_floating_point_test.dart @@ -0,0 +1,57 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; +import 'package:rohd_hcl/src/arithmetic/signals/floating_point_logics/complex_floating_point_logic.dart'; +import 'package:test/test.dart'; + +ComplexFloatingPoint newComplex(double real, double imaginary) { + final realFP = FloatingPoint64(); + final imaginaryFP = FloatingPoint64(); + + final realFPValue = FloatingPoint64Value.populator().ofDouble(real); + final imaginaryFPValue = FloatingPoint64Value.populator().ofDouble(imaginary); + + realFP.put(realFPValue); + imaginaryFP.put(imaginaryFPValue); + + final complex = ComplexFloatingPoint( + exponentWidth: realFP.exponent.width, + mantissaWidth: realFP.mantissa.width); + complex.realPart <= realFP; + complex.imaginaryPart <= imaginaryFP; + + return complex; +} + +void main() { + tearDown(() async { + await Simulator.reset(); + }); + + test('complex constructor', () { + final complex = newComplex(1.23, 3.45); + + expect(complex.realPart.floatingPointValue.toDouble(), 1.23); + expect(complex.imaginaryPart.floatingPointValue.toDouble(), 3.45); + }); + + test('complex addition', () { + final a = newComplex(1, 0); + final b = newComplex(0, -1); + final c = a.adder(b); + + expect(c.realPart.floatingPointValue.toDouble(), 1.0); + expect(b.imaginaryPart.floatingPointValue.toDouble(), -1.0); + }); + + test('complex multiplication', () { + final a = newComplex(1, 2); + final b = newComplex(-3, -4); + final c = a.multiplier(b); + + expect(c.realPart.floatingPointValue.toDouble(), 5.0); + expect(c.imaginaryPart.floatingPointValue.toDouble(), -10.0); + }); +} diff --git a/test/arithmetic/floating_point/fft_utils.dart b/test/arithmetic/floating_point/fft_utils.dart new file mode 100644 index 000000000..0ca863198 --- /dev/null +++ b/test/arithmetic/floating_point/fft_utils.dart @@ -0,0 +1,64 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +import 'package:rohd_hcl/rohd_hcl.dart'; +import 'package:rohd_hcl/src/arithmetic/signals/floating_point_logics/complex_floating_point_logic.dart'; + +ComplexFloatingPoint newComplex(double real, double imaginary) { + final realFP = FloatingPoint64(); + final imaginaryFP = FloatingPoint64(); + + final realFPValue = FloatingPoint64Value.populator().ofDouble(real); + final imaginaryFPValue = FloatingPoint64Value.populator().ofDouble(imaginary); + + realFP.put(realFPValue); + imaginaryFP.put(imaginaryFPValue); + + final complex = ComplexFloatingPoint( + exponentWidth: realFP.exponent.width, + mantissaWidth: realFP.mantissa.width, + ); + complex.realPart <= realFP; + complex.imaginaryPart <= imaginaryFP; + + return complex; +} + +class Complex { + double real; + double imaginary; + + Complex({required this.real, required this.imaginary}); + + Complex add(Complex other) => Complex( + real: real + other.real, + imaginary: imaginary + other.imaginary, + ); + + Complex subtract(Complex other) => Complex( + real: real - other.real, + imaginary: imaginary - other.imaginary, + ); + + Complex multiply(Complex other) => Complex( + real: (real * other.real) - (imaginary * other.imaginary), + imaginary: (real * other.imaginary) + (imaginary * other.real), + ); + + @override + String toString() => '$real${imaginary >= 0 ? '+' : ''}${imaginary}i'; +} + +List butterfly(Complex inA, Complex inB, Complex twiddleFactor) { + final temp = twiddleFactor.multiply(inB); + return [inA.subtract(temp), inA.add(temp)]; +} + +const epsilon = 1e-15; + +void compareDouble(double actual, double expected) { + assert( + (actual - expected).abs() < epsilon, + 'actual $actual, expected $expected', + ); +}