Skip to content

Commit

Permalink
[ESI] Add hostmem write support to cosim
Browse files Browse the repository at this point in the history
  • Loading branch information
teqdruid committed Jan 9, 2025
1 parent 4e47877 commit 8b29dfd
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 29 deletions.
45 changes: 45 additions & 0 deletions frontends/PyCDE/integration_test/esitester.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,50 @@ def construct(ports):
mem_data_ce.assign(hostmem_read_resp_valid)


class WriteMem(Module):
clk = Clock()
rst = Reset()

@generator
def construct(ports):
cmd_chan_wire = Wire(Channel(esi.MMIOReadWriteCmdType))
resp_ready_wire = Wire(Bits(1))
cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire)
mmio_xact = cmd_valid & resp_ready_wire

write_loc_ce = mmio_xact & cmd.write & (cmd.offset == UInt(32)(0))
write_loc = Reg(UInt(64),
clk=ports.clk,
rst=ports.rst,
rst_value=0,
ce=write_loc_ce)
write_loc.assign(cmd.data.as_uint())

response_data = write_loc.as_bits()
response_chan, response_ready = Channel(Bits(64)).wrap(
response_data, cmd_valid)
resp_ready_wire.assign(response_ready)

mmio_rw = esi.MMIO.read_write(appid=AppID("WriteMem"))
mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd']
cmd_chan_wire.assign(mmio_rw_cmd_chan)

tag = Counter(8)(clk=ports.clk, rst=ports.rst, increment=mmio_xact)

cycle_counter = Counter(64)(clk=ports.clk,
rst=ports.rst,
increment=Bits(1)(1))

hostmem_write_req, _ = esi.HostMem.wrap_write_req(
write_loc,
cycle_counter.out.as_bits(),
tag.out,
valid=mmio_xact.reg(ports.clk, ports.rst))

hostmem_write_resp = esi.HostMem.write(appid=AppID("WriteMem_hostwrite"),
req=hostmem_write_req)


class EsiTesterTop(Module):
clk = Clock()
rst = Reset()
Expand All @@ -122,6 +166,7 @@ class EsiTesterTop(Module):
def construct(ports):
PrintfExample(clk=ports.clk, rst=ports.rst)
ReadMem(clk=ports.clk, rst=ports.rst)
WriteMem(clk=ports.clk, rst=ports.rst)


if __name__ == "__main__":
Expand Down
67 changes: 62 additions & 5 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from struct import Struct

from ..common import Clock, Input, Output, Reset
from ..constructs import AssignableSignal, ControlReg, NamedWire, Wire
Expand Down Expand Up @@ -266,34 +267,90 @@ class ChannelHostMemImpl(esi.ServiceImplementation):
clk = Clock()
rst = Reset()

UpstreamReq = StructType([
UpstreamReadReq = StructType([
("address", UInt(64)),
("length", UInt(32)),
("tag", UInt(8)),
])
read = Output(
Bundle([
BundledChannel("req", ChannelDirection.TO, UpstreamReq),
BundledChannel("req", ChannelDirection.TO, UpstreamReadReq),
BundledChannel(
"resp", ChannelDirection.FROM,
StructType([
("tag", UInt(8)),
("data", Bits(read_width)),
])),
]))
UpstreamWriteReq = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", Bits(write_width)),
])
write = Output(
Bundle([
BundledChannel("req", ChannelDirection.TO, UpstreamWriteReq),
BundledChannel("ackTag", ChannelDirection.FROM, UInt(8)),
]))

@generator
def generate(ports, bundles: esi._ServiceGeneratorBundles):
read_reqs = [req for req in bundles.to_client_reqs if req.port == 'read']
ports.read = ChannelHostMemImpl.build_tagged_read_mux(ports, read_reqs)
write_reqs = [
req for req in bundles.to_client_reqs if req.port == 'write'
]
ports.write = ChannelHostMemImpl.build_write_mux(ports, write_reqs)

@staticmethod
def build_write_mux(ports,
reqs: List[esi._OutputBundleSetter]) -> BundleSignal:
"""Build the write side of the HostMem service."""
if len(reqs) == 0:
req, _ = Channel(ChannelHostMemImpl.UpstreamWriteReq).wrap(
{
"address": 0,
"tag": 0,
"data": 0
}, 0)
write_bundle, _ = ChannelHostMemImpl.write.type.pack(req=req)
return write_bundle

# TODO: mux together multiple write clients.
assert len(reqs) == 1, "Only one write client supported for now."

write_channels: List[ChannelSignal] = []
write_acks = []
for idx, req in enumerate(reqs):
reqch = [c.channel for c in req.type.channels if c.name == 'req'][0]
data_type = reqch.inner_type.data
assert data_type == Bits(
write_width
), f"Gearboxing not yet supported. Client {req.client_name}"

write_ack = Wire(Channel(UInt(8)))
write_acks.append(write_ack)
write_req_bundle_type = esi.HostMem.write_req_bundle_type(data_type)
bundle_sig, froms = write_req_bundle_type.pack(ackTag=write_ack)
tagged_client_req = froms["req"]
write_channels.append(tagged_client_req)
req.assign(bundle_sig)

tagged_write_channel = esi.ChannelMux(write_channels)
upstream_write_bundle, froms = ChannelHostMemImpl.write.type.pack(
req=tagged_write_channel)
ack_tag = froms["ackTag"]
# TODO: decode the ack tag and assign it to the correct client.
write_acks[0].assign(ack_tag)
return upstream_write_bundle

@staticmethod
def build_tagged_read_mux(
ports, reqs: List[esi._OutputBundleSetter]) -> BundleSignal:
"""Build the read side of the HostMem service."""

if len(reqs) == 0:
req, req_ready = Channel(ChannelHostMemImpl.UpstreamReq).wrap(
req, req_ready = Channel(ChannelHostMemImpl.UpstreamReadReq).wrap(
{
"tag": 0,
"length": 0,
Expand All @@ -305,7 +362,7 @@ def build_tagged_read_mux(
# TODO: mux together multiple read clients.
assert len(reqs) == 1, "Only one read client supported for now."

req = Wire(Channel(ChannelHostMemImpl.UpstreamReq))
req = Wire(Channel(ChannelHostMemImpl.UpstreamReadReq))
read_bundle, froms = ChannelHostMemImpl.read.type.pack(req=req)
resp_chan_ready = Wire(Bits(1))
resp_data, resp_valid = froms["resp"].unwrap(resp_chan_ready)
Expand Down Expand Up @@ -335,7 +392,7 @@ def build_tagged_read_mux(

# Assign the multiplexed read request to the upstream request.
req.assign(
client_req.transform(lambda r: ChannelHostMemImpl.UpstreamReq({
client_req.transform(lambda r: ChannelHostMemImpl.UpstreamReadReq({
"address": r.address,
"length": 1,
"tag": r.tag
Expand Down
6 changes: 6 additions & 0 deletions frontends/PyCDE/src/pycde/bsp/cosim.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def build(ports):
resp_wire.type)
resp_wire.assign(data)

ack_wire = Wire(Channel(UInt(8)))
write_req = hostmem.write.unpack(ackTag=ack_wire)['req']
ack_tag = esi.CallService.call(esi.AppID("__cosim_hostmem_write"),
write_req, UInt(8))
ack_wire.assign(ack_tag)

class ESI_Cosim_Top(Module):
clk = Clock()
rst = Input(Bits(1))
Expand Down
36 changes: 21 additions & 15 deletions frontends/PyCDE/src/pycde/esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,24 +519,32 @@ class _HostMem(ServiceDecl):
("tag", UInt(8)),
])

WriteReqType = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", Any()),
])

def __init__(self):
super().__init__(self.__class__)

def write_req_bundle_type(self, data_type: Type) -> Bundle:
write_req_type = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", data_type),
])
return Bundle([
BundledChannel("req", ChannelDirection.FROM, write_req_type),
BundledChannel("ackTag", ChannelDirection.TO, UInt(8))
])

def write_req_channel_type(self, data_type: Type) -> StructType:
return StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", data_type),
])

def wrap_write_req(self, address: UIntSignal, data: Type, tag: UIntSignal,
valid: BitsSignal) -> Tuple[ChannelSignal, BitsSignal]:
"""Create the proper channel type for a write request and use it to wrap the
given request arguments. Returns the Channel signal and a ready bit."""
inner_type = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", data.type),
])
inner_type = self.write_req_channel_type(data.type)
return Channel(inner_type).wrap(
inner_type({
"address": address,
Expand All @@ -548,10 +556,8 @@ def write(self, appid: AppID, req: ChannelSignal) -> ChannelSignal:
"""Create a write request to the host memory out of a request channel."""
self._materialize_service_decl()

write_bundle_type = Bundle([
BundledChannel("req", ChannelDirection.FROM, _HostMem.WriteReqType),
BundledChannel("ackTag", ChannelDirection.TO, UInt(8))
])
req_data_type = req.type.inner_type.data
write_bundle_type = self.write_req_bundle_type(req_data_type)

bundle = cast(
BundleSignal,
Expand Down
5 changes: 3 additions & 2 deletions frontends/PyCDE/src/pycde/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,8 +809,9 @@ def unpack(self, **kwargs: ChannelSignal) -> Dict[str, ChannelSignal]:
raise ValueError(
f"Missing channel values for {', '.join(from_channels.keys())}")

unpack_op = esi.UnpackBundleOp([bc.channel._type for bc in to_channels],
self.value, operands)
with get_user_loc():
unpack_op = esi.UnpackBundleOp([bc.channel._type for bc in to_channels],
self.value, operands)

to_channels_results = unpack_op.toChannels
ret = {
Expand Down
7 changes: 4 additions & 3 deletions frontends/PyCDE/src/pycde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,9 +858,10 @@ def pack(
if len(to_channels) > 0:
raise ValueError(f"Missing channels: {', '.join(to_channels.keys())}")

pack_op = esi.PackBundleOp(self._type,
[bc.channel._type for bc in from_channels],
operands)
with get_user_loc():
pack_op = esi.PackBundleOp(self._type,
[bc.channel._type for bc in from_channels],
operands)

return BundleSignal(pack_op.bundle, self), Bundle.PackSignalResults(
[_FromCirctValue(c) for c in pack_op.fromChannels], self)
Expand Down
4 changes: 2 additions & 2 deletions frontends/PyCDE/test/test_esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ def build(ports):
# CHECK-NEXT: [[R5:%.+]] = hwarith.constant 0 : ui256
# CHECK-NEXT: [[R6:%.+]] = hw.struct_create ([[R0]], [[R4]], [[R5]]) : !hw.struct<address: ui64, tag: ui8, data: ui256>
# CHECK-NEXT: %chanOutput_0, %ready_1 = esi.wrap.vr [[R6]], %false : !hw.struct<address: ui64, tag: ui8, data: ui256>
# CHECK-NEXT: [[R7:%.+]] = esi.service.req <@_HostMem::@write>(#esi.appid<"host_mem_write_req">) : !esi.bundle<[!esi.channel<!hw.struct<address: ui64, tag: ui8, data: !esi.any>> from "req", !esi.channel<ui8> to "ackTag"]>
# CHECK-NEXT: %ackTag = esi.bundle.unpack %chanOutput_0 from [[R7]] : !esi.bundle<[!esi.channel<!hw.struct<address: ui64, tag: ui8, data: !esi.any>> from "req", !esi.channel<ui8> to "ackTag"]>
# CHECK-NEXT: [[R7:%.+]] = esi.service.req <@_HostMem::@write>(#esi.appid<"host_mem_write_req">) : !esi.bundle<[!esi.channel<!hw.struct<address: ui64, tag: ui8, data: ui256>> from "req", !esi.channel<ui8> to "ackTag"]>
# CHECK-NEXT: %ackTag = esi.bundle.unpack %chanOutput_0 from [[R7]] : !esi.bundle<[!esi.channel<!hw.struct<address: ui64, tag: ui8, data: ui256>> from "req", !esi.channel<ui8> to "ackTag"]>
# CHECK: esi.service.std.hostmem @_HostMem
@unittestmodule(esi_sys=True)
class HostMemReq(Module):
Expand Down
56 changes: 56 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,14 @@ struct HostMemReadResp {
uint64_t data;
uint8_t tag;
};

struct HostMemWriteReq {
uint64_t data;
uint8_t tag;
uint64_t address;
};

using HostMemWriteResp = uint8_t;
#pragma pack(pop)

class CosimHostMem : public HostMem {
Expand Down Expand Up @@ -465,6 +473,34 @@ class CosimHostMem : public HostMem {
*readRespPort, *readReqPort));
read->connect([this](const MessageData &req) { return serviceRead(req); },
true);

// Setup the write side callback.
ChannelDesc writeArg, writeResp;
if (!rpcClient->getChannelDesc("__cosim_hostmem_write.arg", writeArg) ||
!rpcClient->getChannelDesc("__cosim_hostmem_write.result", writeResp))
throw std::runtime_error("Could not find HostMem channels");

const esi::Type *writeRespType =
getType(ctxt, new StructType(writeResp.type(),
{{"tag", new UIntType("ui8", 8)},
{"data", new BitsType("i64", 64)}}));
const esi::Type *writeReqType =
getType(ctxt, new StructType(writeArg.type(),
{{"address", new UIntType("ui64", 64)},
{"length", new UIntType("ui32", 32)},
{"tag", new UIntType("ui8", 8)}}));

// Get ports, create the function, then connect to it.
writeRespPort = std::make_unique<WriteCosimChannelPort>(
rpcClient->stub.get(), writeResp, writeRespType,
"__cosim_hostmem_write.result");
writeReqPort = std::make_unique<ReadCosimChannelPort>(
rpcClient->stub.get(), writeArg, writeReqType,
"__cosim_hostmem_write.arg");
write.reset(CallService::Callback::get(acc, AppID("__cosim_hostmem_write"),
*writeRespPort, *writeReqPort));
write->connect([this](const MessageData &req) { return serviceWrite(req); },
true);
}

// Service the read request as a callback. Simply reads the data from the
Expand All @@ -491,6 +527,23 @@ class CosimHostMem : public HostMem {
return MessageData::from(resp);
}

// Service a write request as a callback. Simply write the data to the
// location specified. TODO: check that the memory has been mapped.
MessageData serviceWrite(const MessageData &reqBytes) {
const HostMemWriteReq *req = reqBytes.as<HostMemWriteReq>();
acc.getLogger().debug(
[&](std::string &subsystem, std::string &msg,
std::unique_ptr<std::map<std::string, std::any>> &details) {
subsystem = "HostMem";
msg = "Write request: addr=0x" + toHex(req->address) + " data=0x" +
toHex(req->data) + " tag=" + std::to_string(req->tag);
});
uint64_t *dataPtr = reinterpret_cast<uint64_t *>(req->address);
*dataPtr = req->data;
HostMemWriteResp resp = req->tag;
return MessageData::from(resp);
}

struct CosimHostMemRegion : public HostMemRegion {
CosimHostMemRegion(std::size_t size) {
ptr = malloc(size);
Expand Down Expand Up @@ -530,6 +583,9 @@ class CosimHostMem : public HostMem {
std::unique_ptr<WriteCosimChannelPort> readRespPort;
std::unique_ptr<ReadCosimChannelPort> readReqPort;
std::unique_ptr<CallService::Callback> read;
std::unique_ptr<WriteCosimChannelPort> writeRespPort;
std::unique_ptr<ReadCosimChannelPort> writeReqPort;
std::unique_ptr<CallService::Callback> write;
};

} // namespace
Expand Down
Loading

0 comments on commit 8b29dfd

Please sign in to comment.