[SE] Remove Platform*Handle classes
Summary: As pointed out by jprice, these classes don't serve a purpose. Instead, we stay consistent with the way memory is managed and let the Stream and Kernel classes directly hold opaque handles to device Stream and Kernel instances, respectively. Reviewers: jprice, jlebar Subscribers: parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24213 llvm-svn: 280719
This commit is contained in:
parent
df050fd585
commit
18ea094df1
@ -35,12 +35,11 @@ public:
|
||||
Expected<typename std::enable_if<std::is_base_of<KernelBase, KernelT>::value,
|
||||
KernelT>::type>
|
||||
createKernel(const MultiKernelLoaderSpec &Spec) {
|
||||
Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle =
|
||||
PDevice->createKernel(Spec);
|
||||
Expected<const void *> MaybeKernelHandle = PDevice->createKernel(Spec);
|
||||
if (!MaybeKernelHandle) {
|
||||
return MaybeKernelHandle.takeError();
|
||||
}
|
||||
return KernelT(Spec.getKernelName(), std::move(*MaybeKernelHandle));
|
||||
return KernelT(PDevice, *MaybeKernelHandle, Spec.getKernelName());
|
||||
}
|
||||
|
||||
/// Creates a stream object for this device.
|
||||
|
||||
@ -28,19 +28,32 @@
|
||||
|
||||
namespace streamexecutor {
|
||||
|
||||
class PlatformKernelHandle;
|
||||
class PlatformDevice;
|
||||
|
||||
/// The base class for all kernel types.
|
||||
///
|
||||
/// Stores the name of the kernel in both mangled and demangled forms.
|
||||
class KernelBase {
|
||||
public:
|
||||
KernelBase(llvm::StringRef Name);
|
||||
KernelBase(PlatformDevice *D, const void *PlatformKernelHandle,
|
||||
llvm::StringRef Name);
|
||||
|
||||
KernelBase(const KernelBase &Other) = delete;
|
||||
KernelBase &operator=(const KernelBase &Other) = delete;
|
||||
|
||||
KernelBase(KernelBase &&Other);
|
||||
KernelBase &operator=(KernelBase &&Other);
|
||||
|
||||
~KernelBase();
|
||||
|
||||
const void *getPlatformHandle() const { return PlatformKernelHandle; }
|
||||
const std::string &getName() const { return Name; }
|
||||
const std::string &getDemangledName() const { return DemangledName; }
|
||||
|
||||
private:
|
||||
PlatformDevice *PDevice;
|
||||
const void *PlatformKernelHandle;
|
||||
|
||||
std::string Name;
|
||||
std::string DemangledName;
|
||||
};
|
||||
@ -51,17 +64,12 @@ private:
|
||||
/// function.
|
||||
template <typename... ParameterTs> class Kernel : public KernelBase {
|
||||
public:
|
||||
Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle)
|
||||
: KernelBase(Name), PHandle(std::move(PHandle)) {}
|
||||
Kernel(PlatformDevice *D, const void *PlatformKernelHandle,
|
||||
llvm::StringRef Name)
|
||||
: KernelBase(D, PlatformKernelHandle, Name) {}
|
||||
|
||||
Kernel(Kernel &&Other) = default;
|
||||
Kernel &operator=(Kernel &&Other) = default;
|
||||
|
||||
/// Gets the underlying platform-specific handle for this kernel.
|
||||
PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<PlatformKernelHandle> PHandle;
|
||||
};
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
||||
@ -31,34 +31,6 @@
|
||||
|
||||
namespace streamexecutor {
|
||||
|
||||
class PlatformDevice;
|
||||
|
||||
/// Platform-specific kernel handle.
|
||||
class PlatformKernelHandle {
|
||||
public:
|
||||
explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
|
||||
|
||||
virtual ~PlatformKernelHandle();
|
||||
|
||||
PlatformDevice *getDevice() { return PDevice; }
|
||||
|
||||
private:
|
||||
PlatformDevice *PDevice;
|
||||
};
|
||||
|
||||
/// Platform-specific stream handle.
|
||||
class PlatformStreamHandle {
|
||||
public:
|
||||
explicit PlatformStreamHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
|
||||
|
||||
virtual ~PlatformStreamHandle();
|
||||
|
||||
PlatformDevice *getDevice() { return PDevice; }
|
||||
|
||||
private:
|
||||
PlatformDevice *PDevice;
|
||||
};
|
||||
|
||||
/// Raw executor methods that must be implemented by each platform.
|
||||
///
|
||||
/// This class defines the platform interface that supports executing work on a
|
||||
@ -73,19 +45,30 @@ public:
|
||||
virtual std::string getName() const = 0;
|
||||
|
||||
/// Creates a platform-specific kernel.
|
||||
virtual Expected<std::unique_ptr<PlatformKernelHandle>>
|
||||
virtual Expected<const void *>
|
||||
createKernel(const MultiKernelLoaderSpec &Spec) {
|
||||
return make_error("createKernel not implemented for platform " + getName());
|
||||
}
|
||||
|
||||
virtual Error destroyKernel(const void *Handle) {
|
||||
return make_error("destroyKernel not implemented for platform " +
|
||||
getName());
|
||||
}
|
||||
|
||||
/// Creates a platform-specific stream.
|
||||
virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
|
||||
virtual Expected<const void *> createStream() {
|
||||
return make_error("createStream not implemented for platform " + getName());
|
||||
}
|
||||
|
||||
virtual Error destroyStream(const void *Handle) {
|
||||
return make_error("destroyStream not implemented for platform " +
|
||||
getName());
|
||||
}
|
||||
|
||||
/// Launches a kernel on the given stream.
|
||||
virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
|
||||
GridDimensions GridSize, PlatformKernelHandle *K,
|
||||
virtual Error launch(const void *PlatformStreamHandle,
|
||||
BlockDimensions BlockSize, GridDimensions GridSize,
|
||||
const void *PKernelHandle,
|
||||
const PackedKernelArgumentArrayBase &ArgumentArray) {
|
||||
return make_error("launch not implemented for platform " + getName());
|
||||
}
|
||||
@ -94,9 +77,9 @@ public:
|
||||
///
|
||||
/// HostDst should have been allocated by allocateHostMemory or registered
|
||||
/// with registerHostMemory.
|
||||
virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle,
|
||||
size_t SrcByteOffset, void *HostDst,
|
||||
size_t DstByteOffset, size_t ByteCount) {
|
||||
virtual Error copyD2H(const void *PlatformStreamHandle,
|
||||
const void *DeviceSrcHandle, size_t SrcByteOffset,
|
||||
void *HostDst, size_t DstByteOffset, size_t ByteCount) {
|
||||
return make_error("copyD2H not implemented for platform " + getName());
|
||||
}
|
||||
|
||||
@ -104,22 +87,23 @@ public:
|
||||
///
|
||||
/// HostSrc should have been allocated by allocateHostMemory or registered
|
||||
/// with registerHostMemory.
|
||||
virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc,
|
||||
virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc,
|
||||
size_t SrcByteOffset, const void *DeviceDstHandle,
|
||||
size_t DstByteOffset, size_t ByteCount) {
|
||||
return make_error("copyH2D not implemented for platform " + getName());
|
||||
}
|
||||
|
||||
/// Copies data from one device location to another.
|
||||
virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle,
|
||||
size_t SrcByteOffset, const void *DeviceDstHandle,
|
||||
size_t DstByteOffset, size_t ByteCount) {
|
||||
virtual Error copyD2D(const void *PlatformStreamHandle,
|
||||
const void *DeviceSrcHandle, size_t SrcByteOffset,
|
||||
const void *DeviceDstHandle, size_t DstByteOffset,
|
||||
size_t ByteCount) {
|
||||
return make_error("copyD2D not implemented for platform " + getName());
|
||||
}
|
||||
|
||||
/// Blocks the host until the given stream completes all the work enqueued up
|
||||
/// to the point this function is called.
|
||||
virtual Error blockHostUntilDone(PlatformStreamHandle *S) {
|
||||
virtual Error blockHostUntilDone(const void *PlatformStreamHandle) {
|
||||
return make_error("blockHostUntilDone not implemented for platform " +
|
||||
getName());
|
||||
}
|
||||
|
||||
@ -59,10 +59,13 @@ namespace streamexecutor {
|
||||
/// of a stream once it is in an error state.
|
||||
class Stream {
|
||||
public:
|
||||
explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream);
|
||||
Stream(PlatformDevice *D, const void *PlatformStreamHandle);
|
||||
|
||||
Stream(Stream &&Other) = default;
|
||||
Stream &operator=(Stream &&Other) = default;
|
||||
Stream(const Stream &Other) = delete;
|
||||
Stream &operator=(const Stream &Other) = delete;
|
||||
|
||||
Stream(Stream &&Other);
|
||||
Stream &operator=(Stream &&Other);
|
||||
|
||||
~Stream();
|
||||
|
||||
@ -88,7 +91,7 @@ public:
|
||||
//
|
||||
// Returns the result of getStatus() after the Stream work completes.
|
||||
Error blockHostUntilDone() {
|
||||
setError(PDevice->blockHostUntilDone(ThePlatformStream.get()));
|
||||
setError(PDevice->blockHostUntilDone(PlatformStreamHandle));
|
||||
return getStatus();
|
||||
}
|
||||
|
||||
@ -105,7 +108,7 @@ public:
|
||||
const ParameterTs &... Arguments) {
|
||||
auto ArgumentArray =
|
||||
make_kernel_argument_pack<ParameterTs...>(Arguments...);
|
||||
setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize,
|
||||
setError(PDevice->launch(PlatformStreamHandle, BlockSize, GridSize,
|
||||
K.getPlatformHandle(), ArgumentArray));
|
||||
return *this;
|
||||
}
|
||||
@ -136,7 +139,7 @@ public:
|
||||
setError("copying too many elements, " + llvm::Twine(ElementCount) +
|
||||
", to a host array of element count " + llvm::Twine(Dst.size()));
|
||||
else
|
||||
setError(PDevice->copyD2H(ThePlatformStream.get(),
|
||||
setError(PDevice->copyD2H(PlatformStreamHandle,
|
||||
Src.getBaseMemory().getHandle(),
|
||||
Src.getElementOffset() * sizeof(T), Dst.data(),
|
||||
0, ElementCount * sizeof(T)));
|
||||
@ -196,10 +199,9 @@ public:
|
||||
", to a device array of element count " +
|
||||
llvm::Twine(Dst.getElementCount()));
|
||||
else
|
||||
setError(PDevice->copyH2D(ThePlatformStream.get(), Src.data(), 0,
|
||||
Dst.getBaseMemory().getHandle(),
|
||||
Dst.getElementOffset() * sizeof(T),
|
||||
ElementCount * sizeof(T)));
|
||||
setError(PDevice->copyH2D(
|
||||
PlatformStreamHandle, Src.data(), 0, Dst.getBaseMemory().getHandle(),
|
||||
Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -254,7 +256,7 @@ public:
|
||||
llvm::Twine(Dst.getElementCount()));
|
||||
else
|
||||
setError(PDevice->copyD2D(
|
||||
ThePlatformStream.get(), Src.getBaseMemory().getHandle(),
|
||||
PlatformStreamHandle, Src.getBaseMemory().getHandle(),
|
||||
Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(),
|
||||
Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
|
||||
return *this;
|
||||
@ -342,7 +344,7 @@ private:
|
||||
PlatformDevice *PDevice;
|
||||
|
||||
/// The platform-specific stream handle for this instance.
|
||||
std::unique_ptr<PlatformStreamHandle> ThePlatformStream;
|
||||
const void *PlatformStreamHandle;
|
||||
|
||||
/// Mutex that guards the error state flags.
|
||||
std::unique_ptr<llvm::sys::RWMutex> ErrorMessageMutex;
|
||||
@ -350,9 +352,6 @@ private:
|
||||
/// First error message for an operation in this stream or empty if there have
|
||||
/// been no errors.
|
||||
llvm::Optional<std::string> ErrorMessage;
|
||||
|
||||
Stream(const Stream &) = delete;
|
||||
void operator=(const Stream &) = delete;
|
||||
};
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
||||
@ -28,14 +28,11 @@ Device::Device(PlatformDevice *PDevice) : PDevice(PDevice) {}
|
||||
Device::~Device() = default;
|
||||
|
||||
Expected<Stream> Device::createStream() {
|
||||
Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream =
|
||||
PDevice->createStream();
|
||||
Expected<const void *> MaybePlatformStream = PDevice->createStream();
|
||||
if (!MaybePlatformStream) {
|
||||
return MaybePlatformStream.takeError();
|
||||
}
|
||||
assert((*MaybePlatformStream)->getDevice() == PDevice &&
|
||||
"an executor created a stream with a different stored executor");
|
||||
return Stream(std::move(*MaybePlatformStream));
|
||||
return Stream(PDevice, *MaybePlatformStream);
|
||||
}
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
||||
@ -12,16 +12,49 @@
|
||||
///
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "streamexecutor/Kernel.h"
|
||||
#include <cassert>
|
||||
|
||||
#include "streamexecutor/Device.h"
|
||||
#include "streamexecutor/Kernel.h"
|
||||
#include "streamexecutor/PlatformInterfaces.h"
|
||||
|
||||
#include "llvm/DebugInfo/Symbolize/Symbolize.h"
|
||||
|
||||
namespace streamexecutor {
|
||||
|
||||
KernelBase::KernelBase(llvm::StringRef Name)
|
||||
: Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
|
||||
Name, nullptr)) {}
|
||||
KernelBase::KernelBase(PlatformDevice *D, const void *PlatformKernelHandle,
|
||||
llvm::StringRef Name)
|
||||
: PDevice(D), PlatformKernelHandle(PlatformKernelHandle), Name(Name),
|
||||
DemangledName(
|
||||
llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr)) {
|
||||
assert(D != nullptr &&
|
||||
"cannot construct a kernel object with a null platform device");
|
||||
assert(PlatformKernelHandle != nullptr &&
|
||||
"cannot construct a kernel object with a null platform kernel handle");
|
||||
}
|
||||
|
||||
KernelBase::KernelBase(KernelBase &&Other)
|
||||
: PDevice(Other.PDevice), PlatformKernelHandle(Other.PlatformKernelHandle),
|
||||
Name(std::move(Other.Name)),
|
||||
DemangledName(std::move(Other.DemangledName)) {
|
||||
Other.PDevice = nullptr;
|
||||
Other.PlatformKernelHandle = nullptr;
|
||||
}
|
||||
|
||||
KernelBase &KernelBase::operator=(KernelBase &&Other) {
|
||||
PDevice = Other.PDevice;
|
||||
PlatformKernelHandle = Other.PlatformKernelHandle;
|
||||
Name = std::move(Other.Name);
|
||||
DemangledName = std::move(Other.DemangledName);
|
||||
Other.PDevice = nullptr;
|
||||
Other.PlatformKernelHandle = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelBase::~KernelBase() {
|
||||
if (PlatformKernelHandle)
|
||||
// TODO(jhen): Handle the error here.
|
||||
consumeError(PDevice->destroyKernel(PlatformKernelHandle));
|
||||
}
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
||||
@ -16,8 +16,6 @@
|
||||
|
||||
namespace streamexecutor {
|
||||
|
||||
PlatformStreamHandle::~PlatformStreamHandle() = default;
|
||||
|
||||
PlatformDevice::~PlatformDevice() = default;
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
||||
@ -12,14 +12,43 @@
|
||||
///
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "streamexecutor/Stream.h"
|
||||
|
||||
namespace streamexecutor {
|
||||
|
||||
Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream)
|
||||
: PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)),
|
||||
ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {}
|
||||
Stream::Stream(PlatformDevice *D, const void *PlatformStreamHandle)
|
||||
: PDevice(D), PlatformStreamHandle(PlatformStreamHandle),
|
||||
ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {
|
||||
assert(D != nullptr &&
|
||||
"cannot construct a stream object with a null platform device");
|
||||
assert(PlatformStreamHandle != nullptr &&
|
||||
"cannot construct a stream object with a null platform stream handle");
|
||||
}
|
||||
|
||||
Stream::~Stream() = default;
|
||||
Stream::Stream(Stream &&Other)
|
||||
: PDevice(Other.PDevice), PlatformStreamHandle(Other.PlatformStreamHandle),
|
||||
ErrorMessageMutex(std::move(Other.ErrorMessageMutex)),
|
||||
ErrorMessage(std::move(Other.ErrorMessage)) {
|
||||
Other.PDevice = nullptr;
|
||||
Other.PlatformStreamHandle = nullptr;
|
||||
}
|
||||
|
||||
Stream &Stream::operator=(Stream &&Other) {
|
||||
PDevice = Other.PDevice;
|
||||
PlatformStreamHandle = Other.PlatformStreamHandle;
|
||||
ErrorMessageMutex = std::move(Other.ErrorMessageMutex);
|
||||
ErrorMessage = std::move(Other.ErrorMessage);
|
||||
Other.PDevice = nullptr;
|
||||
Other.PlatformStreamHandle = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream::~Stream() {
|
||||
if (PlatformStreamHandle)
|
||||
// TODO(jhen): Handle error condition here.
|
||||
consumeError(PDevice->destroyStream(PlatformStreamHandle));
|
||||
}
|
||||
|
||||
} // namespace streamexecutor
|
||||
|
||||
@ -34,9 +34,7 @@ class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice {
|
||||
public:
|
||||
std::string getName() const override { return "SimpleHostPlatformDevice"; }
|
||||
|
||||
streamexecutor::Expected<
|
||||
std::unique_ptr<streamexecutor::PlatformStreamHandle>>
|
||||
createStream() override {
|
||||
streamexecutor::Expected<const void *> createStream() override {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -69,7 +67,7 @@ public:
|
||||
return streamexecutor::Error::success();
|
||||
}
|
||||
|
||||
streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S,
|
||||
streamexecutor::Error copyD2H(const void *StreamHandle,
|
||||
const void *DeviceHandleSrc,
|
||||
size_t SrcByteOffset, void *HostDst,
|
||||
size_t DstByteOffset,
|
||||
@ -80,8 +78,8 @@ public:
|
||||
return streamexecutor::Error::success();
|
||||
}
|
||||
|
||||
streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S,
|
||||
const void *HostSrc, size_t SrcByteOffset,
|
||||
streamexecutor::Error copyH2D(const void *StreamHandle, const void *HostSrc,
|
||||
size_t SrcByteOffset,
|
||||
const void *DeviceHandleDst,
|
||||
size_t DstByteOffset,
|
||||
size_t ByteCount) override {
|
||||
@ -92,7 +90,7 @@ public:
|
||||
}
|
||||
|
||||
streamexecutor::Error
|
||||
copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc,
|
||||
copyD2D(const void *StreamHandle, const void *DeviceHandleSrc,
|
||||
size_t SrcByteOffset, const void *DeviceHandleDst,
|
||||
size_t DstByteOffset, size_t ByteCount) override {
|
||||
std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
|
||||
|
||||
@ -34,11 +34,11 @@ const auto &getDeviceValue =
|
||||
class StreamTest : public ::testing::Test {
|
||||
public:
|
||||
StreamTest()
|
||||
: Device(&PDevice),
|
||||
Stream(llvm::make_unique<se::PlatformStreamHandle>(&PDevice)),
|
||||
HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9},
|
||||
HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23},
|
||||
Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35},
|
||||
: DummyPlatformStream(1), Device(&PDevice),
|
||||
Stream(&PDevice, &DummyPlatformStream), HostA5{0, 1, 2, 3, 4},
|
||||
HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16},
|
||||
HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28},
|
||||
Host7{29, 30, 31, 32, 33, 34, 35},
|
||||
DeviceA5(getOrDie(Device.allocateDeviceMemory<int>(5))),
|
||||
DeviceB5(getOrDie(Device.allocateDeviceMemory<int>(5))),
|
||||
DeviceA7(getOrDie(Device.allocateDeviceMemory<int>(7))),
|
||||
@ -50,6 +50,8 @@ public:
|
||||
}
|
||||
|
||||
protected:
|
||||
int DummyPlatformStream; // Mimicking a platform where the platform stream
|
||||
// handle is just a stream number.
|
||||
se::test::SimpleHostPlatformDevice PDevice;
|
||||
se::Device Device;
|
||||
se::Stream Stream;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user