Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions src/operator/contrib/transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,65 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,
<< "Must init CuBLAS handle in stream";

cublasHandle_t blas_handle = mshadow::Stream<gpu>::GetBlasHandle(s);
auto err = CUBLAS_STATUS_SUCCESS;
// TODO(cfujitsang): handle computation_precision
err = cublasGemmStridedBatchedEx(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<void*>(&alpha),
a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
reinterpret_cast<void*>(&beta),
c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
static_cast<int>(batchCount), CUDA_R_32F, algo);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail.";
// cublasGemmStridedBatchedEx is only supported for GPU with architecture
// capabilities equal or greater than 5.0. Fall back to
// cublasSgemmStridedBatched, which doesn't support implicit conversion
// to half-precision to use TensorCores
auto cc_major = (s->prop).major;
if (cc_major >= 5) {
CUBLAS_CALL(cublasGemmStridedBatchedEx(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<void*>(&alpha),
a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
reinterpret_cast<void*>(&beta),
c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
static_cast<int>(batchCount), CUDA_R_32F, algo));
} else {
if (std::is_same<DType, float>::value) {
CUBLAS_CALL(cublasSgemmStridedBatched(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<float*>(&alpha),
reinterpret_cast<const float*>(a),
static_cast<int>(lda), strideA,
reinterpret_cast<const float*>(b),
static_cast<int>(ldb), strideB,
reinterpret_cast<float*>(&beta),
reinterpret_cast<float*>(c),
static_cast<int>(ldc), strideC,
static_cast<int>(batchCount)));
} else if (std::is_same<DType, double>::value) {
CUBLAS_CALL(cublasDgemmStridedBatched(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<double*>(&alpha),
reinterpret_cast<const double*>(a),
static_cast<int>(lda), strideA,
reinterpret_cast<const double*>(b),
static_cast<int>(ldb), strideB,
reinterpret_cast<double*>(&beta),
reinterpret_cast<double*>(c),
static_cast<int>(ldc), strideC,
static_cast<int>(batchCount)));
} else if (std::is_same<DType, mshadow::half::half_t>::value) {
CUBLAS_CALL(cublasHgemmStridedBatched(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<__half*>(&alpha),
reinterpret_cast<const __half*>(a),
static_cast<int>(lda), strideA,
reinterpret_cast<const __half*>(b),
static_cast<int>(ldb), strideB,
reinterpret_cast<__half*>(&beta),
reinterpret_cast<__half*>(c),
static_cast<int>(ldc), strideC,
static_cast<int>(batchCount)));
} else {
LOG(FATAL) << "Unsupported DType in CublasStridedBatchedGemm.";
}
}
#else
LOG(FATAL) << "Not implemented with CUDA < 9.1";
#endif
Expand Down