概述

源代码:/xiangshan/frontend/Tage.scala

XiangShan的TAGE采用2拍延迟的TAGE主预测输出,3拍延迟的SC来校正TAGE主预测输出的结构。TAGE主预测器每周期预测2个Br的结果(numBr=2)。

参数

trait TageParams extends HasBPUConst with HasXSParameter {
  // println(BankTageTableInfos)
  // val TageTableInfos = coreParams.TageTableInfos
	// TageTableInfos: Seq[Tuple3[Int,Int,Int]] =
    //       Sets  Hist   Tag
    // Seq(( 4096,    8,    8),
    //     ( 4096,   13,    8),
    //     ( 4096,   32,    8),
    //     ( 4096,  119,    8)),
  val TageNTables = TageTableInfos.size // 4
  // val BankTageNTables = BankTageTableInfos.map(_.size) // Number of tage tables
  // val UBitPeriod = 256
  val TageCtrBits = 3
  val TickWidth = 7

  val USE_ALT_ON_NA_WIDTH = 4
  val NUM_USE_ALT_ON_NA = 128
  def use_alt_idx(pc: UInt) = (pc >> instOffsetBits)(log2Ceil(NUM_USE_ALT_ON_NA)-1, 0)

	// ...
}
  • TageNTables: TAGE使用tag索引的table个数,也就是Tx的个数,为4。4个表的结构如注释所示。
  • TageCtrBits: TAGE Tx表中预测饱和计数器的位数

TageMeta

TageMeta是Tage resp输出的预测信息,其Bundle的定义如下:

class TageMeta(implicit p: Parameters)
  extends TageBundle with HasSCParameter
{
  val providers = Vec(numBr, ValidUndirectioned(UInt(log2Ceil(TageNTables).W)))
  val providerResps = Vec(numBr, new TageResp)
  // val altProviders = Vec(numBr, ValidUndirectioned(UInt(log2Ceil(TageNTables).W)))
  // val altProviderResps = Vec(numBr, new TageResp)
  val altUsed = Vec(numBr, Bool())
  val altDiffers = Vec(numBr, Bool())
  val basecnts = Vec(numBr, UInt(2.W))
  val allocates = Vec(numBr, UInt(TageNTables.W))
  val takens = Vec(numBr, Bool())
  val scMeta = if (EnableSC) Some(new SCMeta(SCNTables)) else None // SCNTables = 4
  val pred_cycle = if (!env.FPGAPlatform) Some(UInt(64.W)) else None
  val use_alt_on_na = if (!env.FPGAPlatform) Some(Vec(numBr, Bool())) else None

  def altPreds = basecnts.map(_(1))
  def allocateValid = allocates.map(_.orR)
}
  • providers: 2个预测信息中,由哪个带标签的表Tx提供
  • providerResps: 2个预测信息中,表Tx提供的信息,包括预测的饱和计数器ctr以及有效域u
abstract class TageBundle(implicit p: Parameters)
  extends XSBundle with TageParams with BPUUtils

class TageResp(implicit p: Parameters) extends TageBundle {
  val ctr = UInt(TageCtrBits.W) // TageCtrBits = 3
  val u = Bool()
}
  • altUsed: 2个预测信息中,哪个使用了altpred
  • altDiffers: 2个预测信息中,哪个与pred预测不同
  • basecnts: T0的2位预测信息
  • allocates: 2个预测信息中,需要allocate哪个Tx
  • takens: 2个预测信息预测是否跳转
  • scMeta: SC相关的预测信息,包括tage主预测是否taken、是否使用SC、SC预测结果以及SC的ctr。
class SCMeta(val ntables: Int)(implicit p: Parameters) extends XSBundle with HasSCParameter {
  val tageTakens = Vec(numBr, Bool())
  val scUsed = Vec(numBr, Bool())
  val scPreds = Vec(numBr, Bool())
  // Suppose ctrbits of all tables are identical
  val ctrs = Vec(numBr, Vec(ntables, SInt(SCCtrBits.W)))  // SCCtrBits = 6
}
  • pred_cycle:
  • use_alt_on_na: 备选预测,若最长历史匹配结果信息不足,动态决定是否使用备选预测

TAGE

Tage Table

Tage table是TAGE的4个带Tag索引的Tx表对应的实现模块。4个Tx表的参数如下所示:

Sets histLen tagLen
T1 4096 8 8
T2 4096 13 8
T3 4096 32 8
T4 4096 119 8

Tage table中每个entry的组成如下:

class TageEntry() extends TageBundle {
  val valid = Bool()               
  val tag = UInt(tagLen.W)         
  val ctr = UInt(TageCtrBits.W)    
}
  • valid: 有效位
  • tag: 8位的tag,用于比较是否命中
  • ctr: 3位的饱和计数器,给出预测方向

SRAM参数

Tage table所使用的SRAM的大小参数如下所示:

val SRAM_SIZE = 256 // physical size
require(nRows % SRAM_SIZE == 0)
require(isPow2(numBr))
val nRowsPerBr = nRows / numBr
val nBanks = 8
val bankSize = nRowsPerBr / nBanks
val bankFoldWidth = if (bankSize >= SRAM_SIZE) bankSize / SRAM_SIZE else 1
val uFoldedWidth = nRowsPerBr / SRAM_SIZE
val uWays = uFoldedWidth * numBr
val uRows = SRAM_SIZE
  • SRAM_SIZE = 256 → 每个SRAM bank的大小
  • nRowsPerBr = 4096 / 2 = 2048 → 每个br所使用的entry的数目
  • nBanks = 8 → SRAM bank的数目
  • bankSize = 2048 / 8 = 256 → 每个SRAM bank的大小
  • bankFoldWidth = 1
  • uFoldedWidth = 2048 / 256 = 8 → useful域SRAM每行的位宽
  • uWays = 8 * 2 = 16 → unused
  • uRows = 256 → unused

索引、Tag以及折叠的全局分支历史管理

所有Tx表所使用的折叠全局分支历史管理的信息使用一个(Int, Int)的元组来表示,前者表示所需要的全局分支历史总长,后者表示哈希后的长度(看表大小/分支历史总长或分支历史总长或tag长度的最小项)。

val idxFhInfo = (histLen, min(log2Ceil(nRowsPerBr), histLen))
val tagFhInfo = (histLen, min(histLen, tagLen))
val altTagFhInfo = (histLen, min(histLen, tagLen-1))
val allFhInfos = Seq(idxFhInfo, tagFhInfo, altTagFhInfo)

// allFhInfos: 
// T1: Seq((8,   8 ), (8,   8), (8,   7))
// T2: Seq((13,  11), (13,  8), (13,  7))
// T3: Seq((32,  11), (32,  8), (32,  7))
// T4: Seq((119, 11), (119, 8), (119, 7))

XiangShan的全局分支历史管理比较tricky,先参照文档中的描述:

TAGE类预测器的每一个历史表都有一个特定的历史长度,为了与 PC异或后进行历史表的索引,很长的分支历史序列需要被分成很多段,然后全部异或起来。每一段的长度一般等于历史表深度的对数。由于异或的次数一般较多,为了避免预测路径上多级异或的时延,我们会直接存储折叠后的历史。由于不同长度历史折叠方式不同,所需折叠历史的份数等于(历史长度,折叠后长度)元组去重后的个数。在更新一位历史时只需要把折叠前的最老的那一位和最新的一位异或到相应的位置,再做一个移位操作即可。

XiangShan管理全局分支历史的类是FoldedHistory

class FoldedHistory(val len: Int, val compLen: Int, val max_update_num: Int)(implicit p: Parameters)
  extends XSBundle with HasBPUConst {
  require(compLen >= 1)
  require(len > 0)
  // require(folded_len <= len)
  require(compLen >= max_update_num)
  val folded_hist = UInt(compLen.W)

	//...
}

在uBTB以及TAGE中,获取折叠后的全局历史的方法比较类似,以TAGE中的为例子:

def getFoldedHistoryInfo = allFhInfos.filter(_._1 >0).toSet
def compute_tag_and_hash(unhashed_idx: UInt, allFh: AllFoldedHistories) = {
  val idx_fh = allFh.getHistWithInfo(idxFhInfo).folded_hist
  val tag_fh = allFh.getHistWithInfo(tagFhInfo).folded_hist
  val alt_tag_fh = allFh.getHistWithInfo(altTagFhInfo).folded_hist
  // require(idx_fh.getWidth == log2Ceil(nRows))
  val idx = (unhashed_idx ^ idx_fh)(log2Ceil(nRowsPerBr)-1, 0)
  val tag = (unhashed_idx ^ tag_fh ^ (alt_tag_fh << 1)) (tagLen - 1, 0)
  (idx, tag)
}

allFhAllFoldedHistories类的实例,AllFoldedHistories中定义如下:

class AllFoldedHistories(val gen: Seq[Tuple2[Int, Int]])(implicit p: Parameters) extends XSBundle with HasBPUConst {
  val hist = MixedVec(gen.map{case (l, cl) => new FoldedHistory(l, cl, numBr)})
  // println(gen.mkString)
  require(gen.toSet.toList.equals(gen))
  def getHistWithInfo(info: Tuple2[Int, Int]) = {
    val selected = hist.filter(_.info.equals(info))
    require(selected.length == 1)
    selected(0)
  }
}

可以看出,每个表所需要的(分支历史长度, 折叠后的长度)都分开来进行保存,并独立进行运算与更新。以T2的Tag Hash方式作为例子阐述整个全局分支历史管理的方法。T2的Tag Hash信息为:(13, 8)即对于AllFoldedHistories中的hist向量,对于T2 Tag生成的FoldedHistory实例的参数为(l=13, cl=8, numBr=2):

以下一步步分解FoldedHistory中的方法进行阐述如何获取折叠后的分支历史。

首先,FoldedHistory中只有一个成员变量foled_hist,用于存储对应该全局分支历史折叠后的结果:

val folded_hist = UInt(compLen.W)  // compLen = 8

对于T2 Tag Hash,folded_hist的位宽为8。接下来,是数个函数辅助折叠历史的计算:

def need_oldest_bits = len > compLen  // 13 > 8, True
def info = (len, compLen)             // (13, 8)
def oldest_bit_to_get_from_ghr = (0 until max_update_num).map(len - _ - 1)      // (0 until 2).map(13 - _ - 1) -> (12, 11)
def oldest_bit_pos_in_folded = oldest_bit_to_get_from_ghr map (_ % compLen)     // (12 % 8, 11 % 8)            -> (4 , 3 )
def oldest_bit_wrap_around = oldest_bit_to_get_from_ghr map (_ / compLen > 0)   // (True, True)

函数get_oldest_bits_from_ghr用于从未折叠的全局分支历史寄存器中获取最老的bit:

def get_oldest_bits_from_ghr(ghr: Vec[Bool], histPtr: CGHPtr) = {
  // TODO: wrap inc for histPtr value
  // For T2 Tag hash: (ghr((histPtr+13.U).value), ghr((histPtr+12.U).value))
  oldest_bit_to_get_from_ghr.map(i => ghr((histPtr + (i+1).U).value))
}

需要注意的是,histPtr是专为FIFO设计的指针,可以自动进行wrap around等操作。函数circular_shift_left将数据循环左移:

def circular_shift_left(src: UInt, shamt: Int) = {
  val srcLen = src.getWidth
  val src_doubled = Cat(src, src)
  val shifted = src_doubled(srcLen*2-1-shamt, srcLen-shamt)
  shifted
}

函数update是全局分支历史管理的核心,其内部定义了两个工具函数:bitsets_xor以及new_folded_hist

def bitsets_xor(len: Int, bitsets: Seq[Seq[Tuple2[Int, Bool]]]) = {
  val res = Wire(Vec(len, Bool()))
  val resArr = Array.fill(len)(List[Bool]())
  for (bs <- bitsets) {
    for ((n, b) <- bs) {
      resArr(n) = b :: resArr(n)
    }
  }
  for (i <- 0 until len) {
    if (resArr(i).length > 2) {
      println(f"[warning] update logic of foldest history has two or more levels of xor gates! " +
        f"histlen:${this.len}, compLen:$compLen, at bit $i")
    }
    if (resArr(i).length == 0) {
      println(f"[error] bits $i is not assigned in folded hist update logic! histlen:${this.len}, compLen:$compLen")
    }
    res(i) := resArr(i).foldLeft(false.B)(_^_)
  }
  res.asUInt
}

其作用是将对应数据位的分支历史进行异或折叠,在new_folded_hist中使用:

val new_folded_hist = if (need_oldest_bits) {
  val oldest_bits = ob
  require(oldest_bits.length == max_update_num)
  // mask off bits that do not update
  val oldest_bits_masked = oldest_bits.zipWithIndex.map{
    case (ob, i) => ob && (i < num).B
  }
  // if a bit does not wrap around, it should not be xored when it exits
  val oldest_bits_set = (0 until max_update_num).filter(oldest_bit_wrap_around).map(i => (oldest_bit_pos_in_folded(i), oldest_bits_masked(i)))

  // only the last bit could be 1, as we have at most one taken branch at a time
  val newest_bits_masked = VecInit((0 until max_update_num).map(i => taken && ((i+1) == num).B)).asUInt
  // if a bit does not wrap around, newest bits should not be xored onto it either
  val newest_bits_set = (0 until max_update_num).map(i => (compLen-1-i, newest_bits_masked(i)))

  val original_bits_masked = VecInit(folded_hist.asBools.zipWithIndex.map{
    case (fb, i) => fb && !(num >= (len-i)).B
  })
  val original_bits_set = (0 until compLen).map(i => (i, original_bits_masked(i)))

  // do xor then shift
  val xored = bitsets_xor(compLen, Seq(original_bits_set, oldest_bits_set, newest_bits_set))
  circular_shift_left(xored, num)
} else {
  // histLen too short to wrap around
  ((folded_hist << num) | taken)(compLen-1,0)
}

oldest_bits是从未折叠的全局分支历史寄存器中获取的最老的bits,对于T2则是(ghr((histPtr+13.U).value), ghr((histPtr+12.U).value))。为了方便起见,不妨假设histPtr=1,则为(ghr(14), ghr(13))。参数num用于指定当前更新对应的预测块中的哪一条Br,如果为0则表示没有Br,为2则表示预测块中的第二条Br,根据num需要masked不需要更新的oldest bits:

// mask off bits that do not update
val oldest_bits_masked = oldest_bits.zipWithIndex.map{
  case (ob, i) => ob && (i < num).B
}

// oldest_bits_masked in 3 cases:
// num = 2: (ghr(14), ghr(13))
// num = 1: (False  , ghr(13))
// num = 0: (False  , False  )

如果一个oldest bit没有wrap around,则它不应该进行异或:

// if a bit does not wrap around, it should not be xored when it exits
val oldest_bits_set = (0 until max_update_num).filter(oldest_bit_wrap_around).map(i => (oldest_bit_pos_in_folded(i), oldest_bits_masked(i)))

// For T2 Tag hash, all oldest bits are wrap around:
// oldest_bits_set = ((4, oldest_bits_masked(0)), (3, oldest_bits_masked(1)))

XiangShan每周期可以预测两个br,显然,taken的条件必须是当前位置的br是该预测块中的最后一个br:

// only the last bit could be 1, as we have at most one taken branch at a time
val newest_bits_masked = VecInit((0 until max_update_num).map(i => taken && ((i+1) == num).B)).asUInt
// newest_bits_masked = (taken && (1==num), taken && (2==num))

// if a bit does not wrap around, newest bits should not be xored onto it either
val newest_bits_set = (0 until max_update_num).map(i => (compLen-1-i, newest_bits_masked(i)))
// newest_bits_set = ((7, newest_bits_masked(0)), (6, newset_bits_masked(1)))

对于原先的fold hist:

val original_bits_masked = VecInit(folded_hist.asBools.zipWithIndex.map{
  case (fb, i) => fb && !(num >= (len-i)).B
})
// original_bits_masked in 3 cases: (fb(0), fb(1), fb(2), fb(3), fb(4), fb(5), fb(6), fb(7))
val original_bits_set = (0 until compLen).map(i => (i, original_bits_masked(i)))
// original_bits_set = ((0, fb(0)), (1, fb(1)), ..., (7, fb(7)))

进行异或操作:

// do xor then shift
val xored = bitsets_xor(compLen, Seq(original_bits_set, oldest_bits_set, newest_bits_set)) // compLen = 8
// original_bits_set = ((0, fb(0)), (1, fb(1)), ..., (7, fb(7)))
// oldest_bits_set   = ((4, oldest_bits_masked(0)), (3, oldest_bits_masked(1)))
// newest_bits_set   = ((7, newest_bits_masked(0)), (6, newest_bits_masked(1)))

/*
	res = Wire(Vec(8, Bool()))
	resArr = Array.fill(8)(List[Bool]())
	
	resArr(0) = fb(0)
	resArr(1) = fb(1)
	resArr(2) = fb(2)
	resArr(3) = fb(3) ^ oldest_bits_masked(1)
	resArr(4) = fb(4) ^ oldest_bits_masked(0)
	resArr(5) = fb(5)
	resArr(6) = fb(6) ^ newest_bits_masked(1)
	resArr(7) = fb(7) ^ newest_bits_masked(0)

	if num==1:
		res = (fb(0), fb(1), fb(2), fb(3) ^ ghr(13), fb(4)          , fb(5), fb(6)        , fb(7) ^ taken)
  elif num==2:
		res = (fb(0), fb(1), fb(2), fb(3) ^ ghr(13), fb(4) ^ ghr(14), fb(5), fb(6) ^ taken, fb(7)        )
*/

最后循环左移num位:

circular_shift_left(xored, num)
// num = 2
// src_doubled = (fb(7), fb(6) ^ taken, ..., fb(0), fb(7), fb(6) ^ taken, ..., fb(0))
// shifted     = (fb(5), fb(4) ^ ghr(14), fb(3) ^ ghr(13), fb(2), fb(1), fb(0), fb(7), fb(6) ^ taken)

到此为止完成对fold_hist的更新。实际上,这个运算等价于将ghr左移后进行折叠,与oldest_bits进行异或是为了去除原来已经折叠的位中oldest_bits的信息(自己与自己异或必为0)。

SRAM

回到Tage Table的代码中,获取index的方法很直接,直接将s0_pc右移1位(使用压缩指令集)即可:

val req_unhashed_idx = getUnhashedIdx(io.req.bits.pc)
val tables = TageTableInfos.zipWithIndex.map {
  case ((nRows, histLen, tagLen), i) => {
    val t = Module(new TageTable(nRows, histLen, tagLen, i))
    t.io.req.valid := io.s0_fire
    t.io.req.bits.pc := s0_pc
    t.io.req.bits.folded_hist := io.in.bits.folded_hist
    t.io.req.bits.ghist := io.in.bits.ghist
    t
  }
}

接下来声明TAGE所使用的SRAM:

// nRowPerBr = 2048, uFoldedWidth = 8
val us = withReset(reset.asBool || io.update.reset_u.reduce(_||_)) {
  Module(new FoldedSRAMTemplate(Bool(), set=nRowsPerBr, width=uFoldedWidth, way=numBr, shouldReset=true, holdRead=true, singlePort=true))
}

val table_banks = Seq.fill(nBanks)(
	Module(new FoldedSRAMTemplate(new TageEntry, set=bankSize, width=bankFoldWidth, way=numBr, shouldReset=true, holdRead=true, singlePort=true)))

us是用于存储Tx表中useful域信息的SRAM,其reset的条件不单止是顶层模块的reset,还包括定期reset useful域的情况。FoldedSRAMTemplate用于实例化SRAM:

class FoldedSRAMTemplate[T <: Data](gen: T, set: Int, width: Int = 4, way: Int = 1,
  shouldReset: Boolean = false, holdRead: Boolean = false, singlePort: Boolean = false, bypassWrite: Boolean = false) extends Module {
  val io = IO(new Bundle {
    val r = Flipped(new SRAMReadBus(gen, set, way))
    val w = Flipped(new SRAMWriteBus(gen, set, way))
  })
  //   |<----- setIdx ----->|
  //   | ridx | width | way |

  require(width > 0 && isPow2(width))
  require(way > 0 && isPow2(way))
  require(set % width == 0)

  val nRows = set / width

  val array = Module(new SRAMTemplate(gen, set=nRows, way=width*way, shouldReset=shouldReset, holdRead=holdRead, singlePort=singlePort))

  io.r.req.ready := array.io.r.req.ready
  io.w.req.ready := array.io.w.req.ready

  val raddr = io.r.req.bits.setIdx >> log2Ceil(width)
  val ridx = RegNext(if (width != 1) io.r.req.bits.setIdx(log2Ceil(width)-1, 0) else 0.U(1.W))
  val ren  = io.r.req.valid

  array.io.r.req.valid := ren
  array.io.r.req.bits.setIdx := raddr

  val rdata = array.io.r.resp.data
  for (w <- 0 until way) {
    val wayData = VecInit(rdata.indices.filter(_ % way == w).map(rdata(_)))
    val holdRidx = HoldUnless(ridx, RegNext(io.r.req.valid))
    val realRidx = if (holdRead) holdRidx else ridx
    io.r.resp.data(w) := Mux1H(UIntToOH(realRidx, width), wayData)
  }

  val wen = io.w.req.valid
  val wdata = VecInit(Seq.fill(width)(io.w.req.bits.data).flatten)
  val waddr = io.w.req.bits.setIdx >> log2Ceil(width)
  val widthIdx = if (width != 1) io.w.req.bits.setIdx(log2Ceil(width)-1, 0) else 0.U
  val wmask = (width, way) match {
    case (1, 1) => 1.U(1.W)
    case (x, 1) => UIntToOH(widthIdx)
    case _      => VecInit(Seq.tabulate(width*way)(n => (n / way).U === widthIdx && io.w.req.bits.waymask.get(n % way))).asUInt
  }
  require(wmask.getWidth == way*width)

  array.io.w.apply(wen, wdata, waddr, wmask)
}

useful域存储的SRAM一共有2048/8=256个set,每个set包括16路(2*8=16),其中每8个useful域对应一个Br。FoldSRAM内部的实现不再赘述。存放TAGE entry的SRAM使用8个bank SRAM进行存储,每个bank的大小为256个set,每个set包括2路,对应每周期预测的2条Br。

Tage Resp

读取TAGE信息,并将读取的信息反馈给输出端口:

val s1_unhashed_idx = RegEnable(req_unhashed_idx, io.req.fire)
val s1_idx = RegEnable(s0_idx, io.req.fire)
val s1_tag = RegEnable(s0_tag, io.req.fire)
val s1_pc  = RegEnable(io.req.bits.pc, io.req.fire)
val s1_bank_req_1h = RegEnable(s0_bank_req_1h, io.req.fire)
val s1_bank_has_write_last_cycle = RegNext(VecInit(table_banks.map(_.io.w.req.valid)))

val tables_r = table_banks.map(_.io.r.resp.data) // s1

val resp_selected = Mux1H(s1_bank_req_1h, tables_r)
val resp_invalid_by_write = Mux1H(s1_bank_req_1h, s1_bank_has_write_last_cycle)  // Update has higher prority than read

val per_br_resp = VecInit((0 until numBr).map(i => Mux1H(UIntToOH(get_phy_br_idx(s1_unhashed_idx, i), numBr), resp_selected)))
val per_br_u    = VecInit((0 until numBr).map(i => Mux1H(UIntToOH(get_phy_br_idx(s1_unhashed_idx, i), numBr), us.io.r.resp.data)))

val req_rhits = VecInit((0 until numBr).map(i =>
  per_br_resp(i).valid && per_br_resp(i).tag === s1_tag && !resp_invalid_by_write
))

for (i <- 0 until numBr) {
  io.resps(i).valid := req_rhits(i)
  io.resps(i).bits.ctr := per_br_resp(i).ctr
  io.resps(i).bits.u := per_br_u(i)
}

s1_x系列信号为s0打一拍的结果,用于表示当前从TAGE SRAM读取的数据的信息。值得注意的是,实现中TAGE读的优先级是低于写的。

per_br_resp以及per_br_u的逻辑比较绕,如下所示:

val unshuffleBitWidth = log2Ceil(numBr)
def get_unshuffle_bits(idx: UInt) = idx(unshuffleBitWidth-1, 0)
// xor hashes are reversable
def get_phy_br_idx(unhashed_idx: UInt, br_lidx: Int)  = get_unshuffle_bits(unhashed_idx) ^ br_lidx.U(log2Ceil(numBr).W)

val per_br_resp = VecInit((0 until numBr).map(i => Mux1H(UIntToOH(get_phy_br_idx(s1_unhashed_idx, i), numBr), resp_selected)))
// (Mux1H(UIntToOH(s1_unhashed_idx(0) ^ 0.U(1.W), 2), resp_selected),
//  Mux1H(UIntToOH(s1_unhashed_idx(0) ^ 1.U(1.W), 2), resp_selected))
/*
	if s1_unhashed_idx(0)==1'b0:
		per_br_resp = (resp_selected(0), resp_selected(1))
  elif s1_unhashed_idx(0)==1'b1:
		per_br_resp = (resp_selected(1), resp_selected(0))
*/

val per_br_u    = VecInit((0 until numBr).map(i => Mux1H(UIntToOH(get_phy_br_idx(s1_unhashed_idx, i), numBr), us.io.r.resp.data)))

然后对2个way(对应两个br)读取的数据进行比较,看是否hit,并给到resp接口:

val req_rhits = VecInit((0 until numBr).map(i =>
  per_br_resp(i).valid && per_br_resp(i).tag === s1_tag && !resp_invalid_by_write
))

for (i <- 0 until numBr) {
  io.resps(i).valid := req_rhits(i)
  io.resps(i).bits.ctr := per_br_resp(i).ctr
  io.resps(i).bits.u := per_br_u(i)
}

Tage Update

Tage Update的基本逻辑与读相似,都需要计算idx以及tag,该部分不再赘述。需要注意的是,Tage Table的update需要与TAGE top模块相结合,这里仅针对Tage Table内的逻辑进行说明,top模块将在下文阐述。

首先计算每个bank中哪个way(哪个br)需要被更新:

val per_bank_not_silent_update = Wire(Vec(nBanks, Vec(numBr, Bool()))) // corresponds to physical branches
val per_bank_update_way_mask =
  VecInit((0 until nBanks).map(b =>
    VecInit((0 until numBr).map(pi => {
      // whether any of the logical branches updates on each slot
      Seq.tabulate(numBr)(li =>
        get_phy_br_idx(update_unhashed_idx, li) === pi.U &&
        io.update.mask(li)).reduce(_||_) && per_bank_not_silent_update(b)(pi)
    })).asUInt
  ))

值得注意的是,是否需要更新的条件中还包括per_bank_not_silent_update,用于表示需要update的ctr值是否已经饱和,还没饱和就可以继续update。

for (b <- 0 until nBanks) {
  table_banks(b).io.w.apply(
    valid   = io.update.mask.reduce(_||_) && update_req_bank_1h(b) && per_bank_not_silent_update(b).reduce(_||_),
    data    = per_bank_update_wdata(b),
    setIdx  = update_idx_in_bank,
    waymask = per_bank_update_way_mask(b)
  )
}

对Tx表中的各个bank进行更新:

  • valid: 有效条件:
    • 当前表需要更新(update or allocate,详见TAGE top模块逻辑)
    • 更新的index命中该bank
    • 当前bank需要更新的ctr值还没饱和(非silent update)
  • data: 需要更新的数据,它不一定是io传进来的更新数据,也有可能是bypass的数据,详见下文
  • setIdx: bank内部的idx
  • waymask: 需要更新的way(br)

useful域的更新逻辑类似,在此不再赘述:

val update_u_idx = update_idx
val update_u_way_mask = VecInit((0 until numBr).map(pi => {
  Seq.tabulate(numBr)(li =>
    get_phy_br_idx(update_unhashed_idx, li) === pi.U &&
    io.update.uMask(li)
  ).reduce(_||_)
})).asUInt

val update_u_wdata = VecInit((0 until numBr).map(pi =>
  Mux1H(Seq.tabulate(numBr)(li =>
    (get_phy_br_idx(update_unhashed_idx, li) === pi.U, io.update.us(li))
  ))
))

us.io.w.apply(io.update.uMask.reduce(_||_), update_u_wdata, update_u_idx, update_u_way_mask)

函数silentUpdate用于判断当前需要update的ctr值是否已经饱和:

def silentUpdate(ctr: UInt, taken: Bool) = {
  ctr.andR && taken || !ctr.orR && !taken
}

Bypass更新策略

为了防止“假Update”的情况频繁发生,XiangShan采取了一个Bypass的TAGE更新策略,可以有效减少假Update的数量。

XiangShan在Tage Table中例化了WrBypass模块来优化假Update的情形:

// perBankWrbypassEntries = 8
// log2Ceil(nRowsPerBr/nBanks) = 8
val bank_wrbypasses = Seq.fill(nBanks)(Seq.fill(numBr)(
  Module(new WrBypass(UInt(TageCtrBits.W), perBankWrbypassEntries, log2Ceil(nRowsPerBr/nBanks), tagWidth=tagLen))
)) // let it corresponds to logical brIdx

每个nank的每个way(br)都使用了一个8-entry的buffer来暂存最近更新到当前bank的信息,包括tag、idx以及ctr的值。

WrBypass类中包含了一个存储tag、idx信息以及比较的模块idx_tag_cam,它是CAMTemplate类的实例:

val idx_tag_cam = Module(new CAMTemplate(new Idx_Tag, numEntries, 1))

以及保存ctr信息的寄存器组data_mem,以及valid信息:

val data_mem = Mem(numEntries, Vec(numWays, gen))
val valids = RegInit(0.U.asTypeOf(Vec(numEntries, Vec(numWays, Bool()))))

若需要update到bank的信息在idx_tag_cam中hit,则将data_mem中对应的ctr值读取出来并送到输出端口。如果没有hit,则allocate一个entry,并写入对应tag、idx以及ctr值信息,allocate的方式采用RR(轮转)。

回到Tage Table,接下来判断需要update到TAGE的数据是从外部(ftq)传入的老的ctr还是bypass的较新的ctr,并判断该ctr是否已经饱和。如果hit了wrbypass,则使用bypass的ctr值:

for (b <- 0 until nBanks) {
  val not_silent_update = per_bank_not_silent_update(b)
  for (pi <- 0 until numBr) { // physical brIdx 
    val update_wdata = per_bank_update_wdata(b)(pi)
    val br_lidx = get_lgc_br_idx(update_unhashed_idx, pi.U(log2Ceil(numBr).W))
    // this 
    val wrbypass_io = Mux1H(UIntToOH(br_lidx, numBr), bank_wrbypasses(b).map(_.io))
    val wrbypass_hit = wrbypass_io.hit
    val wrbypass_ctr = wrbypass_io.hit_data(0).bits
    update_wdata.ctr :=
      Mux(io.update.alloc(br_lidx),
        Mux(io.update.takens(br_lidx), 4.U, 3.U),
        Mux(wrbypass_hit,
          inc_ctr(wrbypass_ctr,               io.update.takens(br_lidx)),
          inc_ctr(io.update.oldCtrs(br_lidx), io.update.takens(br_lidx))
        )
      )
    not_silent_update(pi) :=
      Mux(wrbypass_hit,
        !silentUpdate(wrbypass_ctr,               io.update.takens(br_lidx)),
        !silentUpdate(io.update.oldCtrs(br_lidx), io.update.takens(br_lidx))) ||
      io.update.alloc(br_lidx)

    update_wdata.valid := true.B
    update_wdata.tag   := update_tag
  }

  for (li <- 0 until numBr) {
    val wrbypass = bank_wrbypasses(b)(li)
    val br_pidx = get_phy_br_idx(update_unhashed_idx, li)
    wrbypass.io.wen := io.update.mask(li) && update_req_bank_1h(b)
    wrbypass.io.write_idx := get_bank_idx(update_idx)
    wrbypass.io.write_tag.map(_ := update_tag)
    wrbypass.io.write_data(0) := Mux1H(UIntToOH(br_pidx, numBr), per_bank_update_wdata(b)).ctr
  }
}

最后还需要把更新的信息再保存到wrbypass当中。

TageBTable

btTageBTable类的实例,即基表T0。T0直接使用预测块的pc进行索引,且只使用两位的饱和计数器来进行预测。TageBTable与Tage Table的实现类似,同样也使用了bypass的更新机制,在这里不再赘述。

Top逻辑

s1_resps获取从Tx表中得到的预测信息:

val s1_resps = VecInit(tables.map(_.io.resps))

接下来是声明各个流水级预测信息的中间变量,s1_x表示当拍从TAGE读取的预测信息,s2_x表示上一拍从TAGE读取的预测信息(s1_x打一拍):

val s1_provideds        = Wire(Vec(numBr, Bool()))
val s1_providers        = Wire(Vec(numBr, UInt(log2Ceil(TageNTables).W)))
val s1_providerResps    = Wire(Vec(numBr, new TageResp))
val s1_altUsed          = Wire(Vec(numBr, Bool()))
val s1_tageTakens       = Wire(Vec(numBr, Bool()))
val s1_finalAltPreds    = Wire(Vec(numBr, Bool()))
val s1_basecnts         = Wire(Vec(numBr, UInt(2.W)))
val s1_useAltOnNa       = Wire(Vec(numBr, Bool()))

val s2_provideds        = RegEnable(s1_provideds, io.s1_fire)
val s2_providers        = RegEnable(s1_providers, io.s1_fire)
val s2_providerResps    = RegEnable(s1_providerResps, io.s1_fire)
val s2_altUsed          = RegEnable(s1_altUsed, io.s1_fire)
val s2_tageTakens       = RegEnable(s1_tageTakens, io.s1_fire)
val s2_finalAltPreds    = RegEnable(s1_finalAltPreds, io.s1_fire)
val s2_basecnts         = RegEnable(s1_basecnts, io.s1_fire)
val s2_useAltOnNa       = RegEnable(s1_useAltOnNa, io.s1_fire)

update相关的逻辑定义如下,首先判断2个br当中,哪个是需要进行更新的:

val updateValids = VecInit((0 until TageBanks).map(w =>
    update.ftb_entry.brValids(w) && u_valid && !update.ftb_entry.always_taken(w) &&
    !(PriorityEncoder(update.full_pred.br_taken_mask) < w.U)))

值得注意的是,FTB中实现了always_taken的机制,即第一次遇到的分支该位都置为1,也就是总是预测跳转。因此要更新TAGE该位必须要置0。接下来定义update相关的中间变量:

val updateMask    = WireInit(0.U.asTypeOf(Vec(numBr, Vec(TageNTables, Bool()))))
val updateUMask   = WireInit(0.U.asTypeOf(Vec(numBr, Vec(TageNTables, Bool()))))
val updateResetU  = WireInit(0.U.asTypeOf(Vec(numBr, Bool()))) // per predictor
val updateTakens  = Wire(Vec(numBr, Vec(TageNTables, Bool())))
val updateAlloc   = WireInit(0.U.asTypeOf(Vec(numBr, Vec(TageNTables, Bool()))))
val updateOldCtrs = Wire(Vec(numBr, Vec(TageNTables, UInt(TageCtrBits.W))))
val updateU       = Wire(Vec(numBr, Vec(TageNTables, Bool())))
val updatebcnt    = Wire(Vec(TageBanks, UInt(2.W)))
val baseupdate    = WireInit(0.U.asTypeOf(Vec(TageBanks, Bool())))
val bUpdateTakens = Wire(Vec(TageBanks, Bool()))

TageTableInfo用于将每个Tx表的resp与该表的index关联起来:

class TageTableInfo(implicit p: Parameters) extends XSBundle {
  val resp = new TageResp
  val tableIdx = UInt(log2Ceil(TageNTables).W)
}

下面针对2条br,进行相关的resp以及update处理工作。inputRes将各个Tx表的信息整合为元组:(Tx的resp有效, Tx的tableInfo)

val inputRes = s1_per_br_resp.zipWithIndex.map{case (r, idx) => {
    val tableInfo = Wire(new TageTableInfo)
    tableInfo.resp := r.bits
    tableInfo.tableIdx := idx.U(log2Ceil(TageNTables).W)
    (r.valid, tableInfo)
  }}

providerInfo给出最长历史表命中的预测信息,provided表示是否命中Tx表:

val providerInfo = ParallelPriorityMux(inputRes.reverse)
val provided = inputRes.map(_._1).reduce(_||_)

赋值给s1_x以及resp_meta

s1_provideds(i)      := provided
s1_providers(i)      := providerInfo.tableIdx
s1_providerResps(i)  := providerInfo.resp

resp_meta.providers(i).valid    := RegEnable(s2_provideds(i), io.s2_fire)
resp_meta.providers(i).bits     := RegEnable(s2_providers(i), io.s2_fire)
resp_meta.providerResps(i)      := RegEnable(s2_providerResps(i), io.s2_fire)
resp_meta.pred_cycle.map(_ := RegEnable(GTimer(), io.s2_fire))
resp_meta.use_alt_on_na.map(_(i) := RegEnable(s2_useAltOnNa(i), io.s2_fire))

resp_meta的信息为s2_x信号打一拍后的信息。

allocatableSlots用于计算可以进行allocate的表的信息:

val allocatableSlots =
  RegEnable(
    VecInit(s1_per_br_resp.map(r => !r.valid && !r.bits.u)).asUInt &
      ~(LowerMask(UIntToOH(s1_providers(i)), TageNTables) &
        Fill(TageNTables, s1_provideds(i).asUInt)),
    io.s1_fire
  )
// Asume s1_providers(i) = 1
// ~LowerMask(UIntToOH(s1_providers(i)), TageNTables) = ~(0011) = 1100
// So if the longest hit table is T2, then T3 and T4 is allocatable

resp_meta.allocates(i) := RegEnable(allocatableSlots, io.s2_fire)

Tage读取的s1最终预测结果需要经过USE_ALT_ON_NA的策略进行选择,详见USE_ALT_ON_NA

后续是resp_meta以及resp_s2的赋值:

resp_meta.altUsed(i)    := RegEnable(s2_altUsed(i), io.s2_fire)
resp_meta.altDiffers(i) := RegEnable(s2_finalAltPreds(i) =/= s2_tageTakens(i), io.s2_fire)
resp_meta.takens(i)     := RegEnable(s2_tageTakens(i), io.s2_fire)
resp_meta.basecnts(i)   := RegEnable(s2_basecnts(i), io.s2_fire)

when (io.ctrl.tage_enable) {
  resp_s2.full_pred.br_taken_mask(i) := s2_tageTakens(i)
}

update首先需要整理update相关的信息:

val hasUpdate = updateValids(i)
val updateMispred = updateMisPreds(i)
val updateTaken = hasUpdate && update.full_pred.br_taken_mask(i)

val updateProvided     = updateMeta.providers(i).valid
val updateProvider     = updateMeta.providers(i).bits
val updateProviderResp = updateMeta.providerResps(i)
val updateProviderCorrect = updateProviderResp.ctr(TageCtrBits-1) === updateTaken
val updateUseAlt = updateMeta.altUsed(i)
val updateAltDiffers = updateMeta.altDiffers(i)
val updateAltIdx = use_alt_idx(update.pc)
val updateUseAltCtr = Mux1H(UIntToOH(updateAltIdx, NUM_USE_ALT_ON_NA), useAltOnNaCtrs(i))
val updateAltPred = updateMeta.altPreds(i)
val updateAltCorrect = updateAltPred === updateTaken

val updateProviderWeak = unconf(updateProviderResp.ctr)

然后将对应Tx表以及T0表的更新信息给到各个表:

when (hasUpdate) {
  when (updateProvided) {
    updateMask(i)(updateProvider) := true.B
    updateUMask(i)(updateProvider) := updateAltDiffers
    updateU(i)(updateProvider) := updateProviderCorrect
    updateTakens(i)(updateProvider) := updateTaken
    updateOldCtrs(i)(updateProvider) := updateProviderResp.ctr
    updateAlloc(i)(updateProvider) := false.B
  }
}

// update base table if used base table to predict
baseupdate(i) := hasUpdate && updateUseAlt
updatebcnt(i) := updateMeta.basecnts(i)
bUpdateTakens(i) := updateTaken

是否需要allocate的判断:更新有效且预测错误且不存在同时满足使用alt、最长预测正确以及hit Tx的情况:

val needToAllocate = hasUpdate && updateMispred && !(updateUseAlt && updateProviderCorrect && updateProvided)

allocate的位置取决于随机选中的Tx是否可以进行allocate,如果不可以则选用idx最小的可以进行allocate的表:

val allocLFSR = LFSR64()(TageNTables - 1, 0)
val longerHistoryTableMask = ~(LowerMask(UIntToOH(updateProvider), TageNTables) & Fill(TageNTables, updateProvided.asUInt))
val canAllocMask = allocatableMask & longerHistoryTableMask
val allocFailureMask = ~allocatableMask & longerHistoryTableMask
val firstEntry = PriorityEncoder(canAllocMask)
val maskedEntry = PriorityEncoder(canAllocMask & allocLFSR)
val allocate = Mux(canAllocMask(maskedEntry), maskedEntry, firstEntry)

然后将需要allocate的信息给到各个Tx表:

when (canAllocate) {
  updateMask(i)(allocate) := true.B
  updateTakens(i)(allocate) := updateTaken
  updateAlloc(i)(allocate) := true.B
  updateUMask(i)(allocate) := true.B
  updateU(i)(allocate) := false.B
}

最后将所有的update以及allocate信息给到T0~T4:

for (w <- 0 until TageBanks) {
  for (i <- 0 until TageNTables) {
    tables(i).io.update.mask(w)    := RegNext(updateMask(w)(i))
    tables(i).io.update.takens(w)  := RegNext(updateTakens(w)(i))
    tables(i).io.update.alloc(w)   := RegNext(updateAlloc(w)(i))
    tables(i).io.update.oldCtrs(w) := RegNext(updateOldCtrs(w)(i))

    tables(i).io.update.uMask(w)   := RegNext(updateUMask(w)(i))
    tables(i).io.update.us(w)      := RegNext(updateU(w)(i))
    tables(i).io.update.reset_u(w) := RegNext(updateResetU(w))
    // use fetch pc instead of instruction pc
    tables(i).io.update.pc       := RegNext(update.pc)
    tables(i).io.update.folded_hist := RegNext(updateFHist)
    tables(i).io.update.ghist := RegNext(io.update.bits.ghist)
  }
}
bt.io.update_mask := RegNext(baseupdate)
bt.io.update_cnt := RegNext(updatebcnt)
bt.io.update_pc := RegNext(update.pc)
bt.io.update_takens := RegNext(bUpdateTakens)

// all should be ready for req
io.s1_ready := tables.map(_.io.req.ready).reduce(_&&_)

动态重置useful域机制

XiangShan对于Tx表中的useful域采用了一种动态重置的机制,首先定义相关的计数器:

// TickWidth = 7, Max cnt value = 127
val bankTickCtrDistanceToTops = Seq.fill(numBr)(RegInit((1 << (TickWidth-1)).U(TickWidth.W)))  // initial valud = 7'h40
val bankTickCtrs = Seq.fill(numBr)(RegInit(0.U(TickWidth.W)))

在实际的判断当中,使用的计数器为bankTickCtrs

动态重置的策略取决于allocate的情况:

val tickInc = PopCount(allocFailureMask) > PopCount(canAllocMask)
val tickDec = PopCount(canAllocMask) > PopCount(allocFailureMask)
val tickIncVal = PopCount(allocFailureMask) - PopCount(canAllocMask)
val tickDecVal = PopCount(canAllocMask) - PopCount(allocFailureMask)
val tickToPosSat = tickIncVal >= bankTickCtrDistanceToTops(i) && tickInc
val tickToNegSat = tickDecVal >= bankTickCtrs(i) && tickDec

when (needToAllocate) {
  // val allocate = updateMeta.allocates(i).bits
  when (tickInc) {
    when (tickToPosSat) {
      bankTickCtrs(i) := ((1 << TickWidth) - 1).U
      bankTickCtrDistanceToTops(i) := 0.U
    }.otherwise {
      bankTickCtrs(i) := bankTickCtrs(i) + tickIncVal
      bankTickCtrDistanceToTops(i) := bankTickCtrDistanceToTops(i) - tickIncVal
    }
  }.elsewhen (tickDec) {
    when (tickToNegSat) {
      bankTickCtrs(i) := 0.U
      bankTickCtrDistanceToTops(i) := ((1 << TickWidth) - 1).U
    }.otherwise {
      bankTickCtrs(i) := bankTickCtrs(i) - tickDecVal
      bankTickCtrDistanceToTops(i) := bankTickCtrDistanceToTops(i) + tickDecVal
    }
  }
  when (canAllocate) {
    updateMask(i)(allocate) := true.B
    updateTakens(i)(allocate) := updateTaken
    updateAlloc(i)(allocate) := true.B
    updateUMask(i)(allocate) := true.B
    updateU(i)(allocate) := false.B
  }
  when (bankTickCtrs(i) === ((1 << TickWidth) - 1).U) {
    bankTickCtrs(i) := 0.U
    bankTickCtrDistanceToTops(i) := ((1 << TickWidth) - 1).U
    updateResetU(i) := true.B
  }
}

when (bankTickCtrs(i) === ((1 << TickWidth) - 1).U) {
  bankTickCtrs(i) := 0.U
  bankTickCtrDistanceToTops(i) := ((1 << TickWidth) - 1).U
  updateResetU(i) := true.B
}

每当可以allocate的表比不能allocate的表多时,bankTickCtrs会减少,反之则会增多。当其为全1时,对useful域进行reset。

USE_ALT_ON_NA机制

对于TAGE最长历史命中预测信心不足时,XiangShan实现了一个动态采用备选预测的机制,首先定义相关的计数器:

val useAltOnNaCtrs = RegInit(
  VecInit(Seq.fill(numBr)(
    VecInit(Seq.fill(NUM_USE_ALT_ON_NA)((1 << (USE_ALT_ON_NA_WIDTH-1)).U(USE_ALT_ON_NA_WIDTH.W)))
  ))
)

对于每个br,providerUnconf判断最长历史命中的预测是否信息不足,useAltCtr根据预测块的pc索引USE_ALT_ON_NA的计数器,useAltOnNa使用useAltCtr的最高位来判断是否使用备选预测:

val providerUnconf = unconf(providerInfo.resp.ctr)
val useAltCtr = Mux1H(UIntToOH(use_alt_idx(s1_pc), NUM_USE_ALT_ON_NA), useAltOnNaCtrs(i))
val useAltOnNa = useAltCtr(USE_ALT_ON_NA_WIDTH-1) // highest bit
val s1_bimCtr = bt.io.s1_cnt(i)
s1_tageTakens(i) := 
  Mux(!provided || providerUnconf && useAltOnNa,
    s1_bimCtr(1),
    providerInfo.resp.ctr(TageCtrBits-1)
  )
s1_altUsed(i)       := !provided || providerUnconf && useAltOnNa
s1_finalAltPreds(i) := s1_bimCtr(1)
s1_basecnts(i)      := s1_bimCtr
s1_useAltOnNa(i)    := providerUnconf && useAltOnNa

如果Tx表没有hit,或者当前hit的最长历史信息信心不足且对应的useAltCtr计数器表示选用备选逻辑,则使用备选逻辑,从timing考虑,Xiangshan默认备选逻辑使用T0。

update时,如果update提供的最长历史resp信息信心不足,且备选预测与最长预测信息不符,则需要更新对应的useAltOnNaCtrs

when (hasUpdate) {
  when (updateProvided && updateProviderWeak && updateAltDiffers) {
    val newCtr = satUpdate(updateUseAltCtr, USE_ALT_ON_NA_WIDTH, updateAltCorrect)
    useAltOnNaCtrs(i)(updateAltIdx) := newCtr
  }
}

RISC-V CPU design engineer.