max - 什么是在数组中找到最大值的低延迟硬件算法
问题描述
我需要构建一个低延迟、单周期的硬件模块来查找数组中最大元素的索引和值。目前,我正在使用比较器树,但延迟不令人满意。那么还有其他可能具有更低延迟的算法吗?
我希望输入数组很大(256 到 4096 个元素)并且值很小(3 到 5 位)。另外,我希望数组是稀疏的,即许多小值和几个大值。
我主要关心延迟;面积不是很重要。
我当前使用比较器树的实现如下所示:
implicit class reduceTreeOp[A](seq: Seq[A]) {
def reduceTree[B >: A](op: (B, B) => B): B = {
if(seq.length == 0)
throw new NoSuchElementException("cannot reduce empty Seq")
var rseq: Seq[B] = seq
while(rseq.length != 1)
rseq = rseq.grouped(2).toSeq
.map(s => if(s.length == 1) s(0) else op(s(0), s(1)))
rseq(0)
}
}
val (value, index) = array
.zipWithIndex
.map{case (v, i) => (v, i.U)}
.reduceTree[(UInt, UInt)]{case ((val1, idx1), (val2, idx2)) =>
val is1 = val1 >= idx2
( Mux(is1, val1, idx2),
Mux(is1, idx1, idx2))
}
FWIW 这是为 7nm 硬件设计的;尽管我怀疑这实际上对我的问题很重要。
解决方案
这是很简单的事情chisel3
。结果在单个周期中返回的约束将导致生成一大堆硬件。在这里仔细评估您需要什么可能是一个好主意。
无论如何,这是一个有趣的问题,并展示了凿子的一些力量。我提供了一个可运行的示例Scastie 示例。
这是带有一个非常简单的测试套件的代码
import chisel3._
import chisel3.util.log2Ceil
import chiseltest._
import org.scalatest.freespec.AnyFreeSpec
import treadle.extremaOfUIntOfWidth
/** write only memory that continuously outputs the first index of the highest value in an array and
* that highest value. It works by building a evaluation network of highest values
*
* It would be trivial to add ability to read the values in this memory
*
* @param depth
* @param bitWidth
*/
class ArrayMax(val depth: Int, val bitWidth: Int) extends MultiIOModule {
val writeEnable = IO(Input(Bool()))
val writeAddress = IO(Input(UInt(log2Ceil(depth).W)))
val writeData = IO(Input(UInt(bitWidth.W)))
val maxValue = IO(Output(UInt(bitWidth.W)))
val indexOfMaxValue = IO(Output(UInt(log2Ceil(depth).W)))
val array = Reg(Vec(depth, UInt(bitWidth.W)))
when(writeEnable) {
array(writeAddress) := writeData
}
val valuesAndIndices = array.zipWithIndex.map { case (value, index) => (value, index.U)}.toList
// Look through the array pair wise and return the value and index of the higher of the pair
def compareAdjacentValues(valuesAndIndices: Seq[(UInt, UInt)]): Seq[(UInt, UInt)] = {
val pairs = valuesAndIndices.sliding(2, 2)
pairs.map {
case (aValue, aIndex) :: (bValue, bIndex) :: Nil =>
val (higherValue, higherIndex) = (Wire(UInt(bitWidth.W)), Wire(UInt(log2Ceil(depth).W)))
when(aValue < bValue) {
higherValue := bValue
higherIndex := bIndex
} otherwise {
higherValue := aValue
higherIndex := aIndex
}
(higherValue, higherIndex)
case (aValue, aIndex) :: Nil =>
(aValue, aIndex)
case a =>
throw new Exception("Cannot get here, sliding should return list of size 1 or 2, $a")
}.toList
}
def reduceToOne(pairs: Seq[(UInt, UInt)]): (UInt, UInt) = {
if(pairs.length == 1) {
pairs.head
} else {
reduceToOne(compareAdjacentValues(pairs))
}
}
val (highestValue, index) = reduceToOne(valuesAndIndices)
maxValue := highestValue
indexOfMaxValue := index
}
/** Pumps random values at random indices into the dut and buffer that models it
* Checks to see that the first highest value (there can be multiple occurrences)
* is returned and the first index where that value appears.
*
* A simpler model might try to keep a record of the highest value in the array
* but that will break down if that value is replaced by something lower
*/
class ArrayMaxSpec extends AnyFreeSpec with ChiselScalatestTester {
val rand = new scala.util.Random
"should act like a little write only memory" in {
test(new ArrayMax(depth = 256, bitWidth = 8)) { dut =>
val testArray = Array.fill(dut.depth)(0)
for(i <- 0 until dut.depth * 10) {
val index = rand.nextInt(dut.depth)
dut.writeEnable.poke(true.B)
dut.writeAddress.poke(index.U)
val newValue = rand.nextInt(extremaOfUIntOfWidth(dut.bitWidth)._2.toInt)
dut.writeData.poke(newValue.U)
testArray(index) = newValue
dut.clock.step()
dut.maxValue.expect(testArray.max.U)
dut.indexOfMaxValue.expect(testArray.indexOf(testArray.max).U)
}
}
}
}
推荐阅读
- java - org.apache.axis2.AxisFault:传输错误:404 错误:
- c# - 如何在 Windows 窗体 C# 中为文本框提供特定模式
- python - inverse_transform 一个张量流变量:TypeError: __array__() 接受 1 个位置参数,但给出了 2 个
- markdown - MacDown:换行符分隔代码块后的有序列表项
- android - 如何从我的 android 设备获取我的生物特征数据并将其发送过来
- r - 在 R 中使用 sf 是 MULTIPOLYGON 中的一个点
- authentication - 如何在事件查看器的登录事件中查看用户名?
- python - 如何修复 Python 中的“RuntimeError:事件循环已关闭”问题?
- c - 数组输入比应有的条目多一个
- python - 如何使用 Datetime 修复 AttributeError?